【论文精读】TransE 及其实现
TransE 及其實現
1. What is TransE?
TransE (Translating Embedding), an energy-based model for learning low-dimensional embeddings of entities.
核心思想:將 relationship 視為一個在 embedding space 的 translation。如果 (h, l, t) 存在,那么 h+l≈th + l \approx th+l≈t。
Motivation:一是在 Knowledge Base 中,層次化的關系是非常常見的,translation 是一種很自然的用來表示它們的變換;二是近期一些從 text 中學習 word embedding 的研究發現,一些不同類型的實體之間的 1-to-1 的 relationship 可以被 model 表示為在 embedding space 中的一種 translation。
2. Learning TransE
TransE 的訓練算法如下:
2.1 輸入參數
- training set SSS:用于訓練的三元組的集合,entity 的集合為 EEE,rel. 的集合為 LLL
- margin γ\gammaγ:損失函數中的間隔,這個在原 paper 中描述很模糊
- 每個 entity 或 rel. 的 embedding dim kkk
2.2 訓練過程
初始化:對每一個 entity 和 rel. 的 embedding vector 用 xavier_uniform 分布來初始化,然后對它們實施 L1 or L2 正則化。
loop:
- 在 entity embedding 被更新前進行一次歸一化,這是通過人為增加 embedding 的 norm 來防止 loss 在訓練過程中極小化。
- sample 出一個 mini-batch 的正樣本集合 SbatchS_{batch}Sbatch?
- 將 TbatchT_{batch}Tbatch? 初始化為空集,它表示本次 loop 用于訓練 model 的數據集
- for (h,l,t)∈Sbatch(h,l,t) \in S_{batch}(h,l,t)∈Sbatch? do:
- 根據 (h, l, t) 構造出一個錯誤的三元組 (h′,l,t′)(h', l, t')(h′,l,t′)
- 將 positive sample (h,l,t)(h,l,t)(h,l,t) 和 negative sample (h′,l,t′)(h',l,t')(h′,l,t′) 加入到 TbatchT_{batch}Tbatch? 中
- 計算 TbatchT_{batch}Tbatch? 每一對 positive sample 和 negative sample 的 loss,然后累加起來用于更新 embedding matrix。每一對的 loss 計算方式為:loss=[γ+d(h+l,t)?d(h′+l,t′)]+loss = [\gamma + d(h+l,t) - d(h'+l,t')]_+loss=[γ+d(h+l,t)?d(h′+l,t′)]+?
這個過程中,triplet 的 energy 就是指的 d(h+l,t)d(h+l,t)d(h+l,t),它衡量了 h+lh+lh+l 與 ttt 的距離,可以采用 L1 或 L2 norm,即 ∣∣h+r?t∣∣||h + r - t||∣∣h+r?t∣∣ 具體計算方式可見代碼實現。
loss 的計算中,[x]+=max?(0,x)[x]_+ = \max(0,x)[x]+?=max(0,x)。
關于 margin γ\gammaγ 的含義, 它相當于是一個正確 triple 與錯誤 triple 之前的間隔修正,margin 越大,則兩個 triple 之前被修正的間隔就越大,則對于 embedding 的修正就越嚴格。我們看 loss=[γ+d(h+l,t)?d(h′+l,t′)]+loss = [\gamma + d(h+l,t) - d(h'+l,t')]_+loss=[γ+d(h+l,t)?d(h′+l,t′)]+?,我們希望是 d(h+l,t)d(h+l,t)d(h+l,t) 越小越好,d(h′+l,t′)d(h'+l,t')d(h′+l,t′) 越大越好,假設 d(h+l,t)d(h+l,t)d(h+l,t) 處于理想情況下等于 0,那么由于 γ\gammaγ 的存在,d(h′+l,t′)d(h'+l,t')d(h′+l,t′) 如果不是很大的話,仍然會產生 loss,只有當 d(h′+l,t′)d(h'+l,t')d(h′+l,t′) 大于 γ\gammaγ 時才會讓 loss = 0,所以 γ\gammaγ 越大,對 embedding 的修正就越嚴格。
錯誤三元組的構造方法:將 (h,l,t)(h,l,t)(h,l,t) 中的頭實體、關系和尾實體其中之一隨機替換為其他實體或關系來得到。
2.3 評價指標
鏈接預測是用來預測三元組 (h,r,t) 中缺失實體 h, t 或 r 的任務,對于每一個缺失的實體,模型將被要求用所有的知識圖譜中的實體作為候選項進行計算,并進行排名,而不是單純給出一個最優的預測結果。
首先對于每個 testing triple,以預測 tail entity 為例,我們將 (h,r,t)(h,r,t)(h,r,t) 中的 t 用 KG 中的每個 entity 來代替,然后通過 fr(h,t)f_r(h,t)fr?(h,t) 來計算分數,這樣就可以得到一系列的分數,然后將這些分數排列。我們知道 f 函數值越小越好,那么在前面的排列中,排地越靠前越好。重點來了,我們去看每個 testing triple 中正確答案(也就是真實的 t)在上述序列中排多少位,比如 t1t_1t1? 排 100,t2t_2t2? 排 200,t3t_3t3? 排 60 …,之后對這些排名求平均,就得到 mean rank 值了。
還是按照上述進行 f 函數值排列,然后看每個 testing triple 正確答案是否排在序列的前十,如果在的話就計數 +1,最終 (排在前十的個數) / (總個數) 就等于 Hits@10。
在原論文中,由于這個 model 比較老了,其 baseline 也沒啥參考性,就不做研究了,具體的實驗可參考論文。
3. TransE 優缺點
優點:與以往模型相比,TransE 模型參數較少,計算復雜度低,卻能直接建立實體和關系之間的復雜語義聯系,在 WordNet 和 Freebase 等 dataset 上較以往模型的 performance 有了顯著提升,特別是在大規模稀疏 KG 上,TransE 的性能尤其驚人。
缺點:在處理復雜關系(1-N、N-1 和 N-N)時,性能顯著降低,這與 TransE 的模型假設有密切關系。假設有 (美國,總統,奧巴馬)和(美國,總統,布什),這里的“總統”關系是典型的 1-N 的復雜關系,如果用 TransE 對其進行學習,則會有:
那么這將會使奧巴馬和布什的 vector 變得相同。所以由于這些復雜關系的存在,導致 TransE 學習得到的實體表示區分性較低。
4. TransE 實現
這里選擇用 pytorch 來實現 TransE 模型。
4.1 __init__ 函數
其參數有:
- ent_num:entity 的數量
- rel_num:relationship 的數量
- dim:每個 embedding vector 的維度
- norm:在計算 d(h+l,t)d(h+l,t)d(h+l,t) 時是使用 L1 norm 還是 L2 norm,即 d(h+l,t)=∣∣h+l?t∣∣L1orL2d(h+l,t)=||h+l-t||_{L1 \ or \ L2}d(h+l,t)=∣∣h+l?t∣∣L1?or?L2?
- margin:損失函數中的間隔,是個 hyper-parameter
- α\alphaα:損失函數計算中的正則化項參數
初始化 embedding matrix 時,直接用 nn.Embedding 來完成,參數分別是 entity 的數量和每個 embedding vector 的維數,這樣得到的就是一個 ent_num * dim 大小的 Embedding Matrix。
torch.nn.init.xavier_uniform_ 是一個服從均勻分布的 Glorot 初始化器,在這里做的就是對 Embedding Matrix 中每個位置填充一個 xavier_uniform 初始化的值,這些值從均勻分布 U(?a,a)U(-a,a)U(?a,a) 中采樣得到,這里的 aaa 是:
a=gain×6fan_in+fan_outa = gain \times \sqrt{\frac{6}{fan\_in + fan\_out}}a=gain×fan_in+fan_out6??
在這里,對于 Embedding 這樣的二維矩陣來說,fan_in 和 fan_out 就是矩陣的長和寬,gain 默認為 1。其完整具體行為可參考 pytorch 初始化器文檔。
F.normalize(self.ent_embeddings.weight.data, 2, 1) 這一步就是對 ent_embeddings 的每一個值除以 dim = 1 上的 2 范數值,注意 ent_embeddings.weight.data 的 size 是 (ent_num, embs_dim)。具體來說就是這一步把每行都除以該行下所有元素平方和的開方,也就是 l←l/∣∣l∣∣l \leftarrow l / ||l||l←l/∣∣l∣∣。
損失函數這里先跳過,之后計算損失的步驟一同來看。
4.2 從 ent_idx 到 ent_embs
由于 network 的輸入是 ent_idx,因此需要將其根據 embedding matrix 轉換成 ent_embs。我們通過 get_ent_resps 函數來完成,其實就是個靜態查表的操作:
class TransE(nn.Module):...def get_ent_resps(self, ent_idx): #[batch]return self.ent_embeddings(ent_idx) # [batch, emb]4.3 計算 energy d(h+l,t)d(h+l, t)d(h+l,t)
它衡量了 h+lh+lh+l 與 ttt 的距離,可以采用 L1 或 L2 norm 來算,具體采用哪個由 __init__ 函數中的 self.norm 來決定:
class TransE(nn.Module):...def distance(self, h_idx, r_idx, t_idx):h_embs = self.ent_embeddings(h_idx) # [batch, emb]r_embs = self.rel_embeddings(r_idx) # [batch, emb]t_embs = self.ent_embeddings(t_idx) # [batch, emb]scores = h_embs + r_embs - t_embs# norm 是計算 loss 時的正則化項norms = (torch.mean(h_embs.norm(p=self.norm, dim=1) - 1.0)+ torch.mean(r_embs ** 2) +torch.mean(t_embs.norm(p=self.norm, dim=1) - 1.0)) / 3return scores.norm(p=self.norm, dim=1), norms4.4 計算 loss
self.criterion 是通過實例化 MarginRankingLoss 得到的,這個類的初始化接收 margin 參數,實例化得到 self.criterion,其計算方式如下:
criterion(x1,x2,y)=max?(0,?y×(x1?x2)+margin)criterion(x_1,x_2,y) = \max(0, -y \times (x_1 - x_2) + margin)criterion(x1?,x2?,y)=max(0,?y×(x1??x2?)+margin)
借助于此,我們可以實現計算 loss 的代碼:
class TransE(nn.Module):...def loss(self, positive_distances, negative_distances):target = torch.tensor([-1], dtype=torch.float, device=self.device)return self.criterion(positive_distances, negative_distances, target)positive_distances 就是 d(h+l,t)d(h+l,t)d(h+l,t),negative_distances 就是 d(h′+l,t′)d(h'+l, t')d(h′+l,t′),target = [-1],代入 criterion 的計算公式就是我們計算 一對正樣本和負樣本的 loss 了。
4.5 forward
class TransE(nn.Module):...def forward(self, ph_idx, pr_idx, pt_idx, nh_idx, nr_idx, nt_idx):pos_distances, pos_norms = self.scoring(ph_idx, pr_idx, pt_idx)neg_distances, neg_norms = self.scoring(nh_idx, nr_idx, nt_idx)tmp_loss = self.loss(pos_distances, neg_distances)tmp_loss += self.alpha * pos_norms # 正則化項tmp_loss += self.alpha * neg_norms # 正則化項return tmp_loss, pos_distances, neg_distances以上我們講完了 TransE 模型的定義,接下來就是講對 TransE 模型的訓練了,只要理解了 TransE 模型的定義,其訓練應該不是難事。
總結
以上是生活随笔為你收集整理的【论文精读】TransE 及其实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 简单了解TransE
- 下一篇: TransE 论文笔记