从零学习SwinTransformer
論文信息
論文名稱:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
原論文地址: https://arxiv.org/abs/2103.14030
官方開源代碼地址:https://github.com/microsoft/Swin-Transformer
本篇博客參考文章:從零學(xué)習(xí)SwinTransformer
名詞解答
?M/2?: 表示向下取整
像素通道: 參考-圖像的通道數(shù)問題
描述一個像素點,如果是灰度,那么只需要一個數(shù)值來描述它,就是單通道。
如果一個像素點,有RGB三種顏色來描述它,就是三通道。
四通道圖像,R、G、B加上一個A通道,表示透明度。一般叫做alpha通道,表示透明度的。
2通道圖像不常見,通常在程序處理中會用到,如傅里葉變換,可能會用到,一個通道為實數(shù),一個通道為虛數(shù),主要是編程方便。
dense prediction理解:
標(biāo)注出圖像中每個像素點的對象類別,要求不但給出具體目標(biāo)的位置,還要描繪物體的邊界,如圖像分割、語義分割、邊緣檢測等等
WindowPatchToken理解:
假設(shè)輸入圖片的尺寸為224X224,先劃分成多個大小為4x4像素的小片,每個小片之間沒有交集。224/4=56,那么一共可以劃分56x56個小片。每一個小片就叫一個patch,每一個patch將會被看成一個token,所以patch=token。而一張圖被劃分為7x7個window,每個window之間也沒有交集。那么每個window就會包含8x8個patch。
一張圖有224(pixel)X224(pixel)= 56(個patch)x56(個patch)x4(pixel)x4(pixel)=7(個window)x7(個window)x8(個patch)x8(個patch)x4(pixel)x4(pixel)
不懂可看下圖,圖中本想畫出224*224像素的圖片,奈何畫不了,所以就畫一部分就可以說明劃分關(guān)系。其中圖中每個小方塊中包含4x4個像素,紅色框起來的8x8個patch為一個window,這樣的window一張圖片有7x7個,每個window之間也沒有交集。那么每個window就會包含8x8個patch。
疑問: 圖中那4x4個patch組成的叫啥:好像啥也不叫,只是順手畫出來的。哦,下圖中的4x4個patch組成的東西是一個window的一個計算單元,一個window中有4x4個這樣的單元,用于計算self-attention,意思就是在計算self-attention時是8x8個patch作為一個單元去跟別人計算的。
網(wǎng)絡(luò)整體框架
原論文中給出的關(guān)于Swin Transformer(Swin-T)網(wǎng)絡(luò)的架構(gòu)圖。通過圖(a)可以看出整個框架的基本流程如下:
輸入image
假設(shè)模型的輸入是一張224x224x3 的圖片
Patch Partition詳解
首先將圖片輸入到Patch Partition模塊中進(jìn)行分塊,即每4x4個相鄰的像素劃分為一個patch,即在上面畫的圖,將圖片劃分一個一個的patch。
詳解每4x4個像素(3通道)如何展平為1x1個patch(48通道)?
大概以上的圖形就可示意每4x4個像素(3通道)如何展平為1x1個patch(48通道)。
對224x224個像素(3通道)的圖片都這樣處理,會得到(56,56)個patch,48通道,特征圖大小為(56,56,48)見下圖第二個帶綠色部分的立體圖。
再復(fù)述一遍以上的過程:將上面4x4=16個像素的圖然后在通道(channel)方向展平(flatten)。假設(shè)輸入的是RGB三通道的圖片,那么每個patch就有4x4=16個像素,然后每個像素有R、G、B三個值,所以展平后是16x3=48,所以通過Patch Partition后圖像的shape由 [H, W, 3]=(224,224,3)變成了 [H/4, W/4, 48]=(56,56,48)。上圖中的最左下角的第一張圖片是將一張圖劃分patch之后的樣子,簡略示意(224,224)個pixel,3通道,展平之后展平為(56,56)個patch,48個通道。
之后就沒像素什么事了,后來都是在patch上的討論。
Linear Embedding詳解
參考:Swin Transformer全方位解讀【ICCV2021最佳論文】
(56,56,48)————Linear Embedding———>(56,56,96)
這個層優(yōu)點感覺像是卷積層,只不過是對通道數(shù)進(jìn)行卷積
Linear Embeding層對每個patch的channel數(shù)據(jù)做線性變換,由48變成C=(96),即圖像shape再由 [H/4, W/4, 48]=(56,56,48)變成了 [H/4, W/4, C]=(56,56,96)。每個圖片被劃分為56x56=3136個patch,每個patch又被編碼成96維的向量。
這一步在代碼上實現(xiàn)十分簡單,就是一個Conv2D,把步長和kernel size都設(shè)置為patch的長度即可,可看:
這步以后再flatten一下,就可以把56x56x96變?yōu)?196x96。
Swin Transformer Block(W-MSA、SW-MSA)詳解
與標(biāo)準(zhǔn)transformer不同的就是紫色部分的兩個框,分別是W-MSA和SW-MSA。
W-MSA表示,在window內(nèi)部的Multi-Head Self-Attention,就是把window當(dāng)做獨立的全局來計算window中每個token兩兩注意力。
SW-MSA與W-MSA的一丟丟不一樣,就是將window的覆蓋范圍偏移一下,原文設(shè)置為window的邊長的一半。
Swin Transformer Block注意這里的Block其實有兩種結(jié)構(gòu),如圖(b)中所示,這兩種結(jié)構(gòu)的不同之處僅在于一個使用了W-MSA結(jié)構(gòu),一個使用了SW-MSA結(jié)構(gòu)。而且這兩個結(jié)構(gòu)是成對使用的,先使用一個W-MSA結(jié)構(gòu)再使用一個SW-MSA結(jié)構(gòu)。所以你會發(fā)現(xiàn)堆疊Swin Transformer Block的次數(shù)都是偶數(shù)(因為成對使用)。
W-MSA詳解
全稱為Window based Multi-head Self Attention。一張圖平分為7x7個window,這些window互相都沒有overlap。然后,每個window包含一定數(shù)量的token,直接對這些token計算window內(nèi)部的自注意力。 以分而治之的方法,遠(yuǎn)遠(yuǎn)降低了標(biāo)準(zhǔn)transformer的計算復(fù)雜度。以第1層為例,7x7個window,每個window包含8x8個patch,相當(dāng)于把標(biāo)準(zhǔn)transformer應(yīng)用在window上,而不是全圖上。
在Swin Transformer中使用了Windows Multi-Head Self-Attention(W-MSA)的概念,將特征圖劃分成了多個不相交的區(qū)域(Window),Multi-Head Self-Attention只在每個窗口(Window)內(nèi)進(jìn)行。相對于Vision Transformer中直接對整個(Global)特征圖進(jìn)行Multi-Head Self-Attention,這樣做的目的是能夠減少計算量的,尤其是在淺層特征圖很大的時候。這樣做雖然減少了計算量但也會隔絕不同窗口之間的信息傳遞,所以在論文中作者又提出了 Shifted Windows Multi-Head Self-Attention(SW-MSA)的概念,通過此方法能夠讓信息在相鄰的窗口中進(jìn)行傳遞。
引入Windows Multi-head Self-Attention(W-MSA)模塊是為了減少計算量。如下圖所示,左側(cè)使用的是普通的Multi-head Self-Attention(MSA)模塊,對于feature map中的每個像素(或稱作token,patch)在Self-Attention計算過程中需要和所有的像素去計算。但在圖右側(cè),在使用Windows Multi-head Self-Attention(W-MSA)模塊時,首先將feature map按照MxM(例子中的M=2)大小劃分成一個個Windows,然后單獨對每個Windows內(nèi)部進(jìn)行Self-Attention。
兩者的計算量具體差多少呢?原論文中有給出下面兩個公式,這里忽略了Softmax的計算復(fù)雜度。:
h代表feature map的高度
w代表feature map的寬度
C代表feature map的深度
M代表每個窗口(Windows)的大小
W-MSA模塊計算量詳解
好像在計算計算量的時候不計算加法的次數(shù)。hw是第一個矩陣的行數(shù),第一個C是一個行與列的計算量,第二個C是后一個矩陣的列數(shù)個前面的小計算量。
即矩陣L(ab)與P(bd)相乘,計算為abd,即(第一個矩陣的行數(shù))x(第一個矩陣的列數(shù))x(第二個矩陣的列數(shù))。
SW-MSA詳解
在局部window內(nèi)計算Self-Attention確實可以極大地降低計算復(fù)雜度,但是其也缺失了窗口之間的信息交互,降低了模型的表示能力。為了引入Cross-Window Connection,SwinTransformer采用了一種移位窗口劃分的方法來實現(xiàn)這一目標(biāo),窗口會在連續(xù)兩個SwinTransformer Blocks交替移動,使得不同Windows之間有機(jī)會進(jìn)行交互。
采用W-MSA模塊時,只會在每個窗口內(nèi)進(jìn)行自注意力計算,所以窗口與窗口之間是無法進(jìn)行信息傳遞的。為了解決這個問題,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模塊,即進(jìn)行偏移的W-MSA。如下圖所示,左側(cè)使用的是剛剛講的W-MSA(假設(shè)是第L層),那么根據(jù)之前介紹的W-MSA和SW-MSA是成對使用的,那么第L+1層使用的就是SW-MSA(右側(cè)圖)。根據(jù)左右兩幅圖對比能夠發(fā)現(xiàn)窗口(Windows)發(fā)生了偏移(可以理解成窗口從左上角分別向右側(cè)和下方各偏移了?M/2?個像素)。看下偏移后的窗口(右側(cè)圖),比如對于第一行第2列的2x4的窗口,它能夠使第L層的第一排的兩個窗口信息進(jìn)行交流。再比如,第二行第二列的4x4的窗口,他能夠使第L層的四個窗口信息進(jìn)行交流,其他的同理。那么這就解決了不同窗口之間無法進(jìn)行信息交流的問題。
根據(jù)上圖,可以發(fā)現(xiàn)通過將窗口進(jìn)行偏移后,由原來的4個窗口變成9個窗口了。后面又要對每個窗口內(nèi)部進(jìn)行MSA,這樣做感覺又變麻煩了。為了解決這個麻煩,作者又提出而了Efficient batch computation for shifted configuration,一種更加高效的計算方法。下面是原論文給的示意圖。
下圖左側(cè)是剛剛通過偏移窗口后得到的新窗口,右側(cè)是為了方便大家理解,對每個窗口加上了一個標(biāo)識。然后0對應(yīng)的窗口標(biāo)記為區(qū)域A,3和6對應(yīng)的窗口標(biāo)記為區(qū)域B,1和2對應(yīng)的窗口標(biāo)記為區(qū)域C。
然后先將區(qū)域A和C移到最下方。
接著,再將區(qū)域A和B移至最右側(cè)。
移動完后,4是一個單獨的窗口;將5和3合并成一個窗口;7和1合并成一個窗口;8、6、2和0合并成一個窗口。這樣又和原來一樣是4個4x4的窗口了,所以能夠保證計算量是一樣的。這里肯定有人會想,把不同的區(qū)域合并在一起(比如5和3)進(jìn)行MSA,這信息不就亂竄了嗎?是的,為了防止這個問題,在實際計算中使用的是masked MSA即帶蒙板mask的MSA,這樣就能夠通過設(shè)置蒙板來隔絕不同區(qū)域的信息了。關(guān)于mask如何使用,可以看下下面這幅圖,下圖是以上面的區(qū)域5和區(qū)域3為例。
對于該窗口內(nèi)的每一個像素(或稱token,patch)在進(jìn)行MSA計算時,都要先生成對應(yīng)的query(q),key(k),value(v)。假設(shè)對于上圖的像素0而言,得到q0后要與每一個像素的k進(jìn)行匹配(match),假設(shè)α0,0代表q0與像素0對應(yīng)的k0進(jìn)行匹配的結(jié)果,那么同理可以得到α0,0至α0,15。按照普通的MSA計算,接下來就是SoftMax操作了。但對于這里的masked MSA,像素0是屬于區(qū)域5的,我們只想讓它和區(qū)域5內(nèi)的像素進(jìn)行匹配。那么我們可以將像素0與區(qū)域3中的所有像素匹配結(jié)果都減去100(例如α0,2,α0,3,α0,6 ,α 0,7等等),由于α的值都很小,一般都是零點幾的數(shù),將其中一些數(shù)減去100后再通過SoftMax得到對應(yīng)的權(quán)重都等于0了。所以對于像素0而言實際上還是只和區(qū)域5內(nèi)的像素進(jìn)行了MSA。那么對于其他像素也是同理。注意,在計算完后還要把數(shù)據(jù)給挪回到原來的位置上(例如上述的A,B,C區(qū)域)。
SW-MSA模塊計算量詳解
LayerNorm
相對位置偏置(relative position bias)
Patch Merging詳解
在每個Stage中首先要通過一個Patch Merging層進(jìn)行下采樣(Stage1除外)。如下圖所示,假設(shè)輸入Patch Merging的是一個4x4大小的單通道特征圖(feature map),Patch Merging會將每個2x2的相鄰像素劃分為一個patch,然后將每個patch中相同位置(同一顏色)像素給拼在一起就得到了4個feature map。接著將這四個feature map在深度方向進(jìn)行concat拼接,然后在通過一個LayerNorm層。最后通過一個全連接層在feature map的深度方向做線性變化,將feature map的深度由C變成C/2。通過這個簡單的例子可以看出,通過Patch Merging層后,feature map的高和寬會減半,深度會翻倍。
Relative Position Bias詳解
真正使用到的可訓(xùn)練參數(shù)B是保存在relative position bias table表里的,這個表的長度是等于(2M?1)×(2M?1)的。那么上述公式中的相對位置偏執(zhí)參數(shù)B是根據(jù)上面的相對位置索引表根據(jù)查relative position bias table表得到的,如上圖所示。
參考:
Swin Transformer全方位解讀【ICCV2021最佳論文】
Swin-Transformer網(wǎng)絡(luò)結(jié)構(gòu)詳解
2021-Swin Transformer Attention機(jī)制的詳細(xì)推導(dǎo)
(附加:CSDN上傳圖片去水印方法)
總結(jié)
以上是生活随笔為你收集整理的从零学习SwinTransformer的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 如何在Python中删除字符串中的所有反
- 下一篇: image caption优秀链接