ResT解读
最近的一篇基于Transformer的工作,由南京大學(xué)的研究者提出一種高效的視覺Transformer結(jié)構(gòu),設(shè)計(jì)思想類似ResNet,稱為ResT,這是我個(gè)人覺得值得關(guān)注的一篇工作。
簡(jiǎn)介
ResT是一個(gè)高效的多尺度視覺Transformer結(jié)構(gòu),可以作為圖像識(shí)別的通用骨干網(wǎng)絡(luò),它采用類似ResNet的設(shè)計(jì)思想,分階段捕獲不同尺度的信息。不同于現(xiàn)有的Transformer方法只使用標(biāo)準(zhǔn)的Transformer block來處理具有固定分辨率的原始圖像,ResT有著幾個(gè)優(yōu)勢(shì):提出一種內(nèi)存高效的多頭自注意力,使用深度卷積進(jìn)行內(nèi)存壓縮,并且跨注意力頭的維度投影交互同時(shí)保持多頭的多樣性能力;將位置編碼構(gòu)建為空間注意力,它可以以更加靈活的方式處理任意尺寸的輸入而無需插值或者微調(diào);不同于直接在每個(gè)階段開始進(jìn)行序列化,而是將patch embedding設(shè)計(jì)為一系列重疊的有stride的卷積操作。作者在圖像分類以及下游任務(wù)中驗(yàn)證了ResT的性能,結(jié)果表明,ResT大幅度優(yōu)于當(dāng)前SOTA骨干網(wǎng)絡(luò),在ImageNet數(shù)據(jù)集上,同等計(jì)算量前提下,所提方法取得了優(yōu)于PVT、Swin。
-
論文標(biāo)題
ResT: An Efficient Transformer for Visual Recognition
-
論文地址
https://arxiv.org/abs/2105.13677
-
論文源碼
https://github.com/wofmanaf/ResT
介紹
用于提取圖像特征的骨干網(wǎng)絡(luò)(backbone)在計(jì)算機(jī)視覺任務(wù)中至關(guān)重要,好的特征有利于下游任務(wù)的展開,如圖像分類、目標(biāo)檢測(cè)、實(shí)例分割等。如今,計(jì)算機(jī)視覺中主要有兩種骨干網(wǎng)絡(luò)結(jié)構(gòu),一種是卷積神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),一種是Transformer結(jié)構(gòu),它們都是堆疊多個(gè)塊(block)來捕獲特征信息的。
CNN block通常是一個(gè)bottleneck結(jié)構(gòu),可以定義為堆疊的1x1卷積、3x3卷積和1x1卷積配合一個(gè)殘差連接,如下圖的(a)所示。兩個(gè)1x1卷積分別用于通道降維和通道升維,保證3x3卷積處理的特征圖通道數(shù)不會(huì)太高。CNN骨干網(wǎng)絡(luò)通常更快一些,這主要得益于參數(shù)共享、局部信息聚合以及維度縮減,然而,受限于有限且固定的感受野,卷積網(wǎng)絡(luò)在那些需要長(zhǎng)程依賴的場(chǎng)景中效果并不好,比如實(shí)例分割中,從一個(gè)更大的鄰域中收集并關(guān)聯(lián)目標(biāo)間的關(guān)系是很重要的。
為了克服這些限制,能夠捕獲長(zhǎng)程信息的Transformer結(jié)構(gòu)最近被探索用于設(shè)計(jì)骨干網(wǎng)絡(luò)。不同于CNN網(wǎng)絡(luò),Transformer網(wǎng)絡(luò)首先是將圖片切分為一系列塊(patch,也叫token),然后將這些token和位置編碼相加來表示粗糙的空間信息,最終采用堆疊的Transformer block來捕獲特征信息。一個(gè)標(biāo)準(zhǔn)的Transformer block由一個(gè)多頭自注意力(multi-head self-attention,MSA)和一個(gè)前饋神經(jīng)網(wǎng)絡(luò)(feed-forward network,FFN)構(gòu)成,其中MSA通過query-key-value分解來建模token之間的全局依賴,FFN則用來學(xué)習(xí)更寬泛的表示。Transformer block的結(jié)構(gòu)如上圖的(b)所示,它能夠根據(jù)圖像內(nèi)容自適應(yīng)調(diào)整感受野。
雖然相比于CNN backbone。Transformer backbone潛力巨大,但它依然有四個(gè)主要的缺點(diǎn)如下。
在這篇論文中,作者提出一種高效的通用backbone ResT(以ResNet命名),該結(jié)構(gòu)可以解決上述的問題,這個(gè)結(jié)構(gòu)會(huì)在下一節(jié)具體說明。
ResT
上圖所示的即為ResT的結(jié)構(gòu)圖,可以看到,它和ResNet有著非常類似的pipeline,即采用一個(gè)stem模塊來提取底層特征,然后跟著四個(gè)stage捕獲多尺度特征。每個(gè)stage由三個(gè)組件構(gòu)成,一個(gè)patch embedding模塊,一個(gè)position encoding模塊以及L個(gè)efficient Transformer block。具體而言,在每個(gè)stage的開始,patch embedding模塊用來減少輸入token的分辨率并且拓展通道數(shù)。位置編碼模塊則被融合進(jìn)來用于抑制位置信息并且加強(qiáng)patch embedding的特征提取能力。這兩個(gè)階段完成之后,輸入token被送入efficient Transformer block。
Rethinking of Transformer Block
標(biāo)準(zhǔn)的Transformer block包含兩個(gè)子層,分別是MSA和FFN,每個(gè)子層包圍著一個(gè)殘差連接。在MSA和FFN前,先經(jīng)過了一個(gè)layer normalization(下面簡(jiǎn)稱LN)。假定輸入token為x∈Rn×dm\mathrm{x} \in \mathbb{R}^{n \times d_{m}}x∈Rn×dm?,這里的nnn和dmd_mdm?分別表示空間維度和通道維度,每個(gè)Transformer block的輸出表示如下。
y=x′+FFN(LN(x′)),and?x′=x+MSA(LN(x))\mathrm{y}=\mathrm{x}^{\prime}+\mathrm{FFN}\left(\mathrm{LN}\left(\mathrm{x}^{\prime}\right)\right), \text { and } \mathrm{x}^{\prime}=\mathrm{x}+\mathrm{MSA}(\mathrm{LN}(\mathrm{x})) y=x′+FFN(LN(x′)),?and?x′=x+MSA(LN(x))
對(duì)上面的式子,我們先來看MSA,它首先通過三組線性投影獲取query Q\mathbf{Q}Q、key K\mathbf{K}K和value V\mathbf{V}V,每組投影有kkk個(gè)線性層(即heads)將dmd_mdm?映射到dkd_kdk?的空間中,這里dk=dm/kd_{k}=d_{m} / kdk?=dm?/k。為了描述方便,后續(xù)所有的說明都是基于k=1k=1k=1,因此MSA可以簡(jiǎn)化為單頭注意力(SA),token序列之間的全局關(guān)系可以定義為下式,每個(gè)head的輸出concatenate到一起之后經(jīng)過線性投影得到最終輸出。可以得知,MSA的計(jì)算復(fù)雜度為O(2dmn2+4dm2n)\mathcal{O}\left(2 d_{m} n^{2}+4 d_{m}^{2} n\right)O(2dm?n2+4dm2?n),它根據(jù)輸入token的空間維度或者通道維度次方級(jí)變化。
SA(Q,K,V)=Softmax?(QKTdk)V\mathrm{SA}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\operatorname{Softmax}\left(\frac{\mathbf{Q K}^{\mathrm{T}}}{\sqrt{d_{k}}}\right) \mathbf{V} SA(Q,K,V)=Softmax(dk??QKT?)V
接著,來看FFN,它主要用于特征轉(zhuǎn)換和非線性,通常由兩個(gè)線性層和一個(gè)非線性激活函數(shù)構(gòu)成,第一層將輸入的通道數(shù)從dmd_mdm?拓展到dfd_fdf?,第二層則從dfd_fdf?降到dmd_mdm?。數(shù)學(xué)上表示如下式,其中W1∈Rdm×df\mathbf{W}_{1} \in \mathbb{R}^{d_{m} \times d_{f}}W1?∈Rdm?×df?且W2∈Rdf×dm\mathbf{W}_{2} \in \mathbb{R}^{d_{f} \times d_{m}}W2?∈Rdf?×dm?為兩個(gè)線性層的權(quán)重,b1∈Rdf\mathbf{b}_{1} \in \mathbb{R}^{d_{f}}b1?∈Rdf?和b2∈Rdm\mathbf{b}_{2} \in \mathbb{R}^{d_{m}}b2?∈Rdm?則是相應(yīng)的偏置項(xiàng),σ(?)\sigma(\cdot)σ(?)表示GELU激活函數(shù)。標(biāo)準(zhǔn)的Transformer block中,通道數(shù)通常4倍擴(kuò)大,即df=4dmd_{f}=4 d_{m}df?=4dm?。FFN的計(jì)算代價(jià)為8ndm28 n d_{m}^{2}8ndm2?。
FFN(x)=σ(xW1+b1)W2+b2\mathrm{FFN}(\mathrm{x})=\sigma\left(\mathrm{x} \mathbf{W}_{1}+\mathbf{b}_{1}\right) \mathbf{W}_{2}+\mathbf{b}_{2} FFN(x)=σ(xW1?+b1?)W2?+b2?
Efficient Transformer Block
如上面所述,MSA有兩個(gè)缺點(diǎn),第一是其計(jì)算量是二次方倍的,這給訓(xùn)練和推理都帶來了不小的負(fù)擔(dān);第二,MSA中的每個(gè)head只負(fù)責(zé)輸入token序列的一個(gè)子集,當(dāng)通道數(shù)比較少的時(shí)候這個(gè)會(huì)損害模型的表現(xiàn)。
為了解決這些問題,作者提出了一種高效的多頭自注意力模塊,如上圖所示。和MSA類似,EMSA首先采用一組投影獲取query Q\mathbf{Q}Q。為了壓縮內(nèi)存,2D輸入的token x∈Rn×dm\mathrm{x} \in \mathbb{R}^{n \times d_{m}}x∈Rn×dm?會(huì)被沿著空間維度reshape為3D形式(x^∈Rdm×h×w\hat{\mathrm{x}} \in \mathbb{R}^{d_{m} \times h \times w}x^∈Rdm?×h×w)然后送入深度可分離卷積中按照因子sss降低寬高,為了簡(jiǎn)單,sss根據(jù)kkk自適應(yīng)為s=8/ks=8 / ks=8/k,卷積核尺寸、stride和padding分別是s+1s+1s+1、sss和s/2s/2s/2。然后,下采樣后的token map為x^∈Rdm×h/s×w/s\hat{\mathrm{x}} \in \mathbb{R}^{d_{m} \times h / s \times w / s}x^∈Rdm?×h/s×w/s,它被reshape為2D的形式,也就是x^∈Rn′×dm,n′=h/s×w/s\hat{\mathbf{x}} \in \mathbb{R}^{n^{\prime} \times d_{m}}, n^{\prime}=h / s \times w / sx^∈Rn′×dm?,n′=h/s×w/s,然后x^\hat{x}x^送入兩組投影層獲得key K\mathbf{K}K和value V\mathbf{V}V。再然后,采用下面的式子計(jì)算qkv之間的注意力函數(shù),式子中的Conv表示標(biāo)準(zhǔn)的1x1卷積,它用于建模不同head之間的交互,通過這個(gè)方法attention的結(jié)果依賴于所有的key和query,然而,這將削弱 MSA 聯(lián)合處理來自不同位置的不同表示子集的信息的能力。為了重建這種多樣性能力,在點(diǎn)擊矩陣后添加了一個(gè)LN,也就是Softmax之后。
EMSA?(Q,K,V)=IN?(Softmax?(Conv?(QKTdk)))V\operatorname{EMSA}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\operatorname{IN}\left(\operatorname{Softmax}\left(\operatorname{Conv}\left(\frac{\mathbf{Q} \mathbf{K}^{\mathrm{T}}}{\sqrt{d_{k}}}\right)\right)\right) \mathbf{V} EMSA(Q,K,V)=IN(Softmax(Conv(dk??QKT?)))V
最后,所有head的輸出concatenate到一起經(jīng)過投影得到最終輸出。這就是整個(gè)EMSA塊的計(jì)算過程,其實(shí)對(duì)照上圖就能理解得很明白了,EMSA的計(jì)算代價(jià)為O(2dmn2s2+2dm2n(1+1s2)+dmn(s+1)2s2+k2n2s2)\mathcal{O}\left(\frac{2 d_{m} n^{2}}{s^{2}}+2 d_{m}^{2} n\left(1+\frac{1}{s^{2}}\right)+d_{m} n \frac{(s+1)^{2}}{s^{2}}+\frac{k^{2} n^{2}}{s^{2}}\right)O(s22dm?n2?+2dm2?n(1+s21?)+dm?ns2(s+1)2?+s2k2n2?),假定s=1s=1s=1的話這個(gè)復(fù)雜度是遠(yuǎn)遠(yuǎn)低于原始的MSA的,特別是較淺的stage時(shí),nnn相對(duì)高一些。
當(dāng)然,EMSA之后也添加了FFN以進(jìn)行特征變換和非線性,因此最終effcient Transformer block的輸出如下。
y=x′+FFN(LN(x′)),and?x′=x+EMSA?(LN(x))\mathrm{y}=\mathrm{x}^{\prime}+\mathrm{FFN}\left(\mathrm{LN}\left(\mathrm{x}^{\prime}\right)\right), \text { and } \mathrm{x}^{\prime}=\mathrm{x}+\operatorname{EMSA}(\mathrm{LN}(\mathrm{x})) y=x′+FFN(LN(x′)),?and?x′=x+EMSA(LN(x))
Patch Embedding
知道了最核心的EMSA,接下來是關(guān)于Patch Embedding的內(nèi)容。在標(biāo)準(zhǔn)的Transformer中,一個(gè)token序列的embedding作為輸入,以ViT為例,3D圖像x∈R3×h×w\mathrm{x} \in \mathbb{R}^{3 \times h \times w}x∈R3×h×w為輸入,它被按照patch size為p×pp \times pp×p進(jìn)行切分。這些patch被展平為2D然后被映射為隱嵌入x∈Rn×c\mathrm{x} \in \mathbb{R}^{n \times c}x∈Rn×c(其中n=hw/p2n=h w / p^{2}n=hw/p2)。然而,這種直接的標(biāo)記化難以捕獲底層特征信息(比如邊緣、角點(diǎn))。此外,ViT中的token序列長(zhǎng)度是固定的,這使其難以進(jìn)行下游任務(wù)(比如目標(biāo)檢測(cè)、實(shí)例分割)適配,因?yàn)檫@些任務(wù)往往需要多尺度特征圖。
為了解決上述問題,作者構(gòu)建了一種高效的多尺度backbone,名為ResT以進(jìn)行密集預(yù)測(cè)任務(wù)。如上文所述,每個(gè)階段的efficient Transformer block在一個(gè)確定的尺度和分辨率上跨空間和通道進(jìn)行操作,因此,patch embedding模塊需要減少空間分辨率的同時(shí)拓展通道維度。
和ResNet類似,stem模塊(也就是第一個(gè)patch embedding模塊)以4的縮減因子縮小高度和寬度。為了是圓通很少的參數(shù)高效捕獲底層特征,作者引入了一個(gè)簡(jiǎn)單但有效的方式,即堆疊3個(gè)3x3卷積層(padding為1),stride分別為2、1、2,前兩層緊跟一個(gè)BN和ReLU層。在stage2、stage3和stage4,patch embedding模塊被用來4倍下采樣空間維度并且2倍通道維度。這可以通過標(biāo)準(zhǔn)的3x3卷積以stride2和padding1實(shí)現(xiàn)。在stage2中,patch embedding模塊將輸入分辨率從h/4×w/4×ch / 4 \times w / 4 \times ch/4×w/4×c調(diào)整到h/8×w/8×2ch / 8 \times w / 8 \times 2 ch/8×w/8×2c。
Position Encoding
位置編碼對(duì)于序列順序非常關(guān)鍵,在ViT中,一系列可學(xué)習(xí)參數(shù)被加(加法)到輸入token上編碼位置信息,假定x∈Rn×c\mathrm{x} \in \mathbb{R}^{n \times c}x∈Rn×c為輸入,θ∈Rn×c\theta \in \mathbb{R}^{n \times c}θ∈Rn×c為位置參數(shù),編碼后的輸入可以表示如下。
x^=x+θ\hat{\mathrm{x}}=\mathrm{x}+\theta x^=x+θ
但是,位置的長(zhǎng)度需要和輸入token的長(zhǎng)度一致,這就限制了很多應(yīng)用的場(chǎng)景,因此需要一種可以根據(jù)輸入改變長(zhǎng)度的位置編碼。回顧上面的式子,其實(shí)相加操作非常類似逐像素對(duì)輸入加權(quán)。假定θ\thetaθ和xxx相關(guān),即θ=GL(x)\theta=\mathrm{GL}(\mathrm{x})θ=GL(x),這里的GL(?)\mathrm{GL}(\cdot)GL(?)表示組線性操作且組數(shù)為ccc,上式就被修改為下面的式子,θ\thetaθ可以通過更靈活的注意力機(jī)制獲得。
x^=x+GL(x)\hat{\mathrm{x}}=\mathrm{x}+\mathrm{GL}(\mathrm{x}) x^=x+GL(x)
因此,作者這里提出了一種簡(jiǎn)單高效的像素級(jí)注意力(pixel-wise attention,PA)來編碼位置。具體而言,PA采用3x3深度卷積操作來獲得像素級(jí)權(quán)重,然后使用sigmoid激活,最終使用PA獲得的位置編碼如下式。
x^=PA(x)=x?σ(DWConv?(x))\hat{\mathrm{x}}=\mathrm{PA}(\mathrm{x})=\mathrm{x} * \sigma(\operatorname{DWConv}(\mathrm{x})) x^=PA(x)=x?σ(DWConv(x))
由于每個(gè)stage的輸入token通過卷積得到,可以將位置編碼嵌入到patch embedding模塊中,整體結(jié)果見下圖。注:這里的PA可以采用任意空間注意力替換,這使得ResT中的PE極為靈活。
Classification Head
分類head的設(shè)計(jì)非常簡(jiǎn)單,一個(gè)池化接線性層即可,在圖像分類任務(wù)上的模型結(jié)構(gòu)如下圖所示。
實(shí)驗(yàn)
圖像分類、目標(biāo)檢測(cè)、實(shí)例分割的結(jié)果如下,超越了PVT、Swin等。
此外作者還對(duì)各個(gè)模塊進(jìn)行了消融實(shí)驗(yàn),具體可以查看論文。
總結(jié)
這篇文章作者提出了一種新的Transformer架構(gòu)的視覺backbone,它可以捕獲多尺度特征因而非常適用于密集預(yù)測(cè)任務(wù)。作者壓縮了標(biāo)準(zhǔn) MSA 的內(nèi)存,并在保持多樣性能力的同時(shí)對(duì)多頭之間的交互進(jìn)行建模。 為了處理任意輸入圖像,作者進(jìn)一步將位置編碼重新設(shè)計(jì)為空間注意力。本文也只是我本人從自身出發(fā)對(duì)這篇文章進(jìn)行的解讀,想要更詳細(xì)理解的強(qiáng)烈推薦閱讀原論文。最后,如果我的文章對(duì)你有所幫助,歡迎一鍵三連,你的支持是我不懈創(chuàng)作的動(dòng)力。
總結(jié)
- 上一篇: SiamMOT解读
- 下一篇: 0006-ZigZag Conversi