超细节!从源代码剖析Self-Attention知识点
?PaperWeekly 原創 ·?作者|海晨威
學校|同濟大學碩士生
研究方向|自然語言處理
在當前的 NLP 領域,Transformer / BERT 已然成為基礎應用,而 Self-Attention? 則是兩者的核心部分,下面嘗試用 Q&A 和源碼的形式深入 Self-Attention 的細節。
Q&A
1. Self-Attention 的核心是什么?
Self-Attention 的核心是用文本中的其它詞來增強目標詞的語義表示,從而更好的利用上下文的信息。
2. Self-Attention 的時間復雜度是怎么計算的?
Self-Attention 時間復雜度:,這里,n 是序列的長度,d 是 embedding 的維度,不考慮 batch 維。
Self-Attention 包括三個步驟:相似度計算,softmax 和加權平均。
它們分別的時間復雜度是:
相似度計算 可以看作大小為 和 的兩個矩陣相乘:,得到一個 的矩陣。
softmax 就是直接計算了,時間復雜度為 。
加權平均 可以看作大小為 和 的兩個矩陣相乘:,得到一個 的矩陣。
因此,Self-Attention 的時間復雜度是 。
這里再提一下 Tansformer 中的 Multi-Head Attention,多頭 Attention,簡單來說就是多個 Self-Attention 的組合,它的作用類似于 CNN 中的多核。
多頭的實現不是循環的計算每個頭,而是通過 transposes and reshapes,用矩陣乘法來完成的。
In practice, the multi-headed attention are done with transposes and reshapes rather than actual separate tensors. —— 來自 google BERT 源代碼注釋
Transformer/BERT 中把 d ,也就是 hidden_size/embedding_size 這個維度做了 reshape 拆分,可以去看 Google 的 TF 源碼或者上面的 pytorch 源碼:
hidden_size (d) = num_attention_heads (m) * attention_head_size (a),也即 d=m*a。
并將 num_attention_heads 維度 transpose 到前面,使得 Q 和 K 的維度都是 (m,n,a),這里不考慮 batch 維度。
這樣點積可以看作大小為 (m,n,a) 和 (m,a,n) 的兩個張量相乘,得到一個 (m,n,n) 的矩陣,其實就相當于 m 個頭,時間復雜度是 。
張量乘法時間復雜度分析參見:矩陣、張量乘法的時間復雜度分析 [1]。
因此 Multi-Head Attention 時間復雜度就是 ,而實際上,張量乘法可以加速,因此實際復雜度會更低一些。
3. 不考慮多頭的原因,self-attention中詞向量不乘QKV參數矩陣,會怎么樣?
對于 Attention 機制,都可以用統一的 query/key/value 模式去解釋,而對于? self-attention,一般會說它的 q=k=v,這里的相等實際上是指它們來自同一個基礎向量,而在實際計算時,它們是不一樣的,因為這三者都是乘了 QKV 參數矩陣的。那如果不乘,每個詞對應的 q,k,v 就是完全一樣的。
在 self-attention 中,sequence 中的每個詞都會和 sequence 中的每個詞做點積去計算相似度,也包括這個詞本身。
在相同量級的情況下,qi 與 ki 點積的值會是最大的(可以從“兩數和相同的情況下,兩數相等對應的積最大”類比過來)。
那在 softmax 后的加權平均中,該詞本身所占的比重將會是最大的,使得其他詞的比重很少,無法有效利用上下文信息來增強當前詞的語義表示。
而乘以 QKV 參數矩陣,會使得每個詞的 q,k,v 都不一樣,能很大程度上減輕上述的影響。
當然,QKV 參數矩陣也使得多頭,類似于 CNN 中的多核,去捕捉更豐富的特征/信息成為可能。
4. 在常規 attention 中,一般有 k=v,那 self-attention 可以嘛?
self-attention 實際只是 attention 中的一種特殊情況,因此 k=v 是沒有問題的,也即 K,V 參數矩陣相同。
擴展到 Multi-Head Attention 中,乘以 Q、K 參數矩陣之后,其實就已經保證了多頭之間的差異性了,在 q 和 k 點積 +softmax 得到相似度之后,從常規 attention 的角度,覺得再去乘以和 k 相等的 v 會更合理一些。
在 Transformer / BERT 中,完全獨立的 QKV 參數矩陣,可以擴大模型的容量和表達能力。
但采用 Q,K=V 這樣的參數模式,我認為也是沒有問題的,也能減少模型的參數,又不影響多頭的實現。
當然,上述想法并沒有做過實驗,為個人觀點,僅供參考。
源碼
在整個 Transformer / BERT 的代碼中,(Multi-Head Scaled Dot-Product) Self-Attention 的部分是相對最復雜的,也是 Transformer / BERT 的精髓所在,這里給出 Pytorch 版本的實現 [2],并對重要的代碼加上了注釋和維度說明。
話不多說,都在代碼里,它主要有三個部分:
初始化:包括有幾個頭,每個頭的大小,并初始化 QKV 三個參數矩陣。
class?SelfAttention(nn.Module):def?__init__(self,?config):super(SelfAttention,?self).__init__()if?config.hidden_size?%?config.num_attention_heads?!=?0:raise?ValueError("The?hidden?size?(%d)?is?not?a?multiple?of?the?number?of?attention?""heads?(%d)"?%?(config.hidden_size,?config.num_attention_heads))#?在Transformer/BERT中,這里的?all_head_size?就等于?config.hidden_size#?應該是一種簡化,為了從embedding到最后輸出維度都保持一致#?這樣使得多個attention頭合起來維度還是config.hidden_size#?而?attention_head_size?就是每個attention頭的維度,要保證可以整除self.num_attention_heads?=?config.num_attention_headsself.attention_head_size?=?int(config.hidden_size?/?config.num_attention_heads)self.all_head_size?=?self.num_attention_heads?*?self.attention_head_size#?三個參數矩陣self.query?=?nn.Linear(config.hidden_size,?self.all_head_size)self.key?=?nn.Linear(config.hidden_size,?self.all_head_size)self.value?=?nn.Linear(config.hidden_size,?self.all_head_size)self.dropout?=?nn.Dropout(config.attention_probs_dropout_prob)transposes and reshapes:這個函數主要是把維度大小為 [batch_size * seq_length * hidden_size] 的 q,k,v 向量變換成 [batch_size * num_attention_heads * seq_length * attention_head_size],便于后面做 Multi-Head Attention。
????def?transpose_for_scores(self,?x):"""shape?of?x:?batch_size?*?seq_length?*?hidden_size這個操作是把hidden_size分解為?self.num_attention_heads?*?self.attention_head_size然后再交換?seq_length?維度?和?num_attention_heads?維度為什么要做這一步:因為attention是要對query中的每個字和key中的每個字做點積,即是在 seq_length 維度上query和key的點積是?[seq_length?*?attention_head_size]?*?[attention_head_size?*?seq_length]=[seq_length?*?seq_length]"""#?這里是一個維度拼接:(1,2)+(3,4)?->?(1, 2, 3, 4)new_x_shape?=?x.size()[:-1]?+?(self.num_attention_heads,?self.attention_head_size)x?=?x.view(*new_x_shape)return?x.permute(0,?2,?1,?3)前向計算: 乘以 QKV 參數矩陣 —> transposes and reshapes —> 做 scaled —> 加 attention mask —> Softmax —> 加權平均 —> 維度恢復。
?def?forward(self,?hidden_states,?attention_mask):#?shape?of?hidden_states?and?mixed_*_layer:?batch_size?*?seq_length?*?hidden_sizemixed_query_layer?=?self.query(hidden_states)mixed_key_layer?=?self.key(hidden_states)mixed_value_layer?=?self.value(hidden_states)#?shape?of?*_layer:?batch_size?*?num_attention_heads?*?seq_length?*?attention_head_sizequery_layer?=?self.transpose_for_scores(mixed_query_layer)key_layer?=?self.transpose_for_scores(mixed_key_layer)value_layer?=?self.transpose_for_scores(mixed_value_layer)#?Take?the?dot?product?between?"query"?and?"key"?to?get?the?raw?attention?scores.#?shape?of?attention_scores:?batch_size?*?num_attention_heads?*?seq_length?*?seq_lengthattention_scores?=?torch.matmul(query_layer,?key_layer.transpose(-1,?-2))#?這里就是做?Scaled,將方差統一到1,避免維度的影響attention_scores?/=?math.sqrt(self.attention_head_size)#?shape?of?attention_mask:?batch_size?*?1?*?1?*?seq_length.?它可以自動廣播到和attention_scores一樣的維度#?我們初始輸入的attention_mask是:batch_size * seq_length,做了兩次unsqueeze之后得到當前的attention_maskattention_scores?=?attention_scores?+?attention_mask#?Normalize?the?attention?scores?to?probabilities.?Softmax?不改變維度#?shape?of?attention_scores:?batch_size?*?num_attention_heads?*?seq_length?*?seq_lengthattention_probs?=?nn.Softmax(dim=-1)(attention_scores)attention_probs?=?self.dropout(attention_probs)#?shape?of?value_layer:?batch_size?*?num_attention_heads?*?seq_length?*?attention_head_size#?shape?of?first?context_layer:?batch_size?*?num_attention_heads?*?seq_length?*?attention_head_size#?shape?of?second?context_layer:?batch_size?*?seq_length?*?num_attention_heads?*?attention_head_size# context_layer 維度恢復到:batch_size * seq_length * hidden_sizecontext_layer?=?torch.matmul(attention_probs,?value_layer)context_layer?=?context_layer.permute(0,?2,?1,?3).contiguous()new_context_layer_shape?=?context_layer.size()[:-2]?+?(self.all_head_size,)context_layer?=?context_layer.view(*new_context_layer_shape)return?context_layerAttention is all you need ! 希望這篇文章能讓你對 Self-Attention 有更深的理解。
參考文獻
[1]https://liwt31.github.io/2018/10/12/mul-complexity/
[2]https://github.com/hichenway/CodeShare/tree/master/bert_pytorch_source_code
更多閱讀
#投 稿?通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學習心得或技術干貨。我們的目的只有一個,讓知識真正流動起來。
?????來稿標準:
? 稿件確系個人原創作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發,請在投稿時提醒并附上所有已發布鏈接?
? PaperWeekly 默認每篇文章都是首發,均會添加“原創”標志
?????投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發送?
? 請留下即時聯系方式(微信或手機),以便我們在編輯發布時和作者溝通
????
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結
以上是生活随笔為你收集整理的超细节!从源代码剖析Self-Attention知识点的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 起底商汤校招需求TOP 10岗位 | 智
- 下一篇: win7卡安装更新失败怎么办啊 win7