timm 视觉库中的 create_model 函数详解
timm 視覺庫中的 create_model 函數詳解
最近一年 Vision Transformer 及其相關改進的工作層出不窮,在他們開源的代碼中,大部分都用到了這樣一個庫:timm。各位煉丹師應該已經想必已經對其無比熟悉了,本文將介紹其中最關鍵的函數之一:create_model 函數。
timm簡介
PyTorchImageModels,簡稱timm,是一個巨大的PyTorch代碼集合,包括了一系列:
- image models
- layers
- utilities
- optimizers
- schedulers
- data-loaders / augmentations
- training / validation scripts
旨在將各種 SOTA 模型、圖像實用工具、常用的優化器、訓練策略等視覺相關常用函數的整合在一起,并具有復現ImageNet訓練結果的能力。
源碼:https://github.com/rwightman/pytorch-image-models
文檔:https://fastai.github.io/timmdocs/
create_model 函數的使用及常用參數
本小節先介紹 create_model 函數,及常用的參數 **kwargs。
顧名思義,create_model 函數是用來創建一個網絡模型(如 ResNet、ViT 等),timm 庫本身可供直接調用的模型已有接近400個,用戶也可以自己實現一些模型并注冊進 timm (這一部分內容將在下一小節著重介紹),供自己調用。
model_name
我們首先來看最簡單地用法:直接傳入模型名稱 model_name
import timm # 創建 resnet-34 model = timm.create_model('resnet34') # 創建 efficientnet-b0 model = timm.create_model('efficientnet_b0')我們可以通過 list_models 函數來查看已經可以直接創建、有預訓練參數的模型列表:
all_pretrained_models_available = timm.list_models(pretrained=True) print(all_pretrained_models_available) print(len(all_pretrained_models_available))輸出:
[..., 'vit_large_patch16_384', 'vit_large_patch32_224_in21k', 'vit_large_patch32_384', 'vit_small_patch16_224', 'wide_resnet50_2', 'wide_resnet101_2', 'xception', 'xception41', 'xception65', 'xception71'] 452如果沒有設置 pretrained=True 的話有將會輸出612,即有預訓練權重參數的模型有452個,沒有預訓練參數,只有模型結構的共有612個。
pretrained
如果我們傳入 pretrained=True,那么 timm 會從對應的 URL 下載模型權重參數并載入模型,只有當第一次(即本地還沒有對應模型參數時)會去下載,之后會直接從本地加載模型權重參數。
model = timm.create_model('resnet34', pretrained=True)輸出:
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/song/.cache/torch/hub/checkpoints/resnet34-43635321.pthfeatures_only、out_indices
create_mode 函數還支持 features_only=True 參數,此時函數將返回部分網絡,該網絡提取每一步最深一層的特征圖。還可以使用 out_indices=[…] 參數指定層的索引,以提取中間層特征。
# 創建一個 (1, 3, 224, 224) 形狀的張量 x = torch.randn(1, 3, 224, 224) model = timm.create_model('resnet34') preds = model(x) print('preds shape: {}'.format(preds.shape))all_feature_extractor = timm.create_model('resnet34', features_only=True) all_features = all_feature_extractor(x) print('All {} Features: '.format(len(all_features))) for i in range(len(all_features)):print('feature {} shape: {}'.format(i, all_features[i].shape))out_indices = [2, 3, 4] selected_feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=out_indices) selected_features = selected_feature_extractor(x) print('Selected Features: ') for i in range(len(out_indices)):print('feature {} shape: {}'.format(out_indices[i], selected_features[i].shape))我們以一個 (1, 3, 224, 224) 形狀的張量為輸入,在視覺任務中,圖像輸入張量總是類似的形狀。上面例程展示了,創建完整模型 model,創建完整特征提取器 all_feature_extractor,和創建某幾層特征提取器 selected_feature_extractor 的具體輸出。
可以結合下面 ResNet34 的結構圖來理解(圖中不同的顏色表示不同的 layer),根據下圖分析各層的卷積操作,計算各層最后一個卷積的輸入,并與上面例程的輸出(附在圖后)驗證是否一致。
輸出:
preds shape: torch.Size([1, 1000]) All 5 Features: feature 0 shape: torch.Size([1, 64, 112, 112]) feature 1 shape: torch.Size([1, 64, 56, 56]) feature 2 shape: torch.Size([1, 128, 28, 28]) feature 3 shape: torch.Size([1, 256, 14, 14]) feature 4 shape: torch.Size([1, 512, 7, 7]) Selected Features: feature 2 shape: torch.Size([1, 128, 28, 28]) feature 3 shape: torch.Size([1, 256, 14, 14]) feature 4 shape: torch.Size([1, 512, 7, 7])這樣,我們就可以通過 timm_model 函數及其 features_only 、out_indices 參數將預訓練模型方便地轉換為自己想要的特征提取器。
接下來我們來看一下這些特征提取器究竟是什么類型:
import timm feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=[3])print('type:', type(feature_extractor)) print('len: ', len(feature_extractor)) for item in feature_extractor:print(item)輸出:
type: <class 'timm.models.features.FeatureListNet'> len: 7 conv1 bn1 act1 maxpool layer1 layer2 layer3可以看到,feature_extractor 其實也是一個神經網絡,在 timm 中稱為 FeatureListNet,而我們通過 out_indices 參數來指定截取到哪一層特征。
需要注意的是,ViT 模型并不支持 features_only 選項(0.4.12版本)。
extractor = timm.create_model('vit_base_patch16_224', features_only=True)輸出:
RuntimeError: features_only not implemented for Vision Transformer models.create_model 函數究竟做了什么
registry
在了解了 create_model 函數的基本使用之后,我們來深入探索一下 create_model 函數的源碼,看一下究竟是怎樣實現從模型到特征提取器的轉換的。
create_model 主體只有 50 行左右的代碼,因此所有這些神奇的事情是在其他地方完成的。我們知道 timm.list_models() 函數中的每一個模型名字(str)實際上都是一個函數。以下代碼可以測試這一點:
import timm import random from timm.models import registrym = timm.list_models()[-1] print(m) registry.is_model(m)輸出:
xception71 True實際上,在 timm 內部,有一個字典稱為 _model_entrypoints 包含了所有的模型名稱和他們各自的函數。比如說,我們可以通過 model_entrypoint 函數從 _model_entrypoints 內部得到 xception71 模型的構造函數。
constuctor_fn = registry.model_entrypoint(m) print(constuctor_fn)輸出:
<function timm.models.xception_aligned.xception71(pretrained=False, **kwargs)>也有可能輸出:
<function xception71 at 0x7fc0cba0eca0>一樣的。
如我們所見,在 timm.models.xception_aligned 模塊中有一個函數稱為 xception71 。類似的,timm 中的每一個模型都有著一個這樣的構造函數。事實上,內部的 _model_entrypoints 字典大概長這個樣子:
_model_entrypoints > > { 'cspresnet50':<function timm.models.cspnet.cspresnet50(pretrained=False, **kwargs)>,'cspresnet50d': <function timm.models.cspnet.cspresnet50d(pretrained=False, **kwargs)>, 'cspresnet50w': <function timm.models.cspnet.cspresnet50w(pretrained=False, **kwargs)>, 'cspresnext50': <function timm.models.cspnet.cspresnext50(pretrained=False, **kwargs)>, 'cspresnext50_iabn': <function timm.models.cspnet.cspresnext50_iabn(pretrained=False, **kwargs)>, 'cspdarknet53': <function timm.models.cspnet.cspdarknet53(pretrained=False, **kwargs)>, 'cspdarknet53_iabn': <function timm.models.cspnet.cspdarknet53_iabn(pretrained=False, **kwargs)>, 'darknet53': <function timm.models.cspnet.darknet53(pretrained=False, **kwargs)>, 'densenet121': <function timm.models.densenet.densenet121(pretrained=False, **kwargs)>, 'densenetblur121d': <function timm.models.densenet.densenetblur121d(pretrained=False, **kwargs)>, 'densenet121d': <function timm.models.densenet.densenet121d(pretrained=False, **kwargs)>, 'densenet169': <function timm.models.densenet.densenet169(pretrained=False, **kwargs)>, 'densenet201': <function timm.models.densenet.densenet201(pretrained=False, **kwargs)>, 'densenet161': <function timm.models.densenet.densenet161(pretrained=False, **kwargs)>, 'densenet264': <function timm.models.densenet.densenet264(pretrained=False, **kwargs)>,}所以說,在 timm 對應的模塊中,每個模型都有一個構造器。比如說 ResNets 系列模型被定義在 timm.models.resnet 模塊中。因此,實際上我們有兩種方式來創建一個 resnet34 模型:
import timm from timm.models.resnet import resnet34# 使用 create_model m = timm.create_model('resnet34')# 直接調用構造函數 m = resnet34()但使用上,我們無須調用構造函數。所用模型都可以通過 create_model 函數來將創建。
Register model
resnet34 構造函數的源碼如下:
@register_model def resnet34(pretrained=False, **kwargs):"""Constructs a ResNet-34 model."""model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)return _create_resnet('resnet34', pretrained, **model_args)我們會發現 timm 中的每個模型都有一個 register_model 裝飾器。最開始, _model_entrypoints 是一個空字典。我們是通過 register_model 裝飾器來不斷地像其中添加模型名稱和它對應的構造函數。該裝飾器的定義如下:
def register_model(fn):# lookup containing modulemod = sys.modules[fn.__module__]module_name_split = fn.__module__.split('.')module_name = module_name_split[-1] if len(module_name_split) else ''# add model to __all__ in modulemodel_name = fn.__name__if hasattr(mod, '__all__'):mod.__all__.append(model_name)else:mod.__all__ = [model_name]# add entries to registry dict/sets_model_entrypoints[model_name] = fn_model_to_module[model_name] = module_name_module_to_models[module_name].add(model_name)has_pretrained = False # check if model has a pretrained url to allow filtering on thisif hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:# this will catch all models that have entrypoint matching cfg key, but miss any aliasing# entrypoints or non-matching comboshas_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']if has_pretrained:_model_has_pretrained.add(model_name)return fn我們可以看到, register_model 函數完成了一些比較基礎的步驟,但這里需要指出的是這一句:
_model_entrypoints[model_name] = fn它將給定的 fn 添加到 _model_entrypoints 其鍵名為 fn.__name__。所以說 resnet34 函數上的裝飾器 @register_model 在 _model_entrypoints 中創建一個新的條目,像這樣:
{’resnet34’: <function timm.models.resnet.resnet34(pretrained=False, **kwargs)>}我們同樣可以看到在 resnet34 構造函數的源碼中,在設置完一些 model_args 之后,它會隨后調用 _create_resnet 函數。讓我們再來看一下該函數的源碼:
def _create_resnet(variant, pretrained=False, **kwargs):return build_model_with_cfg(ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)所以在 _create_resnet 函數之中,會再調用 build_model_with_cfg 函數并將一個構造器類 ResNet 、變量名 resnet34、一個 default_cfg 和一些 **kwargs 傳入其中。
default config
timm 中所有的模型都有一個默認的配置,包括指向它的預訓練權重參數的URL、類別數、輸入圖像尺寸、池化尺寸等。
resnet34 的默認配置如下:
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth', 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'conv1', 'classifier': 'fc'}此默認配置與其他參數(如構造函數類和一些模型參數)一起傳遞給 build_model_with_cfg 函數。
build model with config
這個 build_model_with_cfg 函數負責:
看一下該函數的源碼:
def build_model_with_cfg(model_cls: Callable,variant: str,pretrained: bool,default_cfg: dict,model_cfg: dict = None,feature_cfg: dict = None,pretrained_strict: bool = True,pretrained_filter_fn: Callable = None,pretrained_custom_load: bool = False,**kwargs):pruned = kwargs.pop('pruned', False)features = Falsefeature_cfg = feature_cfg or {}if kwargs.pop('features_only', False):features = Truefeature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))if 'out_indices' in kwargs:feature_cfg['out_indices'] = kwargs.pop('out_indices')model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)model.default_cfg = deepcopy(default_cfg)if pruned:model = adapt_model_from_file(model, variant)# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for featsnum_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))if pretrained:if pretrained_custom_load:load_custom_pretrained(model)else:load_pretrained(model,num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),filter_fn=pretrained_filter_fn, strict=pretrained_strict)if features:feature_cls = FeatureListNetif 'feature_cls' in feature_cfg:feature_cls = feature_cfg.pop('feature_cls')if isinstance(feature_cls, str):feature_cls = feature_cls.lower()if 'hook' in feature_cls:feature_cls = FeatureHookNetelse:assert False, f'Unknown feature class {feature_cls}'model = feature_cls(model, **feature_cfg)model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfgreturn model我們可以看到,模型在這一步被創建出來:model = model_cls(**kwargs)。本文將不再深入到 pruned 和 adapt_model_from_file 內部查看。
總結
通過本文,我們已經完全了解了 create_model 函數,我們了解到:
- 每個模型有不同的構造函數,可以傳入不同的參數, _model_entrypoints 字典包括了所有的模型名稱及其對應的構造函數
- build_with_model_cfg 函數接收模型構造器類和其中的一些具體參數,真正地實例化一個模型
- load_pretrained 會加載預訓練參數
- FeatureListNet 類可以將模型轉換為特征提取器
Ref:
https://github.com/rwightman/pytorch-image-models
https://fastai.github.io/timmdocs/
https://fastai.github.io/timmdocs/create_model#Turn-any-model-into-a-feature-extractor
https://fastai.github.io/timmdocs/tutorial_feature_extractor
https://zhuanlan.zhihu.com/p/404107277
總結
以上是生活随笔為你收集整理的timm 视觉库中的 create_model 函数详解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 怎么引导大白菜启动不了系统安装 大白菜启
- 下一篇: 做系统就蓝屏怎么解决方法 系统蓝屏,怎么