Transformer性能优化:运算和显存
?作者 |?王晗煒
單位 |?中科院信工所ASCII LAB
研究方向 |?自然語(yǔ)言處理
概述
Transformer [1]?在如今的深度學(xué)習(xí)領(lǐng)域有著不可或缺的地位,它被廣泛應(yīng)用于自然語(yǔ)言處理、圖像處理等領(lǐng)域,有著極大的影響力。自注意力機(jī)制作為 Transformer 模型的核心,其幾乎不含歸納偏置的特性在足量數(shù)據(jù)的基礎(chǔ)上帶給了 Transformer 模型強(qiáng)大的建模能力,也給 Transformer 帶來(lái)了一系列效率問(wèn)題:運(yùn)算和顯存的限制使得其無(wú)法在長(zhǎng)序列問(wèn)題建模上應(yīng)用。針對(duì)此問(wèn)題,許多工作對(duì) Transformer 的結(jié)構(gòu)進(jìn)行了魔改,使得其性能得到優(yōu)化,下面將對(duì)部分有代表性的工作進(jìn)行介紹。
經(jīng)典Transformer結(jié)構(gòu)
經(jīng)典的 Transformer 結(jié)構(gòu)包含 Encoder 和 Decoder 兩個(gè)部分,主要的組成組件可以分為以下三個(gè)小塊:多頭自注意力(Multi-Head Self-Attention)、前饋神經(jīng)網(wǎng)絡(luò)(Position-wise Feed-forward)以及殘差連接(Residual Connect)。
▲ Transformer模型結(jié)構(gòu)圖
2.1 Multi-Head Self-Attention
自注意力機(jī)制是一種基于縮放點(diǎn)積的注意力機(jī)制,其將原始序列的輸入向量投影至三個(gè)不同的空間,作為 query、key 和 value,每個(gè)序列中的輸入都會(huì)對(duì)整個(gè)序列進(jìn)行注意力計(jì)算,包括自身。在 Transformer 中這個(gè)過(guò)程還引入了多頭機(jī)制,即對(duì)輸入向量進(jìn)行維度的切分,在每個(gè)頭的空間中進(jìn)行同樣的操作,最后再將每個(gè)頭的信息拼接起來(lái),以便學(xué)到更豐富的信息。
2.2 Position-wise Feed-forward
在 Transformer 的每個(gè)小模塊中,多頭自注意力模塊后面都會(huì)接上一個(gè)全連接的前向神經(jīng)網(wǎng)絡(luò)模塊,其對(duì)輸入向量的維度進(jìn)行放大再縮小(一般為放大四倍再縮小,例如 BERT),以下為模塊的公式描述:
2.3 Residual Connect
為了使模型能夠堆疊更多的子模塊,完成深度網(wǎng)絡(luò)的訓(xùn)練而避免梯度消失等問(wèn)題,Transformer 內(nèi)每一個(gè)自注意力模塊和前向神經(jīng)網(wǎng)絡(luò)模塊均會(huì)伴有殘差連接模塊:
2.4 復(fù)雜度分析
通過(guò)對(duì)以上三個(gè)模塊的簡(jiǎn)單回顧,我們可以發(fā)現(xiàn) Transformer 內(nèi)主要運(yùn)算資源消耗集中在 Mulit-Head Attention 模塊和 FFN 模塊,稍加分析我們可以得知其時(shí)間復(fù)雜度和空間復(fù)雜度如下:
▲ self-attention 和 FFN 模塊復(fù)雜度分析,T 代表序列長(zhǎng)度,D 表示隱層向量的維度
從表中的分析可以看出,當(dāng)輸入序列較短的時(shí)候,模型的主要計(jì)算開銷集中在 FFN 模塊(復(fù)雜度與隱層向量維度的平方成正比),而當(dāng)輸入序列較長(zhǎng)時(shí),模型的計(jì)算開銷則會(huì)轉(zhuǎn)移至 Multi-Head Self-Attention 模塊(復(fù)雜度與序列長(zhǎng)度的平方成正比)。目前基于 Transformer 結(jié)構(gòu)的 BERT 等預(yù)訓(xùn)練模型在一定維度下已經(jīng)能夠達(dá)到較好的效果,因此 FFN 模塊的計(jì)算開銷一般能夠被承受,但是當(dāng)輸入序列過(guò)長(zhǎng)時(shí),模型則無(wú)法進(jìn)行處理,這也是目前 Transformer 架構(gòu)模型的一個(gè)通病。下面我們將介紹一些工作針對(duì)此問(wèn)題對(duì) Transformer 模型的改進(jìn)。
基于遞歸連接的改進(jìn)
使用經(jīng)典的 Transformer 模型時(shí),當(dāng)遇到輸入為長(zhǎng)序列,很多模型采用的方式都是直接將序列進(jìn)行截?cái)?#xff0c;像 Bert 這種預(yù)訓(xùn)練模型一般只保留前 512 個(gè)字符,或者其他的模型將文本劃分為多個(gè) segments,訓(xùn)練的時(shí)候?qū)γ總€(gè) segment 單獨(dú)處理,segments之間沒有聯(lián)系,這便會(huì)導(dǎo)致以下兩個(gè)問(wèn)題:
文本最長(zhǎng)語(yǔ)義依賴關(guān)系取決于 segment 的長(zhǎng)度,不同 segment 之間沒有關(guān)聯(lián)
分割出來(lái)的 segments 語(yǔ)義不完整,存在一句話分隔在兩個(gè) segment 之中的情況
針對(duì)這一問(wèn)題,Transformer-XL [2] 對(duì)原始的 Transformer 結(jié)構(gòu)做了以下兩點(diǎn)改進(jìn):
提出片段級(jí)遞歸機(jī)制,引入一個(gè)記憶模塊,循環(huán)用來(lái)建模片段之間的聯(lián)系。使得 Transformer 能夠?qū)﹂L(zhǎng)距離依賴進(jìn)行建模,解決上下文碎片化問(wèn)題
提出相對(duì)位置編碼機(jī)制,代替絕對(duì)位置編碼。避免使用記憶模塊時(shí)出現(xiàn)混淆
首先我們可以結(jié)合論文中的圖片來(lái)分析經(jīng)典的 Transformer 是如何對(duì)長(zhǎng)文本編碼的:
▲ Transformer分段訓(xùn)練和測(cè)試
通過(guò)這張圖可以清晰地看見在訓(xùn)練過(guò)程中每個(gè) segment 會(huì)分別編碼,相互之間沒有任何聯(lián)系。在測(cè)試時(shí)為了保證測(cè)試的完整性,其對(duì) segment 的劃分更為緊密,不同 segment 之間的重疊度很高,造成了運(yùn)算資源開銷巨大。
而為了解決這一問(wèn)題,Transformer-XL 引入了一個(gè) memory 狀態(tài),在對(duì)當(dāng)前 segment 進(jìn)行處理的時(shí)候,緩存并利用上一個(gè) segment 中所有 layer 的隱向量序列,綜合兩個(gè) segment 的信息進(jìn)行 attention 等操作。為了節(jié)約計(jì)算資源的開銷,上一個(gè) segment 的所有隱層向量只參與前向計(jì)算,不進(jìn)行反向傳播,具體計(jì)算過(guò)程如下:?
通過(guò)此機(jī)制,Transformer-XL 能夠?qū)⑿蛄械乃?segments 連接起來(lái),在訓(xùn)練測(cè)試過(guò)程中保持前后文的關(guān)聯(lián)。并且可以顯著地提高測(cè)試時(shí)的效率,其過(guò)程可以如下圖所示:
▲ Transformer-XL的訓(xùn)練和測(cè)試過(guò)程
在引入片段級(jí)遞歸機(jī)制后,雖然能夠建模出截?cái)嗟奈谋镜年P(guān)聯(lián)性,但是也帶來(lái)了一個(gè)額外的問(wèn)題:每個(gè) segment 都添加相同的位置編碼,多個(gè) segments 之間無(wú)法區(qū)分位置關(guān)系。針對(duì)此問(wèn)題,Transfomer-XL 放棄使用絕對(duì)位置編碼,而是采用相對(duì)位置編碼,在計(jì)算當(dāng)前位置隱向量的時(shí)候,考慮與之依賴 token 的相對(duì)位置關(guān)系。
首先將包含絕對(duì)位置編碼的 Attention 計(jì)算展開:
可以發(fā)現(xiàn)展開后的表達(dá)式與位置信息相關(guān)的僅為 ,如果以一個(gè) 的相對(duì)位置視角來(lái)看, 對(duì)于所有位置均為一個(gè)定值,變化的只是 的值,因此在使用相對(duì)位置編碼時(shí)我們可以將公式改寫為以下的形勢(shì):
其中 和 為隨機(jī)初始化的可學(xué)習(xí)參數(shù),在計(jì)算 self-attention 時(shí),由于 query 所有位置對(duì)應(yīng)的 query 向量是一樣的,因此不管的 query 位置如何,對(duì)不同單詞的 attention 偏差應(yīng)保持相同。 為計(jì)算出來(lái)的位置向量編碼,同 Transformer 里的計(jì)算方式一致。
在原論文中作者使用字符級(jí)的語(yǔ)言模型任務(wù)對(duì) Transformer-XL 進(jìn)行了評(píng)估,與當(dāng)時(shí)的一些基線模型相比取得了最優(yōu)的結(jié)果:
▲ 部分語(yǔ)言模型實(shí)驗(yàn)結(jié)果實(shí)驗(yàn)
同時(shí),如同前面分析的一致,Transformer-XL 在進(jìn)行長(zhǎng)文本推理時(shí)有著十分明顯的速度優(yōu)勢(shì),并且隨著序列長(zhǎng)度的增大愈發(fā)明顯:
▲ 推理加速同序列長(zhǎng)度的關(guān)系
總體而言,Transformer-XL 通過(guò)引入遞歸連接機(jī)制和相對(duì)位置編碼機(jī)制,使得 Transformer 模型有了能夠處理長(zhǎng)序列的能力,無(wú)論在模型效果還是性能上均有可觀的提升,但是其對(duì)于長(zhǎng)序列的解決思路是停留在分段這一個(gè)前提上的,并沒有真正得一次性處理全部長(zhǎng)度的序列,只不過(guò)通過(guò)增加一些額外的空間開銷來(lái)讓實(shí)驗(yàn)性能和效率有一定提升,也就是空間換性能和時(shí)間,從這個(gè)角度上來(lái)看 Transformer-XL 其實(shí)并不算優(yōu)化了復(fù)雜度,反而增加了復(fù)雜度,提升的速度也只是相對(duì)于截?cái)嗍降脑?Transformer 速度的提升。
基于稀疏注意力的改進(jìn)
Transformer 強(qiáng)大的性能可以說(shuō)很大程度上來(lái)源于其特殊的 Self-Attention 注意力,而其巨大的計(jì)算開銷也來(lái)源于此,可謂禍福相依,接下來(lái)的改進(jìn)均將圍繞這一機(jī)制進(jìn)行,即如何在保證模型性能的前提下盡量減少 Attention 過(guò)程的計(jì)算開銷。
4.1 Sparse Transformer
引入稀疏注意力是很多工作的解決思路,其中 Sparse Transformer [3] 是較為早期的工作,其關(guān)注長(zhǎng)序列的生成問(wèn)題,針對(duì) Transformer 因效率問(wèn)題無(wú)法應(yīng)用于長(zhǎng)序列生成的任務(wù),提出一種稀疏注意力機(jī)制來(lái)解決此問(wèn)題,最終在圖像、文本、音頻三個(gè)模態(tài)上驗(yàn)證了其效果。
首先作者為引入稀疏注意力機(jī)制做了一些實(shí)驗(yàn)進(jìn)行支撐:其構(gòu)建了一個(gè)由原始 Transformer 模塊深層全注意力網(wǎng)絡(luò),用于圖像的像素級(jí)自回歸生成任務(wù),即把圖像的像素點(diǎn)按照從上到下從左到右的方式當(dāng)成一個(gè)序列,然后在序列上去做自回歸。其在生成過(guò)程中對(duì)網(wǎng)絡(luò)的注意力情況進(jìn)行了觀察,如下圖所示,底部的黑色部分表示還沒有生成的部分,白色凸顯的部分則是注意力權(quán)重高的地方。下圖是比較低的層次的注意力,可以看到,低層次的時(shí)候主要關(guān)注的還是局部區(qū)域的部分。
▲ 淺層注意力
在網(wǎng)絡(luò)的 19 層和 20 層,注意力的分布也發(fā)生了變化,開始關(guān)注橫向和縱向的像素點(diǎn):
▲ 19層和20層的注意力
而還有一些層的注意力甚至學(xué)習(xí)到了物體本身的一些特征,例如它們的邊緣:
▲ 某些層的注意力
而我們通過(guò)對(duì)這些層注意力的觀察可以得知不同層的注意力關(guān)注的區(qū)域是不一樣的,但是無(wú)論其關(guān)注的是哪一部分的信息,注意力權(quán)重高的區(qū)域始終很小,即注意力是稀疏的,這為稀疏注意力機(jī)制的引入提供了實(shí)驗(yàn)依據(jù)。
作者在設(shè)計(jì)稀疏注意力之前先對(duì) Attention 的計(jì)算過(guò)程進(jìn)行了重新定義:
這里同原始 Attention 范式的區(qū)別主要在于引入了 Attend 項(xiàng),即當(dāng)前位置關(guān)注的范圍,這里 表示 需要關(guān)注的范圍,對(duì)于自回歸生成任務(wù)中的 Transformer,。根據(jù)這個(gè)范式,我們?cè)诙x不同的稀疏注意力時(shí)就是在定義 。
論文中定義了兩種 Attention 機(jī)制,第一種為 Strided Attention,其描述如下:
▲ Strided Attention模式圖
上面的圖代表是二維的圖像像素圖,下面的圖是序列的注意力矩陣圖。這個(gè)注意力其實(shí)就分成兩個(gè)部分,每個(gè)字符關(guān)注前面若干個(gè)字符,然后每隔相同的間距關(guān)注一個(gè)字符。對(duì)應(yīng)到圖像里面就是一個(gè)關(guān)注行,一個(gè)關(guān)注列。作者認(rèn)為這種 attention 機(jī)制比較傾向于有規(guī)律的數(shù)據(jù),比如圖像音頻,不適合文本。
另外論文中還設(shè)計(jì)了一種 Fixed Attention,描述如下:
▲ Fixed Attention模式圖
這里 c?是一個(gè)超參數(shù)。既然是針對(duì)文本的,我們看這個(gè) attention 其實(shí)可以把左邊的二維圖的每一行看成一個(gè)句子,其實(shí)就變成了每個(gè) token 關(guān)注自己句子之前的 token,然后關(guān)注之前句子特定位置的 token。相當(dāng)于一個(gè)是局部注意力,一個(gè)是全局注意力。
可以看出每一種 Attention 模式都包含了多種關(guān)注方式,因此論文給出了三種 attention 結(jié)合的方式:
不同層使用不同的注意力
將所有的注意力方式融合
不同的頭使用不同的注意力
論文直接表明第三種方式的效果最好。作者在三個(gè)模態(tài)的數(shù)據(jù)集上均進(jìn)行了自回歸任務(wù)的實(shí)驗(yàn),均達(dá)到了較好的效果,甚至在速度和性能上都超過(guò)了原始的 self-attention。
▲ 三個(gè)模態(tài)上Sparse Transformer的性能結(jié)果
▲ 文本和圖像上Sparse Transformer的運(yùn)行效率
由實(shí)驗(yàn)結(jié)果還可以看出作者設(shè)計(jì)的注意力模式確實(shí)分別使用于圖像和文本。至此為對(duì) Sparse Transformer 關(guān)于稀疏注意力部分的簡(jiǎn)單介紹,該模型提出較早,雖然效果不如后面一些同樣基于稀疏注意力的模型,但是其范式的總結(jié)對(duì)后續(xù)模型的發(fā)展有著深遠(yuǎn)的意義。
4.2 Longformer
Longformer [4] 和之前介紹的 Sparse Transformer 都可被視為同一類改動(dòng),也是設(shè)計(jì)不同的 Attention 范圍,但是 Longformer是專門為 NLP 任務(wù)設(shè)計(jì)的,因此除了評(píng)測(cè)模型本身的性能,其也在很多 NLP 下游任務(wù)上進(jìn)行了實(shí)驗(yàn)。
Longformer 設(shè)計(jì)了三種 Attention 模式,分為 Sliding Window、Dilated Sliding Window 和 Global Attention:
Sliding Window:對(duì)每個(gè) token 設(shè)置一個(gè) w 大小的窗口,每次進(jìn)行 Attention 操作時(shí)僅對(duì)前后 w/2 個(gè)位置的 token
Dilated Sliding Window:為了能夠注意到更遠(yuǎn)距離的 token,還可以對(duì)窗口進(jìn)行挖洞,每隔 d 個(gè)位置關(guān)注一個(gè)字符
Global Attention:在捕獲到局部信息的同時(shí),也不能丟失全局的信息,因此存在一些字符能夠關(guān)注到所有的字符,并且所有字符也能關(guān)注到此字符
▲ Longformer中的Attention機(jī)制
除此之外論文中給出了三個(gè)使用這些 Attention 的細(xì)節(jié):
窗口的大小會(huì)隨著模型層數(shù)的加深變大
多頭注意力中,可以另某些頭使用 Sliding window,另外一些頭使用 Dilated sliding window
對(duì)于不同的任務(wù),設(shè)計(jì)不同的位置使用 global attention
在實(shí)驗(yàn)部分,Longformer 首先和之前的模型一樣在自回歸任務(wù)上進(jìn)行了實(shí)驗(yàn):
▲ Longformer 在字符級(jí)自回歸任務(wù)上的性能
可以看出 Longformer 較之于之前的模型有一個(gè)更好的效果,而如前文所提到的,Longformer 是專門為 NLP 任務(wù)設(shè)計(jì)的 Transformer 模型,因此其在下游任務(wù)上的表現(xiàn)也很重要。因?yàn)橹恍薷?Attention 的范圍,不修改模型的結(jié)構(gòu),可以直接在預(yù)訓(xùn)練模型 Roberta 的 checkpoint 上繼續(xù)訓(xùn)練并且運(yùn)用到下游任務(wù)上,這里預(yù)訓(xùn)練的任務(wù)也是 MLM。將窗口大小設(shè)置為 512,添加了額外的 position embedding 到 4096 的大小,通過(guò)多次復(fù)制其 512 個(gè)位置嵌入來(lái)對(duì)其進(jìn)行初始化。
隨后其在三個(gè)下游任務(wù)上進(jìn)行了 Finetuning:
Question answering:問(wèn)題和候選答案使用 Global Attention
Coreference Resolution:不使用 Global Attention
Document Classification:[CLS] 使用 Global Attention
其實(shí)驗(yàn)結(jié)果如下所示:
▲ Longformer在下游任務(wù)中的效果
可以看出 Longformer 在下游的任務(wù)上的表現(xiàn)基本都超過(guò)了 Roberta,而且下游任務(wù)的數(shù)據(jù)集文本越長(zhǎng) Longformer 的優(yōu)勢(shì)更為明顯。
4.3 Big Bird
還是在這個(gè)稀疏注意力的模式上,谷歌又繼續(xù)推出了?[5],除了之前的 window attention 和 global attention,還加入了一種 random attention,讓模型不局限于人為增加的強(qiáng)先驗(yàn)。而對(duì)于 global attention 的使用,其還提出了兩種方式:
Big Bird-ITC:選擇已經(jīng)存在于序列的詞匯使用 global attention
Big Bird-ETC:添加額外的特殊字符使用 global attention
該模型沒有單獨(dú)對(duì)模型效率進(jìn)行評(píng)估,而是在大量下游任務(wù)上進(jìn)行了實(shí)驗(yàn)并取得了良好的效果:
▲ Big Bird在問(wèn)答任務(wù)上的表現(xiàn)
總體而言,這三種都可以看成是一類稀疏 attention,都是通過(guò)引入一些先驗(yàn)知識(shí)來(lái)限定 attention 的范圍,從而提升效率,看起來(lái)是直觀有效的,但是存在兩個(gè)問(wèn)題:
通用性降低:self-attention 相對(duì)于 CNN、RNN 這類模型一大優(yōu)勢(shì)就是具有更廣泛的歸納偏置,也就是能夠處理更一般化的信息,做了這些限定之后其實(shí)算是一種倒退
稀疏 Attention 計(jì)算優(yōu)化困難:這種稀疏 attention 的做法雖然看上去很簡(jiǎn)單,但是真正在實(shí)現(xiàn)的時(shí)候是沒那么方便的,前面兩種都設(shè)計(jì)到 cuda 內(nèi)核的修改,后面的 big bird 設(shè)計(jì)了一種很復(fù)雜的分塊計(jì)算方法,實(shí)際的運(yùn)算優(yōu)化度不如理論上高
4.4 Reformer
前面幾種人為設(shè)計(jì)的稀疏 Attention 模式不一定能夠滿足待處理數(shù)據(jù)的特點(diǎn),而且不具有普適性,而 Reformer [6] 使用了一種局部哈希的方式來(lái)獲取每個(gè)字符需要關(guān)注的范圍。在論文的開始,其提出了三個(gè) Transformer 中存在的效率問(wèn)題并給出了自己的解決思路:
自注意力機(jī)制消耗的資源隨著序列長(zhǎng)度的增長(zhǎng)而平方倍增長(zhǎng):局部哈希
模型占用的內(nèi)存隨模型深度呈線性增長(zhǎng):可逆殘差
Feed-Forward 部分參數(shù)過(guò)多:分塊計(jì)算
首先我們可以介紹一下局部哈希算法,如下圖所示,對(duì)于空間中的點(diǎn),先將其投影到一個(gè)圓上,然后將分成八個(gè)區(qū)域,每個(gè)區(qū)域都代表一個(gè)獨(dú)立的值。隨機(jī)轉(zhuǎn)動(dòng)圓,記錄下投影后的點(diǎn)所在區(qū)域的值;那么經(jīng)過(guò)多次轉(zhuǎn)動(dòng)后,就為一個(gè)點(diǎn)得到了多個(gè)值,這些值就是點(diǎn)的哈希值。
▲ 局部哈希示例
上圖中的上半部分是兩個(gè)不相似的點(diǎn)組成的例子,它們的哈希值差別很大。下半部分則是兩個(gè)相似的點(diǎn),它們的哈希值也是一樣的。通過(guò)這個(gè)簡(jiǎn)單的示例我們就能發(fā)現(xiàn),在空間中距離越近的點(diǎn),其哈希完之后哈希值相同的概率也會(huì)越高,這就是局部敏感哈希的原理 [7]。
隨后我們可以通過(guò)矩陣乘法隨機(jī)構(gòu)造一個(gè)哈希函數(shù):
接著把原始 Attention 公式進(jìn)行重寫:
這里 表示 mask 項(xiàng), 表示歸一化項(xiàng), 嘖表示當(dāng)前字符注意的區(qū)域,這里規(guī)定為當(dāng)前字符的 query 的哈希值和序列中 key 的哈希值相同的位置進(jìn)行 Attention 操作。經(jīng)過(guò)這些定義,我們可以將原序列按哈希值進(jìn)行排序,并分為 個(gè)桶:
▲ 分桶哈希示意圖
上圖中左邊為序列的表示,可以觀察到其首先將序列進(jìn)行哈希然后排序。這首先便會(huì)存在一個(gè)問(wèn)題,就是 query 和 key 均通過(guò)一個(gè)哈希函數(shù)進(jìn)行哈希,因此在一個(gè)桶中,可能會(huì)發(fā)生 query 很多但是 key 很少的問(wèn)題,甚至,會(huì)有 qeury 很多而 key 不存在的問(wèn)題。為了解決這個(gè)問(wèn)題,這里讓 query 和 key 在同一空間,即生成 Q 和 K 的矩陣是同一個(gè)。這樣在避免了剛剛的問(wèn)題的同時(shí)又引入了一個(gè)新的問(wèn)題:同一個(gè) token 的 key 值和 query 值的點(diǎn)積值會(huì)遠(yuǎn)遠(yuǎn)高于不同 token 之間的值,因此論文中的做法是不計(jì)算本身的 Attention 值。
看到這里 Reformer 對(duì) Self-Attention 的優(yōu)化思路已經(jīng)很清晰了,就是講相似的 token 置于一個(gè)桶中,僅對(duì)在桶中進(jìn)行 Attention 操作。然而在實(shí)現(xiàn)的過(guò)程中,我們需要保證桶內(nèi)的元素分布均衡,否則在進(jìn)行計(jì)算優(yōu)化的時(shí)候會(huì)帶來(lái)很大的麻煩。在這里為了保證計(jì)算的均衡,采用的策略是分塊,如果一個(gè)桶的元素跨塊的話,則讓后面的塊再去 attend 到前一個(gè)塊中的元素(僅考慮自回歸任務(wù)時(shí)),如上圖左圖所示。由此可見 Reformer 的加速僅在模型結(jié)構(gòu)上就能夠?qū)崿F(xiàn),較之之前幾種方式更為方便。
另外,在局部敏感哈希中,隨機(jī)的次數(shù)越多,得到的哈希桶就越準(zhǔn)確,所以哈希值可以做多輪。這樣,P 就成為多次哈希的值的全集:
隨后論文針對(duì)局部哈希 Attention 的優(yōu)化做了一個(gè)很有意思的實(shí)驗(yàn),該實(shí)驗(yàn)的任務(wù)是重復(fù)句子的單詞預(yù)測(cè),即將兩個(gè)相同的句子通過(guò)特殊字符拼接在一起,以 的格式作為序列輸入,每次通過(guò)前一個(gè)句子預(yù)測(cè)后面一個(gè)句子。其使用了一個(gè)單層的 Transformer 模型,通過(guò)訓(xùn)練和測(cè)試時(shí)不同輪數(shù)的局部哈希進(jìn)行了效果對(duì)比,實(shí)驗(yàn)結(jié)果如下:
▲ 重復(fù)句預(yù)測(cè)實(shí)驗(yàn)
可以發(fā)現(xiàn)無(wú)論在訓(xùn)練時(shí)還是測(cè)試時(shí),哈希的輪數(shù)越多模型的效果越好,這證明了多輪哈希的合理性和重要性,同時(shí)通過(guò)這個(gè)實(shí)驗(yàn)我們也可以進(jìn)一步認(rèn)識(shí)到:局部哈希函數(shù)無(wú)需訓(xùn)練,只需要隨機(jī)初始化,并且訓(xùn)練和測(cè)試過(guò)程的哈希輪數(shù)和函數(shù)也可以不一致,自由度非常高。
在完成了對(duì) Self-Attention 的優(yōu)化后,論文還提出了一種可逆殘差的方式來(lái)解決模型過(guò)深帶來(lái)的性能開銷。在訓(xùn)練網(wǎng)絡(luò)的時(shí)候,一般需要記錄每層的激活值,用來(lái)在反向傳播的時(shí)候進(jìn)行計(jì)算。所以每增加一層,內(nèi)存也會(huì)隨之增長(zhǎng)。一般的殘差網(wǎng)絡(luò)的結(jié)構(gòu)如下:
我們?cè)诜聪騻鞑サ倪^(guò)程中無(wú)法通過(guò) y 復(fù)原 x 和 F(x) 的值,因此需要將其記錄下來(lái),而論文通過(guò)對(duì)?進(jìn)行拆分巧妙地避免了這一問(wèn)題:
在 Transformer 的模塊中 和 不用由 拆分,只需分別代表 Self-Attention 和 FeedForward 的輸入即可:
同時(shí),對(duì)于 FeedForward 部分,論文也進(jìn)行了優(yōu)化。該模塊的全稱為 Position-wise FeedForward,如名字所示,改模塊對(duì)序列中的每個(gè) token 的操作是一樣的,也可以理解為 token 之間不會(huì)進(jìn)行交互,因此當(dāng)輸入序列過(guò)長(zhǎng)時(shí),完全可以講序列分段輸入,來(lái)緩解內(nèi)存上的開銷。
論文針對(duì)自己優(yōu)化的模塊在一些基礎(chǔ)任務(wù)上進(jìn)行了實(shí)驗(yàn),來(lái)驗(yàn)證模型優(yōu)化的有效性:
▲ 驗(yàn)證共享q、k參數(shù)和可逆殘差的效果
▲不同輪數(shù)哈希取得的效果
▲? 不同輪數(shù)哈希對(duì)應(yīng)的模型效率
總體而言,Reformer 通過(guò)實(shí)驗(yàn)證實(shí)了自己模型結(jié)構(gòu)優(yōu)化的有效性,多輪哈希在帶來(lái)效果提升的同時(shí)也會(huì)一定程度上降低模型的效率。相較于之前介紹的集中稀疏 Attention 機(jī)制的模型,Reformer 不依賴人工設(shè)計(jì)先驗(yàn)知識(shí),模型的適用性更廣,并且能夠在模型代碼層面進(jìn)行優(yōu)化,不涉及到cuda內(nèi)核的修改。但是在模型的設(shè)計(jì)和實(shí)現(xiàn)上過(guò)于復(fù)雜,在時(shí)間和空間復(fù)雜度上反復(fù)橫跳,從最后的實(shí)驗(yàn)結(jié)果也可以看出模型只在超長(zhǎng)序列上效果較為明顯,能夠應(yīng)用的場(chǎng)景還是十分有限的。
基于低秩分解的改進(jìn)
同樣是解決 Self-Attention 部分的復(fù)雜度問(wèn)題,Linformer?[8]?從另外一個(gè)角度給出了自己的優(yōu)化方案。同 Sparse Transformer 一樣,其也先對(duì)自己的理論進(jìn)行了實(shí)驗(yàn)的論證。論文在在 Wiki103 和 IMDB 兩個(gè)數(shù)據(jù)集上,用 Roberta-large 模型上計(jì)算出的 Attention 矩陣做了奇異值分解,然后從下圖左兩圖中可以看出,前 128 維的奇異值累計(jì)值已經(jīng)占了到了 0.9 左右。
而在右圖中可以看到,越高層,128 個(gè)奇異值累積值就越高。在第 11 層,128個(gè)奇異值累積起來(lái)達(dá)到了0.96。因而說(shuō)明了,雖然 Attention 的計(jì)算結(jié)果是一個(gè) N x N 的矩陣,但其實(shí)一個(gè)低秩矩陣比如 N x 128 可能就已經(jīng)足夠存儲(chǔ) Attention 的所有信息。因此直接對(duì)矩陣進(jìn)行合理降維就能夠在保證模型效果的前提下完成 Self-Attention 過(guò)程復(fù)雜度從平方級(jí)到線性的轉(zhuǎn)變。
作者在論文中花了很多篇幅證明降維的合理性,在這里我們不過(guò)多贅述,感興趣的讀者可以自行閱讀原文。根據(jù)以上的結(jié)論,我們很容易想到在訓(xùn)練和測(cè)試的過(guò)程對(duì) K、V 的矩陣進(jìn)行 SVD 分解就能夠解決效率問(wèn)題,但實(shí)際上再測(cè)試過(guò)程中進(jìn)行 SVD 分解還是比較麻煩,因此論文中給出了一種十分簡(jiǎn)單粗暴的方式來(lái)代替這一步驟:直接給 key 和 value 矩陣加一層線性變換,將其的長(zhǎng)度變?yōu)橐粋€(gè)定值,其模型結(jié)構(gòu)如下所示:
▲ Linformer 模型結(jié)構(gòu)圖
由公式和模型圖均可知 Linformer 在 V 和 K 參與 Self-Attention 計(jì)算之前,通過(guò)兩個(gè)投影矩陣將其序列長(zhǎng)度那一維降至一個(gè)固定值,因此無(wú)論原始序列的輸入長(zhǎng)度是多少,最后均會(huì)變成一個(gè) k 長(zhǎng)度的向量,最后 Self-Attention 的時(shí)間復(fù)雜度也會(huì)變成線性復(fù)雜度。
論文在 MLM 預(yù)訓(xùn)練任務(wù)和 NLP 下游任務(wù)上均做了實(shí)驗(yàn):
▲?MLM 預(yù)訓(xùn)練任務(wù)實(shí)驗(yàn)
▲?下游任務(wù)實(shí)驗(yàn)
總體而言 Linformer 無(wú)論是在運(yùn)行效率還是下游任務(wù)的效果上都有著不錯(cuò)的表現(xiàn),實(shí)現(xiàn)的方式也十分簡(jiǎn)單,但是這種投影的方式卻有一個(gè)致命的缺陷,即無(wú)法做自回歸的生成任務(wù),因?yàn)橥队跋喈?dāng)于把序列信息都雜糅至一個(gè)定長(zhǎng)的向量,模型無(wú)法通過(guò)之前 causal mask 的方式將信息給掩蓋,這也是論文只進(jìn)行了 MLM 預(yù)訓(xùn)練任務(wù)的原因。
基于線性注意力的改進(jìn)
在線性注意力的探索上,除了對(duì) K、V 矩陣進(jìn)行下采樣之外,還有一些工作給出了另一種視角。讓我們繼續(xù)回到注意力的計(jì)算過(guò)程,其本質(zhì)上是三個(gè)矩陣的連乘,而矩陣的連乘是滿足結(jié)合律的,我們正常的計(jì)算順序是 ,前一步我們會(huì)得到一個(gè) 的矩陣,這一步將導(dǎo)致我的時(shí)間復(fù)雜度為 ,而如果我們以順序 計(jì)算,復(fù)雜度則會(huì)變成 。
▲ Attention 計(jì)算順序與復(fù)雜度
然而 softmax 的存在卻讓這一步的操作無(wú)法實(shí)現(xiàn),因?yàn)槠湫枰葘? 值指數(shù)化并歸一化,因此沒有辦法先計(jì)算后面兩個(gè)矩陣的乘積。然而一定需要 softmax 嗎?
首先我們先來(lái)思考一下 Self-Attention 的本質(zhì),其使用點(diǎn)積相乘并進(jìn)行 softmax 其實(shí)想得到的僅僅是 token 之間的相似度,因此其計(jì)算過(guò)程可以如下描述:
softmax 的目的其實(shí)只是使得注意力的值恒為正且滿足歸一化,而這兩個(gè)條件通過(guò)其他的方式顯然是可以實(shí)現(xiàn)的。Linear Transformer [9] 通過(guò)一種核函數(shù)的方式巧妙地替換了 softmax,其認(rèn)為只需要找一個(gè)恒為正的核函數(shù)將 Q 和 K 的值進(jìn)行映射便可完成以上過(guò)程:
這里 代表核函數(shù),其定義如下:
如此轉(zhuǎn)換我們就可以將softmax函數(shù)刪去從而使得整個(gè)運(yùn)算的復(fù)雜度變?yōu)?。
之后谷歌推出的 Performer [10]?的解決思路其實(shí)也和 Linear Transformer 是一致的,其工作的最大亮點(diǎn)在于為 softmax 找到了一個(gè)更為優(yōu)美的映射來(lái)替代并給出了理論證明,而不像之前 Linear Transformer 那樣僅通過(guò)核函數(shù)的替換顯得有些空穴來(lái)潮。其映射表示如下 [11]:
由公式可知其使用了一種向量采樣的方式來(lái)表示 ,隨后為了使采樣出的向量能夠表征更多的信息,即讓采出來(lái)的向量線性無(wú)關(guān),其對(duì)采樣出的向量使用了正交化技術(shù)。
Performer 也在一些任務(wù)上進(jìn)行了測(cè)評(píng),也是一些長(zhǎng)序列相關(guān)的任務(wù),我們可以簡(jiǎn)單了解一下其大致性能:
▲ Performer 訓(xùn)練和推理速度
▲?Performer 在蛋白質(zhì)序列預(yù)測(cè)任務(wù)上的性能
模型綜合評(píng)測(cè)
上文已經(jīng)對(duì) Transformer 的魔改模型分類別介紹了不少,通過(guò)這些介紹我們可能能夠大致了解這些工作改進(jìn)的思路和大概方法,但是很難知道這些模型在使用時(shí)的差異。針對(duì)這一問(wèn)題,谷歌推出了一個(gè)評(píng)測(cè)框架 LRA [12] 對(duì)一眾 Transformer 改進(jìn)模型進(jìn)行了統(tǒng)一評(píng)測(cè)。
LRA 旨在用一些(形式)簡(jiǎn)單、通用、有一定挑戰(zhàn)性的多個(gè)任務(wù)對(duì)這些模型進(jìn)行長(zhǎng)序列任務(wù)的評(píng)測(cè),其中包括:
Long Listops
Byte-Level Text Classification
Byte-Level Document Retrieval
Image Classification on Sequences of Pixels
Pathfinder (Long-Range Spatial Dependency)
Pathfinder-X (Long-Range Spatial Dependencies with Extreme Lengths)
以下為模型在這些任務(wù)上的表現(xiàn):
▲ LRA性能測(cè)試效果
▲?LRA速度測(cè)試效果
由結(jié)果可以大體看出 Linformer、Linear Transformer、Performer 效率很高,Big bird Reformer 效率很低,但是 Big Bird 的性能還是不錯(cuò)的??傮w的性能還可以由以下這張圖看出,其中縱軸是效果,橫軸是速度,圓圈的大小代表所需要的顯存。理論上來(lái)說(shuō),越靠近右上方的模型越好,圓圈越小的模型越好。
▲ Transformer 們的性能-速度-顯存圖
總結(jié)
本文大體介紹了若干種高效 Transformer 的改進(jìn)版本,均集中在對(duì)長(zhǎng)序列任務(wù)的處理上,里面有一些筆者對(duì)這些模型的思考,其中更多的細(xì)節(jié)感興趣的讀者可以自行閱讀原文。
參考文獻(xiàn)
[1] Transformer?https://arxiv.org/abs/1706.03762
[2] Transformer-XL?https://arxiv.org/abs/1901.02860
[3] Sparse Transformer?https://arxiv.org/abs/1904.10509
[4] Longformer?https://arxiv.org/abs/2004.05150
[5] Big Bird?https://arxiv.org/abs/2007.14062
[6] Reformer?https://arxiv.org/abs/2001.04451
[7] Reformer: 局部敏感哈希、可逆殘差和分塊計(jì)算帶來(lái)的高效?https://mp.weixin.qq.com/s?__biz=MzI4ODg3NDY2NQ==&mid=2247483911&idx=1&sn=8d98a214d455a55650bb589830b08dae&chksm=ec368bc1db4102d7d54216e917ec22f83b47df55153ef4c3f83aaafd68a0dd60caf836fb712a&scene=178&cur_album_id=1464771644039610372#rd
[8] Linformer?https://arxiv.org/abs/2006.04768
[9] Linear Transformer?https://arxiv.org/abs/2006.16236
[10] Performer?https://arxiv.org/abs/2009.14794
[11]《Performer:用隨機(jī)投影將Attention的復(fù)雜度線性化 》 https://kexue.fm/archives/7921
[12] LRA?https://arxiv.org/abs/2011.04006
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識(shí)的人。
總有一些你不認(rèn)識(shí)的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵(lì)高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺(tái)上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)術(shù)熱點(diǎn)剖析、科研心得或競(jìng)賽經(jīng)驗(yàn)講解等。我們的目的只有一個(gè),讓知識(shí)真正流動(dòng)起來(lái)。
📝?稿件基本要求:
? 文章確系個(gè)人原創(chuàng)作品,未曾在公開渠道發(fā)表,如為其他平臺(tái)已發(fā)表或待發(fā)表的文章,請(qǐng)明確標(biāo)注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發(fā)送,要求圖片清晰,無(wú)版權(quán)問(wèn)題
? PaperWeekly 尊重原作者署名權(quán),并將為每篇被采納的原創(chuàng)首發(fā)稿件,提供業(yè)內(nèi)具有競(jìng)爭(zhēng)力稿酬,具體依據(jù)文章閱讀量和文章質(zhì)量階梯制結(jié)算
📬?投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來(lái)稿請(qǐng)備注即時(shí)聯(lián)系方式(微信),以便我們?cè)诟寮x用的第一時(shí)間聯(lián)系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長(zhǎng)按添加PaperWeekly小編
🔍
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁(yè)搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
·
總結(jié)
以上是生活随笔為你收集整理的Transformer性能优化:运算和显存的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 英雄杀不可以聊天(为什么英雄杀里我不能聊
- 下一篇: 关于泥质防水层质量控制的说法,正确的是(