Vision Transformer(ViT)PyTorch代码全解析(附图解)
Vision Transformer(ViT)PyTorch代碼全解析
最近CV領(lǐng)域的Vision Transformer將在NLP領(lǐng)域的Transormer結(jié)果借鑒過來,屠殺了各大CV榜單。本文將根據(jù)最原始的Vision Transformer論文,及其PyTorch實(shí)現(xiàn),將整個(gè)ViT的代碼做一個(gè)全面的解析。
對原Transformer還不熟悉的讀者可以看一下Attention is All You Need原文,中文講解推薦李宏毅老師的視頻 YouTube,BiliBili 個(gè)人覺得講的很明白。
話不多說,直接開始。
下圖是ViT的整體框架圖,我們在解析代碼時(shí)會(huì)參照此圖:
以下是文中給出的符號(hào)公式,也是我們解析的重要參照:
z=[xclass;xp1E,xp2E,…;xpNE]+Epos,E∈R(P2?C)×D,Epos∈R(N+1)×D(1)\mathbf{z}=[\mathbf{x}_{class};\mathbf{x}^1_p\mathbf{E},\mathbf{x}^2_p\mathbf{E},\dots;\mathbf{x}^N_p\mathbf{E}]+\mathbf{E}_{pos},\ \ \ \mathbf{E}\in\mathbb{R}^{(P^2\cdot C)\times D},\mathbf{E}_{pos}\in \mathbb{R}^{(N+1)\times D} \ \ \ \ \ \ \ \ \ \ \ \ \ (1) z=[xclass?;xp1?E,xp2?E,…;xpN?E]+Epos?,???E∈R(P2?C)×D,Epos?∈R(N+1)×D?????????????(1)
z?′=MSA(LN(z??1))+z??1(2)\mathbf{z'_\ell}=MSA(LN(\mathbf{z}_{\ell-1}))+\mathbf{z}_{\ell-1}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2) z?′?=MSA(LN(z??1?))+z??1???????????????????????????????????????(2)
z?=MLP(LN(z′?))+z′?(3)\mathbf{z}_{\ell}=MLP(LN(\mathbf{z'}_{\ell}))+\mathbf{z'}_{\ell}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (3) z??=MLP(LN(z′??))+z′???????????????????????????????????(3)
y=LN(zL0)(4)\mathbf{y}=LN(\mathbf{z}_L^0)\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (4) y=LN(zL0?)??????????????????????????????????(4)
導(dǎo)入需要的包
import torch from torch import nn, einsum import torch.nn.functional as Ffrom einops import rearrange, repeat from einops.layers.torch import Rearrange都是搭建網(wǎng)絡(luò)時(shí)常用的PyTorch包,其中在卷積神經(jīng)網(wǎng)絡(luò)的搭建中并不常用的einops和einsum,還不熟悉的讀者可以參考博客:einops和einsum:直接操作張量的利器。
pair函數(shù)
def pair(t):return t if isinstance(t, tuple) else (t, t)作用是:判斷t是否是元組,如果是,直接返回t;如果不是,則將t復(fù)制為元組(t, t)再返回。
用來處理當(dāng)給出的圖像尺寸或塊尺寸是int類型(如224)時(shí),直接返回為同值元組(如(224, 224))。
PreNorm
class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)PreNorn對應(yīng)框圖中最下面的黃色的Norm層。其參數(shù)dim是維度,而fn則是預(yù)先要進(jìn)行的處理函數(shù),是以下的Attention、FeedForward之一,分別對應(yīng)公式(2)(3)。
z?′=MSA(LN(z??1))+z??1(2)\mathbf{z'_\ell}=MSA(LN(\mathbf{z}_{\ell-1}))+\mathbf{z}_{\ell-1}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2) z?′?=MSA(LN(z??1?))+z??1???????????????????????????????????????(2)
z?=MLP(LN(z′?))+z′?(3)\mathbf{z}_{\ell}=MLP(LN(\mathbf{z'}_{\ell}))+\mathbf{z'}_{\ell}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (3) z??=MLP(LN(z′??))+z′???????????????????????????????????(3)
FeedForward
class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout=0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim), nn.Dropout(dropout))def forward(self, x):return self.net(x)FeedForward層由全連接層,配合激活函數(shù)GELU和Dropout實(shí)現(xiàn),對應(yīng)框圖中藍(lán)色的MLP。參數(shù)dim和hidden_dim分別是輸入輸出的維度和中間層的維度,dropour則是dropout操作的概率參數(shù)p。
Attention
class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head * headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim=-1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout),) if project_out else nn.Identity()def forward(self, x):b, n, _, h = *x.shape, self.headsqkv = self.to_qkv(x).chunk(3, dim=-1) # (b, n(65), dim*3) ---> 3 * (b, n, dim)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) # q, k, v (b, h, n, dim_head(64))dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scaleattn = self.attend(dots)out = einsum('b h i j, b h j d -> b h i d', attn, v)out = rearrange(out, 'b h n d -> b n (h d)')return self.to_out(out)Attention,Transformer中的核心部件,對應(yīng)框圖中的綠色的Multi-Head Attention。參數(shù)heads是多頭自注意力的頭的數(shù)目,dim_head是每個(gè)頭的維度。
本層的對應(yīng)公式就是經(jīng)典的Tansformer的計(jì)算公式:
Attention(Q,K,V)=softmax(QKTdk)VAttention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk??QKT?)V
Transformer
class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn x定義好幾個(gè)層之后,我們就可以構(gòu)建整個(gè)Transformer Block了,即對應(yīng)框圖中的整個(gè)右半部分Transformer Encoder。有了前面的鋪墊,整個(gè)Block的實(shí)現(xiàn)看起來非常簡潔。
參數(shù)depth是每個(gè)Transformer Block重復(fù)的次數(shù),其他參數(shù)與上面各個(gè)層的介紹相同。
筆者也在圖中也做了標(biāo)注與代碼的各部分對應(yīng)。
ViT
class ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):super().__init__()image_height, image_width = pair(image_size)patch_height, patch_width = pair(patch_size)assert image_height % patch_height ==0 and image_width % patch_width == 0num_patches = (image_height // patch_height) * (image_width // patch_width)patch_dim = channels * patch_height * patch_widthassert pool in {'cls', 'mean'}self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),nn.Linear(patch_dim, dim))self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) # nn.Parameter()定義可學(xué)習(xí)參數(shù)self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_classes))def forward(self, img):x = self.to_patch_embedding(img) # b c (h p1) (w p2) -> b (h w) (p1 p2 c) -> b (h w) dimb, n, _ = x.shape # b表示batchSize, n表示每個(gè)塊的空間分辨率, _表示一個(gè)塊內(nèi)有多少個(gè)值cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) # self.cls_token: (1, 1, dim) -> cls_tokens: (batchSize, 1, dim) x = torch.cat((cls_tokens, x), dim=1) # 將cls_token拼接到patch token中去 (b, 65, dim)x += self.pos_embedding[:, :(n+1)] # 加位置嵌入(直接加) (b, 65, dim)x = self.dropout(x)x = self.transformer(x) # (b, 65, dim)x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] # (b, dim)x = self.to_latent(x) # Identity (b, dim)print(x.shape)return self.mlp_head(x) # (b, num_classes)筆者在forward()函數(shù)代碼中的注釋說明了各中間state的尺寸形狀,可供參考比對。
在 x 送入transformer之前,都是對應(yīng)公式(1)的預(yù)處理操作:
z=[xclass;xp1E,xp2E,…;xpNE]+Epos,E∈R(P2?C)×D,Epos∈R(N+1)×D(1)\mathbf{z}=[\mathbf{x}_{class};\mathbf{x}^1_p\mathbf{E},\mathbf{x}^2_p\mathbf{E},\dots;\mathbf{x}^N_p\mathbf{E}]+\mathbf{E}_{pos},\ \ \ \mathbf{E}\in\mathbb{R}^{(P^2\cdot C)\times D},\mathbf{E}_{pos}\in \mathbb{R}^{(N+1)\times D} \ \ \ \ \ \ \ \ \ \ \ \ \ (1) z=[xclass?;xp1?E,xp2?E,…;xpN?E]+Epos?,???E∈R(P2?C)×D,Epos?∈R(N+1)×D?????????????(1)
positional embedding和class token由nn.Parameter()定義,該函數(shù)會(huì)將送到其中的Tensor注冊到Parameters列表,隨模型一起訓(xùn)練更新,對nn.Parameter()不熟悉的同學(xué)可參考博客:PyTorch中的torch.nn.Parameter() 詳解。
我們知道,transformer模型最后送到mlp中做預(yù)測的只有cls_token的輸出結(jié)果(如上圖紅框所示),而其他的圖像塊的輸出全都不要了,是由這一步實(shí)現(xiàn):
x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] # (b, dim)可以看到,如果指定池化方式為'mean'的話,則會(huì)對全部token做平均池化,然后全部進(jìn)行送到mlp中,但是我們可以看到,默認(rèn)的self.pool='cls',也就是說默認(rèn)不會(huì)進(jìn)行平均池化,而是按照ViT的設(shè)計(jì)只使用cls_token,即x[:, 0]只取第一個(gè)token(cls_token)。
最后經(jīng)過mlp_head,得到各類的預(yù)測值。
筆者也簡單做了一張圖展示整個(gè)過程中的信號(hào)流,可以結(jié)合代碼中注釋的維度的變化來看:
圖中各符號(hào)含義:H,W,CH,W,CH,W,C 分別是某一張輸入圖像的長、寬、通道數(shù),h,wh,wh,w 是圖塊的長、寬,如此這張圖中塊的個(gè)數(shù)就是 Hh×Ww\frac{H}{h}\times \frac{W}{w}hH?×wW? ,用 NpN_pNp? 表示,DDD 是維度數(shù)dim,NcN_cNc? 是類的個(gè)數(shù)。
至此,ViT模型的定義就全部完成了,在訓(xùn)練腳本中實(shí)例化一個(gè)ViT模型來進(jìn)行訓(xùn)練即可,以下腳本可驗(yàn)證ViT模型正常運(yùn)作。
model_vit = ViT(image_size = 256,patch_size = 32,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1)img = torch.randn(16, 3, 256, 256)preds = model_vit(img) print(preds.shape) # (16, 1000)有疑惑或異議歡迎留言討論。
創(chuàng)作挑戰(zhàn)賽新人創(chuàng)作獎(jiǎng)勵(lì)來咯,堅(jiān)持創(chuàng)作打卡瓜分現(xiàn)金大獎(jiǎng)總結(jié)
以上是生活随笔為你收集整理的Vision Transformer(ViT)PyTorch代码全解析(附图解)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 联想圆梦怎么进入bios &qu
- 下一篇: 怎么通过cmd窗口 禁用u口 通过CMD