【PyTorch】 99%程序员都不知道, 深度学习还能这样玩 (建议收藏)
生活随笔
收集整理的這篇文章主要介紹了
【PyTorch】 99%程序员都不知道, 深度学习还能这样玩 (建议收藏)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
【PyTorch】 99%程序員都不知道, 深度學習還能這樣玩
- 概述
- 遷移學習
- 入住 GitHub
- 項目詳解
- get_data.py (獲取數據)
- get_model (獲取模型)
- 參數詳解
- 使用說明
- 訓練 MNIST
- 訓練 Fashion MNIST
- 訓練 CIFAR 10
- 訓練 CIFAR 100
- 訓練自己的數據
概述
你還在為訓練無從下手而苦惱么?
你還在為模型訓練時間漫長而痛苦么?
你還在為模型準確率提升困難在深夜一個人啜泣么?
今天教大家一個方法, 使得我們的模型起跑線上直接甩開別人幾條街. 隔壁王叔叔都學會了!
遷移學習
遷移學習 (Transfer Learning) 是把已學訓練好的模型參數用作新訓練模型的起始參數.
入住 GitHub
經過幾天的日夜狂肝, 本人完成了在 GitHub 上的第一個項目. 把遷移學習封裝成了一個有手就能用的黑盒模型.
大家只要替換自己的數據集就可以實現多個可選模型遷移學習并自動保存. 就是兩個字簡單
項目詳解
GitHub 鏈接
get_data.py (獲取數據)
目前支持 MNIST, Fashion MNIST, CIFAR 10 和 CIFAR 100 數據集.
可以在```get_data.py``下自行替換成自己需要的數據集:
傳入數據的格式為:
data_loader = {"train": train_loader, "valid": test_loader}get_model (獲取模型)
目前支持:
- resnet18
- resnet34
- resnet50
- resnet101
- resnet152
- alexnet
- squeezenet
- vgg11
- vgg13
- vgg16
- vgg19
替換模型的方法:
python main.py --model_name "模型名稱"例如, 使用 vgg 13:
python main.py --model_name vgg13例如, 使用 resnet 152:
python main.py --model_name resnet152參數詳解
必填參數:
- model_name: 模型名稱, 類型為 string
- num_classes: 輸出類別數, 類型為 int (例如 MNIST 是 10 分類, CIFAR 100 是 100 分類)
重要參數:
- data_name: 數據名稱, 類型為 string, 默認為 CIFAR10
- data_gray: 是否為灰度圖, 類型為 boolean, 默認為 False
- num_epochs: 迭代次數, 類型為 int, 默認為 20
- batch_size: 一個批次的樣本數目, 默認為 512
可選參數 (不建議修改):
- feature_exact: 是否凍層, 類型為 boolean, 默認為 False
- use_pretrained: 是否使用預訓練權重, 類型為 boolean, 默認為 True
- pretrained_model_path: 預訓練權重, 類型為 string, 默認為 pretrained_model/
- model_save_path: 模型保存路徑, 類型為 string, 默認為 “checkpoint/”
- visualize: 模型可視化, 類型為 boolean, 默認為 True
使用說明
首先我們需要cd到文件路徑, 例如:
cd C:\Users\Windows\Desktop\Project\transfer_learning-main訓練 MNIST
使用 resnet18 訓練 MNIST 數據集:
python main.py --data_name MNIST --data_gray True --model_name resnet18 --num_classes 10 --batch_size 512訓練 Fashion MNIST
使用 resnet34 訓練 Fashion MNIST 數據集:
python main.py --data_name FashionMNIST --data_gray True --model_name resnet34 --num_classes 10 --batch_size 512訓練 CIFAR 10
使用 resnet50 訓練 CIFAR 10 數據集:
python main.py --data_name CIFAR10 --model_name resnet50 --num_classes 10 --batch_size 512訓練 CIFAR 100
使用 resnet152 訓練 CIFAR 10 數據集:
python main.py --data_name CIFAR100 --model_name resnet152 --num_classes 100 --batch_size 512訓練自己的數據
python main.py --data_name other --model_name ? --num_classes ? --batch_size ? --epochs ?總結
以上是生活随笔為你收集整理的【PyTorch】 99%程序员都不知道, 深度学习还能这样玩 (建议收藏)的全部內容,希望文章能夠幫你解決所遇到的問題。