[NLP] 相对位置编码(二) Relative Positional Encodings - Transformer-XL
1. Motivation
在Transformer-XL中,由于設(shè)計(jì)了segments,如果仍采用transformer模型中的絕對(duì)位置編碼的話,將不能區(qū)分處不同segments內(nèi)同樣相對(duì)位置的詞的先后順序。
比如對(duì)于$segment_i$的第k個(gè)token,和$segment_j$的第k個(gè)token的絕對(duì)位置編碼是完全相同的。
鑒于這樣的問題,transformer-XL中采用了相對(duì)位置編碼。
2. Relative Positional Encodings
paper中,由對(duì)絕對(duì)位置編碼變換推導(dǎo)出新的相對(duì)位置編碼方式。
vanilla Transformer中的絕對(duì)位置編碼
它對(duì)每個(gè)index的token都通過sin/cos變換,為其唯一指定了一個(gè)位置編碼。該位置編碼將與input的embedding求sum之后作為transformer的input。
那么如果將該位置編碼應(yīng)用在transformer-xl會(huì)怎樣呢?
其中$\tau$表示第$\tau$個(gè)segment,?是當(dāng)前segment的序列$s_{\tau}$的word embedding sequence, $L$是序列長(zhǎng),$d$是每個(gè)word embedding的維度。$U_{1:L}$表示該segment中每個(gè)token的絕對(duì)位置編碼組成的序列。
可以看到對(duì)于$h_{\tau + 1}$和$h_{\tau}$,其在位置編碼表示是完全相同的,都是$U_{1:L}$,這樣就會(huì)造成motivation中所述的無法區(qū)分在不同segments中相對(duì)位置相同的tokens.
3. Transformer-XL中的相對(duì)位置編碼
transformer-xl中沒有采用vanilla transformer中的將位置編碼靜態(tài)地與embedding結(jié)合的方式;而是沿用了shaw et al.2018的相對(duì)位置編碼中通過將位置信息注入到求Attention score的過程中,即將相對(duì)位置信息編碼入hidden state中。
為什么要這么做呢?paper中給出的解釋是:
1) 位置編碼在概念上講,是為模型提供了時(shí)間線索或者說是關(guān)于如何收集信息的"bias"。出于同樣的目的,除了可以在初始的embedding中加入這樣的統(tǒng)計(jì)上的bias, 也可以在計(jì)算每層的Attention score時(shí)加入同樣的信息。
2) 以相對(duì)而非絕對(duì)的方式定義時(shí)間偏差更為直觀和通用。比如對(duì)于一個(gè)query vector $q_{\tau,i}$ 與 key vectors $k_{\tau, \leq i}$做attention時(shí),這個(gè)query 并不需要知道每一個(gè)key vector在序列中的絕對(duì)的位置來決定segment的時(shí)序。它只需要知道每一對(duì)$k_{\tau,j}$ 和其本身$q_{\tau,i}$的相對(duì)距離(比如,i - j)就足夠。
因此,在實(shí)際中可以創(chuàng)建一個(gè)相對(duì)位置編碼的encodings矩陣 $R \in \mathbb{R} ^ {L_{max} \times d}$,其中第i行 $R_i$表示兩個(gè)pos(比如位置pos_q, pos_k)之間的相對(duì)距離為i. (可以參考我在參考鏈接3中的介紹,以下圖示便是一個(gè)簡(jiǎn)單的說明例子.
但是圖示中的i表示query的位置pos, 與$R_i$ 中的i不同。如果以該圖示為例,當(dāng)pos_q = i, pos_k = i - 4時(shí), 相對(duì)位置為 0, 二者的相對(duì)位置編碼是 $R_0$。
--------------------------------------------------------------------------------------------------
Transformer-XL的相對(duì)位置編碼方式是對(duì)Shaw et al.,2018 和 Huang et al.2018提出模型的改進(jìn)。它由采用絕對(duì)編碼計(jì)算Attention score的表達(dá)式出發(fā),進(jìn)行了改進(jìn)3項(xiàng)改變。
若采用絕對(duì)位置編碼,hidden state的表達(dá)式為:
,
那么對(duì)應(yīng)的query,key的attention score表達(dá)式為:
(應(yīng)用乘法分配率, query的embedding 分別與 key的embedding, positional encoding相乘相加;之后 query的positional encoding分別與 key的embedding, positional encoding相乘相加)
(其中i是query的位置index,j是key的位置index) (WE, WU是對(duì)embedding進(jìn)行l(wèi)inear projection的表示,細(xì)節(jié)內(nèi)容可以參看attention is all you need 中對(duì)multi-head attention的介紹)
,
Transformer-XL 對(duì)上式進(jìn)行了改進(jìn):
?
改進(jìn)1) $Uj \rightarrow R_{i - j}$.
首先將 $A_{i, j} ^ {abs}$ 中的key vector的絕對(duì)位置編碼 $U_j$ 替換為了相對(duì)位置編碼 $R_{i - j}$ 其中 $R$是一個(gè)沒有需要學(xué)習(xí)的參數(shù)的sinusoid encoding matrix,如同Vaswani et al., 2017提出的一樣。
該改進(jìn)既可以避免不同segments之間由于tokens在各自segment的index相同而產(chǎn)生的時(shí)序沖突的問題。
改進(jìn)2)? $(c) : U_i^{T} W_q ^ {T} \rightarrow?{\color{red} u}? \in \mathbb{R}^d$;$(d) : U_i^{T} W_q ^ {T} \rightarrow {\color{red} v} \in \mathbb{R}^d$
在改進(jìn)1中將key的絕對(duì)位置編碼轉(zhuǎn)換為相對(duì)位置編碼,在改進(jìn)2中則對(duì)query的絕對(duì)位置編碼進(jìn)行了替換。因?yàn)闊o論query在序列中的絕對(duì)位置如何,其相對(duì)于自身的相對(duì)位置都是一樣的。這說明attention bias的計(jì)算與query在序列中的絕對(duì)位置無關(guān),應(yīng)當(dāng)保持不變. 所以這里將$A_{i, j} ^ {abs}$ 中的c,d項(xiàng)中的$U_i^{T} W_q ^ {T}$分別用一個(gè)可學(xué)習(xí)參數(shù)$u \in \mathbb{R}^d$,$v \in \mathbb{R}^d$替換。
改進(jìn)3) $W_{k} \rightarrow W_{k, E}$, $W_{k, R}$
在vanilla transformer模型中,對(duì)query, key分別進(jìn)行線性映射時(shí),query 對(duì)應(yīng)$W_q$矩陣,key對(duì)應(yīng)$W_k$矩陣,由于input 是 embedding 與 positional encoding的相加,也就相當(dāng)于
$query_{embedding} W_q + query_{pos encoding} W_q$得到query的線性映射后的表征;
$key_{embedding} W_q + key_{pos encoding} W_q$ 得到key的線性映射后的表征。
可以看出,在vanilla transformer中對(duì)于embedding和positional encoding都是采用的同樣的線性變換。
在改進(jìn)3中,則將key的embedding和positional encoding 分別采用了不同的線性變換。其中$W_{k,E}$對(duì)應(yīng)于key的embedding線性映射矩陣,$W_{k,R}$對(duì)應(yīng)與key的positional encoding的線性映射矩陣。
在這樣的參數(shù)化定義后,每一項(xiàng)都有了一個(gè)直觀上的表征含義,(a)表示基于內(nèi)容content的表征,(b)表示基于content的位置偏置,(c)表示全局的content的偏置,(d)表示全局的位置偏置。
與shaw的RPR的對(duì)比
shaw的RPR可以參考我在參考鏈接3中的介紹。這里給出論文中的表達(dá)式:其中$a_{i,j}$是query i, key j的相對(duì)位置編碼矩陣$A$中的對(duì)應(yīng)編碼。
attention score: (在key的表征中加入相對(duì)位置信息)
softmax計(jì)算權(quán)值系數(shù):
attention score * (value + 的output:(在value的表征中加入相對(duì)位置信息)
1) 對(duì)于$e_{ij}$可以用乘法分配率拆解來看,那么其相當(dāng)于transforerm-xl中的(a)(b)兩項(xiàng)。也就是在shaw的模型中未考慮加入(c)(d)項(xiàng)的全局內(nèi)容偏置和全局位置偏置。
2) 還是拆解$e_{ij}$來看,涉及到一項(xiàng)為$x_iW^Q(a_{ij}^K)^T$,是直接用 query的線性映射后的表征 與 相對(duì)位置編碼相乘;而在transformer-xl中,則是與query的線性映射后的表征 與 相對(duì)位置編碼也進(jìn)行線性映射后的表征 相乘。
優(yōu)勢(shì):
paper中指出,shaw et al用單一的相對(duì)位置編碼矩陣 與 transformer-xl中的$W_kR$相比,丟失掉了在原始的?sinusoid positional?encoding (Vaswani et al., 2017)中的歸納偏置。而XL中的這種表征方式則可以更好地利用sinusoid 的inductive bias。
----------------------------為什么XL中的這種表征方式則可以更好地利用sinusoid 的inductive bias?--------------------------------------------------------------------
有幾個(gè)問題:原始的?sinusoid positional?encoding (Vaswani et al., 2017)中的歸納偏置是什么呢?為什么shaw et al 把它丟失了呢?為什么transformer-xl可以適用呢?
這里需要搞清楚:
1. 為什么在vanilla transformer中使用sinusoid?
2. shaw et al.2018中的相對(duì)位置編碼Tensor是什么?
3. transformer-xl的相對(duì)位置編碼矩陣是什么?
對(duì)于1,sinusoid函數(shù)具有并不受限于序列長(zhǎng)度仍可以較好表示位置信息的特點(diǎn)。
We chose the sinusoidal version because?it may allow the model to extrapolate to sequence lengths longer than the ones encountered?during training. ~Attention is all you need.
為什么不用學(xué)得參數(shù)而采用sinusoid函數(shù)呢?sinusoidal函數(shù)并不受限于序列長(zhǎng)度,其可以在遇到訓(xùn)練集中未出現(xiàn)過的序列長(zhǎng)度時(shí)仍能很好的“extrapolate.” (外推),這體現(xiàn)了其具有一些inductive bias。
對(duì)于2,shaw et al.2018中的相對(duì)位置編碼Tensor是兩個(gè)需要參數(shù)學(xué)習(xí)的tensor.?
相對(duì)位置編碼矩陣是設(shè)定長(zhǎng)度為 2K + 1的(K是窗口大小) ,維度為$d_a$的2個(gè)tensor(分別對(duì)應(yīng)與key的RPR和value的RPR),其第i行表示相對(duì)距離為i的query,key(或是query, value)的相對(duì)位置編碼。這兩個(gè)tensor的參數(shù)都是需要訓(xùn)練學(xué)習(xí)的。那么顯然其是受限于最大長(zhǎng)度的。在RPR中規(guī)定了截?cái)嗟拇翱诖笮?#xff0c;在遇到超出窗口大小的情況時(shí),由于直接被截?cái)喽赡軄G失信息。
對(duì)于3,transformer-xl的相對(duì)位置編碼矩陣是一個(gè)sinusoid矩陣,不需要參數(shù)學(xué)習(xí)。
在transformer-xl中雖然也是引入了相對(duì)位置編碼矩陣,但是這個(gè)矩陣不同于shaw et al.2018。該矩陣$R_{i,j}$是一個(gè)sinusoid encoding 的矩陣(sinusoid 是借鑒的vanilla transformer中的),不涉及參數(shù)的學(xué)習(xí)。
具體實(shí)現(xiàn)可以參看代碼,這里展示了pytorch版本的位置編碼的代碼:
1 class PositionalEmbedding(nn.Module): 2 def __init__(self, demb): 3 super(PositionalEmbedding, self).__init__() 4 5 self.demb = demb 6 7 inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 8 self.register_buffer('inv_freq', inv_freq) 9 10 def forward(self, pos_seq, bsz=None): 11 sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 12 pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 13 14 if bsz is not None: 15 return pos_emb[:,None,:].expand(-1, bsz, -1) 16 else: 17 return pos_emb[:,None,:]其中$demb$是embedding的維度。
sinusoid的shape:[batch_size, seq_length × (d_emb / 2)]
sin,cos concat之后,pos_emb的shape:[batch_size, seq_length × d_emb]
pos_emb[:,None,:]之后的shape:[batch_size, 1, seq_length × d_emb]
那么綜合起來看,transformer-xl的模型的hidden states表達(dá)式為:
4. 高效計(jì)算方法
在該表達(dá)式中,在計(jì)算$W_{k,R}R_{i-j}$時(shí),需要對(duì)每一對(duì)(i,j)進(jìn)行計(jì)算,時(shí)間復(fù)雜度是$O(n^2)$。paper中提出了高效的計(jì)算方法,使其降為$O(n).$
核心算法:發(fā)現(xiàn)(b)項(xiàng)組成的矩陣的行列之間的關(guān)系,構(gòu)建一個(gè)矩陣,將其按行左移,恰好是(b)項(xiàng)矩陣$B$,而所構(gòu)建的矩陣只需要$O(n)$時(shí)間。
由于相對(duì)距離(i-j)的變化范圍是[0, M + L - 1] (其中M是memory的長(zhǎng)度,L是當(dāng)前segment的長(zhǎng)度)
那么令:
那么將(b)項(xiàng)應(yīng)用與所有的(i,j)可得一個(gè)$L \times (M + L)$的矩陣 $B$: (其中q是對(duì)E經(jīng)過$W_q$映射變換后的表示)
看這些帶紅線的部分,是不是只有q的下標(biāo)不一樣!
如果我們定義$\widetilde{B}$:
對(duì)比$B$與$\widetilde{B}$發(fā)現(xiàn),將$\widetilde{B}$的第i行左移 $L - 1 - i$個(gè)單位即為$B$。而$\widetilde{B}$的計(jì)算僅涉及到兩個(gè)矩陣的相乘,因此$B$的計(jì)算也僅需要求$qQ^T$之后按行左移即可得到,時(shí)間復(fù)雜度降為$O(n)$!
同理,可以求(d)項(xiàng)的矩陣D。
?
這樣將B,D原本需要$O(n^2)$的復(fù)雜度,降為了$O(n)$.
5. 總結(jié)
Transformer-XL針對(duì)其需要對(duì)segment中相對(duì)位置的token加入位置信息的特點(diǎn),將vanilla transformer中的絕對(duì)位置編碼方式,改進(jìn)為相對(duì)位置編碼。改進(jìn)中涉及到位置編碼矩陣的替換、query全局向量替換、以及為key的相對(duì)位置編碼和embedding分別采用了不同的線性映射矩陣W。
transformer-xl與shaw et al.2018的相對(duì)編碼方式亦有區(qū)別。1. shaw et al.2018的相對(duì)編碼矩陣是一個(gè)需要學(xué)習(xí)參數(shù)的tensor,受限于相對(duì)距離的窗口長(zhǎng)度設(shè)置;而transformer-xl的相對(duì)編碼矩陣是一個(gè)無需參數(shù)學(xué)習(xí)的使用sinusoid表示的矩陣,可以更好的generalize到訓(xùn)練集中未出現(xiàn)長(zhǎng)度的長(zhǎng)序列中;2. 相比與shaw et al.2018,transformer-xl的attention score中引入了基于content的bias,和基于位置的bias。
另外在計(jì)算優(yōu)化上,transformer-xl提出了一種高效計(jì)算(b)(d)矩陣運(yùn)算的方法。通過構(gòu)造可以在$O(n)$時(shí)間內(nèi)計(jì)算的新矩陣,并將其項(xiàng)左移構(gòu)建出目標(biāo)矩陣B,D的計(jì)算方式,將時(shí)間復(fù)雜度由$O(n^2)$降為$O(n)$。
?
參考:
1.?Transformer-XL: Attentive Language Models?Beyond a Fixed-Length Context:?https://arxiv.org/pdf/1901.02860.pdf
2.?Self-Attention with Relative Position Representations (shaw et al.2018):?https://arxiv.org/pdf/1803.02155.pdf
3.?[NLP] 相對(duì)位置編碼(一) Relative Position Representatitons (RPR) - Transformer?https://www.cnblogs.com/shiyublog/p/11185625.html
?[支付寶] 感謝您的捐贈(zèng)!
That's been one of my mantras - focus and simplicity. Simple can be harder than complex: you have to work hard to get your thinking clean to make it simple. But it's worth it in the end beacuse once you get there, you can move mountains. ~ Steve Jobs
轉(zhuǎn)載于:https://www.cnblogs.com/shiyublog/p/11236212.html
創(chuàng)作挑戰(zhàn)賽新人創(chuàng)作獎(jiǎng)勵(lì)來咯,堅(jiān)持創(chuàng)作打卡瓜分現(xiàn)金大獎(jiǎng)總結(jié)
以上是生活随笔為你收集整理的[NLP] 相对位置编码(二) Relative Positional Encodings - Transformer-XL的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: vue项目遇到error This li
- 下一篇: MySQL过滤相同binlog_通过Li