vit transformer中的cls_token
1.源碼
# timm.model.vision_transformer def forward_head(self, x, pre_logits: bool = False):'''# self.global_pool == 'avg'則取所有token的均值作為一個類別的表征# self.global_pool == 'token'則取第一個cls_token作為一個類別的表征'''if self.global_pool: # [bs,token,dim] -> [bs,dim] 經過gapx = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) # bs, dim=768 -> bs, class_numreturn x if pre_logits else self.head(x)2.說明
假設我們將原始圖像切分成共9個小圖像塊,最終的輸入序列長度卻是10,也就是說我們這里人為的增加了一個向量進行輸入,我們通常將人為增加的這個向量稱為 Class Token。
我們可以想象,如果沒有這個向量,也就是將9個向量(1~9)輸入 Transformer 結構中進行編碼,我們最終會得到9個編碼向量,可對于圖像分類任務而言,我們應該選擇哪個輸出向量進行后續分類呢?
方案一,即vit的方案:ViT算法提出了一個可學習的嵌入向量 Class Token( 向量0),將它與9個向量一起輸入到 Transformer 結構中,輸出10個編碼向量,然后用這個 Class Token 進行分類預測。即,基于添加的cls_token執行類別預測,位置在所有token的第一個位置token[0],見編碼中的x[:,0]
方案二,取除了cls_token之外的所有token的均值作為類別特征表示,即編碼中的x[:, self.num_tokens:].mean(dim=1)
?3.思考
根據自注意機制,每個patch token一定程度上聚合了全局信息,但是主要是自身特征。ViT論文還使用了所有token取平均的方式,這意味每個patch對預測的貢獻相同,似乎不太合理?。實際上,這樣做的效果基本和引入cls_token差不多。
參考:
?vit 中的 cls_token 與 position_embed 理解_mingqian_chu的博客-CSDN博客_cls token
ViT為何引入cls_token_gltangwq的博客-CSDN博客_cls token
總結
以上是生活随笔為你收集整理的vit transformer中的cls_token的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 用友系统检查iis服务器不符,安装用友U
- 下一篇: PCB多层板设计总结