【图像分类】如何使用 mmclassification 训练自己的分类模型
生活随笔
收集整理的這篇文章主要介紹了
【图像分类】如何使用 mmclassification 训练自己的分类模型
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
文章目錄
- 一、數據準備
- 二、模型修改
- 三、模型訓練
- 四、模型效果可視化
- 五、如何分別計算每個類別的精確率和召回率
MMclassification 是一個分類工具庫,這篇文章是簡單記錄一下如何用該工具庫來訓練自己的分類模型,包括數據準備,模型修改,模型訓練,模型測試等等。
MMclassification鏈接:https://github.com/open-mmlab/mmclassification
安裝:https://mmclassification.readthedocs.io/en/latest/install.html
訓練:https://mmclassification.readthedocs.io/en/latest/getting_started.html
一、數據準備
MMclassification 支持 ImageNet 和 cifar 兩種數據格式,我們以 ImageNet 為例來看看數據結構:
|- imagenet | |- classmap.txt | |- train | | |- cls1 | | |- cls2 | | |- cls3 | | |- ... | |- train.txt | |- val | | |- images | |- val.txt假設我們要訓練一個貓狗二分類模型,則需要組織的形式如下:
|- dog_cat_dataset | |- classmap.txt | |- train | | |- dog | | |- cat | |- train.txt | |- val | | |- images | |- val.txt其中,classmap.txt 中的內容如下:
dog 0 cat 1二、模型修改
假設使用 resnet18 來訓練,則我們需要修改的內容主要集中在 config 文件里邊,修改后的config文件 resnet18_b32x8_dog_cat_cls.py 如下:
- 修改類別:將 1000 類改為 2 類
- 修改數據路徑:data
- 如果數據前處理需要修改的話,也可以在config里邊修改
- 因為config是最高級的,所以在這里修改后會覆蓋模型從mmcls庫中讀出來的參數
三、模型訓練
python tools/train.py configs/resnet/resnet18_b32x8_dog_cat_cls.py四、模型效果可視化
python tools/test.py configs/resnet/resnet18_b32x8_dog_cat_cls.py ./models/epoch_99.pth --out result.pkl --show-dir output_cls使用 gradcam 可視化:
python tools/visualizations/vis_cam.py visual_img/4.jpg configs/resnet/resnet18_b32x8_door.py ./models/epoch_99.pth --s ave-path visual_path/4.jpg五、如何分別計算每個類別的精確率和召回率
先進行測試,得到 result.pkl 文件,然后運行下面的程序即可:
python tools/cal_precision.py configs/resnet/resnet18_b32x8_imagenet.py import mmcv import argparse from mmcls.datasets import build_dataset from mmcls.core.evaluation import calculate_confusion_matrix from sklearn.metrics import confusion_matrixdef parse_args():parser = argparse.ArgumentParser(description='calculate precision and recall for each class')parser.add_argument('config', help='test config file path')args = parser.parse_args()return argsdef main():args = parse_args()cfg = mmcv.Config.fromfile(args.config)dataset = build_dataset(cfg.data.test)pred = mmcv.load("./result.pkl")['pred_label']matrix = confusion_matrix(pred, dataset.get_gt_labels())print('confusion_matrix:', matrix)cat_recall = matrix[0,0]/(matrix[0,0]+matrix[1,0])dog_recall = matrix[1,1]/(matrix[0,1]+matrix[1,1])cat_precision = matrix[0,0]/sum(matrix[0])dog_precision = matrix[1,1]/sum(matrix[1])print(' cat_precision:{} \n dog_precison:{} \n cat_recall:{} \n dog_recall:{}'.format(cat_precision, dog_precison, cat_recall, dog_recall))if __name__ == '__main__':main()總結
以上是生活随笔為你收集整理的【图像分类】如何使用 mmclassification 训练自己的分类模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【Transformer】AdaViT:
- 下一篇: 【Attention】Visual At