Swintransformer详细设计文档
1、文件說(shuō)明
Model.py:構(gòu)建模型
My_dataset.py:數(shù)據(jù)集處理
Predict.py:預(yù)測(cè)圖片分類類別
Train.py:訓(xùn)練網(wǎng)絡(luò)
Utils.py:
2、項(xiàng)目結(jié)構(gòu)和函數(shù)設(shè)計(jì)
Model.py 的類
class DropPath(nn.Module)def forward(self, x) class PatchEmbed(nn.Module)def forward(self, x) class PatchMerging(nn.Module):def forward(self, x, H, W) class Mlp(nn.Module):def forward(self, x): class WindowAttention(nn.Module):def forward(self, x, mask: Optional[torch.Tensor] = None): class SwinTransformerBlock(nn.Module):def forward(self, x, attn_mask): class BasicLayer(nn.Module):def create_mask(self, x, H, W):def forward(self, x, H, W): class SwinTransformer(nn.Module):def _init_weights(self, m):def forward(self, x)Model.py 的函數(shù)
def drop_path_f(x, drop_prob: float = 0., training: bool = False) def window_partition(x, window_size: int) def window_reverse(windows, window_size: int, H: int, W: int) def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs): def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs): def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs): def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs): def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs): def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs): def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs): def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):My_dataset.py只有類
class MyDataSet(Dataset): ---def __len__(self): ---def __getitem__(self, item):@staticmethod ---def collate_fn(batch):Predict.py只有函數(shù)
def main(): if __name__ == '__main__':main()Train.py只有函數(shù)
def main(args): if __name__ == '__main__':。。。main(opt)Utils.py只有函數(shù)
def read_split_data(root: str, val_rate: float = 0.2): def plot_data_loader_image(data_loader): def write_pickle(list_info: list, file_name: str): def read_pickle(file_name: str) -> list: def train_one_epoch(model, optimizer, data_loader, device, epoch): @torch.no_grad() def evaluate(model, data_loader, device, epoch):Swin-Transformer 論文代碼介紹
1 開發(fā)環(huán)境
? Python 3.6
? torch 1.7.1
? GPU
2 功能設(shè)計(jì)
實(shí)驗(yàn)數(shù)據(jù)集的說(shuō)明:
數(shù)據(jù)來(lái)源
http://download.tensorflow.org/example_images/flower_photos.tgz
5類花的圖片做分類:
3670 images were found in the dataset.
2939 images for training.
731 images for validation.
Daisy:菊花
Dandelion:蒲公英
Roses:玫瑰
Sunflowers:向日葵
Tulips:郁金香
3 、文件說(shuō)明
Model.py:構(gòu)建模型
My_dataset.py:數(shù)據(jù)集處理
Predict.py:預(yù)測(cè)圖片分類類別
Train.py:訓(xùn)練網(wǎng)絡(luò)
Utils.py:功能類函數(shù)
Model.py 的類
DropPath:設(shè)置各模塊內(nèi)的dropout率
PatchEmbed:對(duì)圖片像素進(jìn)行劃分patch
PatchMerging:對(duì)圖進(jìn)行petch的拼接和線性映射
Mlp:SwinTransformerBlock后面一段的使用的
WindowAttention:window內(nèi)部計(jì)算attention
SwinTransformerBlock:構(gòu)建單個(gè)SwinTransformerBlock模型,該模型中含有W-MSA和SW-MSA兩個(gè)模塊
SwinTransformer:構(gòu)建整個(gè)分類模型,這個(gè)類調(diào)用其他類,共同組成整個(gè)模型,從Patchpartion到LinearEmbedding(即類PatchEmbed),到四個(gè)SwinTransformerBlock,以及在SwinTransformerBlock中使用是否使用PatchMerging,經(jīng)過(guò)四個(gè)階段的SwinTransformerBlock之后輸出展平的向量。
Model.py 的函數(shù)
window_partition:對(duì)特征圖進(jìn)行劃分,劃分成一個(gè)一個(gè)沒有重疊的window
window_reverse:將window還原成特征圖
定義各種模型,用于實(shí)例化模型
swin_tiny_patch4_window7_224
swin_small_patch4_window7_224
swin_base_patch4_window7_224
swin_base_patch4_window12_384
swin_base_patch4_window7_224_in22k
swin_base_patch4_window12_384_in22k
swin_large_patch4_window7_224_in22k
swin_large_patch4_window12_384_in22k
My_dataset.py只有類
MyDataSet(Dataset):構(gòu)建獲取數(shù)據(jù)集中元素和大小的方法
@staticmethod
collate_fn(batch):用于單獨(dú)調(diào)用使用,將一個(gè)批次的圖片轉(zhuǎn)為向量并拼在一起
Predict.py只有函數(shù)
main(): 創(chuàng)建預(yù)測(cè)圖片類別的函數(shù),展示預(yù)測(cè)的圖片以及被預(yù)測(cè)圖片屬于每個(gè)類別的概率
if name == ‘main’:
main()
開始預(yù)測(cè)
Train.py只有函數(shù)
main(args)
獲取訓(xùn)練集和驗(yàn)證集,對(duì)圖片進(jìn)行處理,調(diào)整兩個(gè)數(shù)據(jù)集中圖片的大小,實(shí)例化模型,訓(xùn)練模型,保存模型。
自定義參數(shù),解析參數(shù),調(diào)用并執(zhí)行main(args),訓(xùn)練分類模型
Utils.py只有函數(shù)
read_split_data:讀取圖片和圖片的類別,劃分訓(xùn)練集和驗(yàn)證集
train_one_epoch:
定義損失函數(shù):torch.nn.CrossEntropyLoss()
進(jìn)行一個(gè)epoch的訓(xùn)練,返回?fù)p失和精確率
Evaluate
4 流程
運(yùn)行train.py訓(xùn)練模型,訓(xùn)練了個(gè)epoch,最高精確率可到96.6%
5 效果演示
運(yùn)行predict.py對(duì)單獨(dú)一張圖片進(jìn)行預(yù)測(cè)類別
總結(jié)
以上是生活随笔為你收集整理的Swintransformer详细设计文档的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Python-类的学习
- 下一篇: Pycharm-列出代码结构