【DKN】(六)KCNN.py
生活随笔
收集整理的這篇文章主要介紹了
【DKN】(六)KCNN.py
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
內容
import torch import torch.nn as nn import torch.nn.functional as F from src.model.general.attention.additive import AdditiveAttentiondevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class KCNN(torch.nn.Module):"""Knowledge-aware CNN (KCNN) based on Kim CNN.Input a news sentence (e.g. its title), produce its embedding vector."""def __init__(self, config, pretrained_word_embedding,pretrained_entity_embedding, pretrained_context_embedding):#前面是單純的定義: 獲取單詞嵌入、實體嵌入和上下文嵌入的預訓練參數(不只是歷史點擊新聞還有候選新聞的)super(KCNN, self).__init__()self.config = configif pretrained_word_embedding is None: #如果預訓練單詞嵌入是空,那么就需要用集成在nn.Embedding()的函數了self.word_embedding = nn.Embedding(config.num_words,config.word_embedding_dim,padding_idx=0)else:self.word_embedding = nn.Embedding.from_pretrained(pretrained_word_embedding, freeze=False, padding_idx=0)if pretrained_entity_embedding is None:self.entity_embedding = nn.Embedding(config.num_entities,config.entity_embedding_dim,padding_idx=0)else:self.entity_embedding = nn.Embedding.from_pretrained(pretrained_entity_embedding, freeze=False, padding_idx=0)if config.use_context:if pretrained_context_embedding is None:self.context_embedding = nn.Embedding(config.num_entities,config.entity_embedding_dim,padding_idx=0)else:self.context_embedding = nn.Embedding.from_pretrained(pretrained_context_embedding, freeze=False, padding_idx=0)self.transform_matrix = nn.Parameter(torch.empty(self.config.entity_embedding_dim,self.config.word_embedding_dim).uniform_(-0.1, 0.1))self.transform_bias = nn.Parameter(torch.empty(self.config.word_embedding_dim).uniform_(-0.1, 0.1))self.conv_filters = nn.ModuleDict({str(x): nn.Conv2d(3 if self.config.use_context else 2,self.config.num_filters,(x, self.config.word_embedding_dim))for x in self.config.window_sizes})self.additive_attention = AdditiveAttention(self.config.query_vector_dim, self.config.num_filters)def forward(self, news):"""Args:news:{"title": batch_size * num_words_title,"title_entities": batch_size * num_words_title}Returns:final_vector: batch_size, len(window_sizes) * num_filters"""# batch_size, num_words_title, word_embedding_dimword_vector = self.word_embedding(news["title"].to(device))#獲得單詞向量 需要放到設備上的# batch_size, num_words_title, entity_embedding_dim entity_vector = self.entity_embedding( #獲得實體向量news["title_entities"].to(device))if self.config.use_context: #用上下文的話就得獲得上下文的向量# batch_size, num_words_title, entity_embedding_dimcontext_vector = self.context_embedding(news["title_entities"].to(device))# batch_size, num_words_title, word_embedding_dimtransformed_entity_vector = torch.tanh( #轉換矩陣是將其中某些詞替換掉! torch.add(torch.matmul(entity_vector, self.transform_matrix),self.transform_bias))if self.config.use_context: # batch_size, num_words_title, word_embedding_dimtransformed_context_vector = torch.tanh(torch.add(torch.matmul(context_vector, self.transform_matrix),self.transform_bias))# batch_size, 3, num_words_title, word_embedding_dimmulti_channel_vector = torch.stack([word_vector, transformed_entity_vector,transformed_context_vector], dim=1) #獲得最終的concat向量else:# batch_size, 2, num_words_title, word_embedding_dimmulti_channel_vector = torch.stack([word_vector, transformed_entity_vector], dim=1)pooled_vectors = [] #for x in self.config.window_sizes: # window_size = 3 # batch_size, num_filters, num_words_title + 1 - xconvoluted = self.conv_filters[str(x)]( #后面就是卷積常規操作! 分別進行3種window_size的卷積multi_channel_vector).squeeze(dim=3)# batch_size, num_filters, num_words_title + 1 - xactivated = F.relu(convoluted)# batch_size, num_filters# Here we use a additive attention module# instead of pooling in the paperpooled = self.additive_attention(activated.transpose(1, 2))# pooled = activated.max(dim=-1)[0]# # or# # pooled = F.max_pool1d(activated, activated.size(2)).squeeze(dim=2)pooled_vectors.append(pooled)# batch_size, len(window_sizes) * num_filtersfinal_vector = torch.cat(pooled_vectors, dim=1)return final_vector說明
最后的卷積有必要說一下
總結
以上是生活随笔為你收集整理的【DKN】(六)KCNN.py的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 测试服务器性能常用算法,服务器性能剖析(
- 下一篇: AcWing之从尾到头打印链表