模型参数量(Params)/模型大小 Pytorch统计模型参数量
生活随笔
收集整理的這篇文章主要介紹了
模型参数量(Params)/模型大小 Pytorch统计模型参数量
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
模型參數量大小可以從保存的checkpoint文件直觀看出來
total_params = sum(p.numel() for p in model.parameters()) total_params += sum(p.numel() for p in model.buffers()) print(f'{total_params:,} total parameters.') print(f'{total_params/(1024*1024):.2f}M total parameters.') total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'{total_trainable_params:,} training parameters.') print(f'{total_trainable_params/(1024*1024):.2f}M training parameters.')有的地方算參數量/模型大小會乘以4,因為模型參數一般都是FP32存儲的,FP32是單精度,占4個字節。要看具體的概念定義
如果統計各個部分的參數量
考慮一下是否需要統計buffer
_dict = {} for _,param in enumerate(model.named_parameters()):# print(param[0])# print(param[1])total_params = param[1].numel()# print(f'{total_params:,} total parameters.')k = param[0].split('.')[0]if k in _dict.keys():_dict[k] += total_paramselse:_dict[k] = 0_dict[k] += total_params# print('----------------') for k,v in _dict.items():print(k)print(v)print("%3.3fM parameters" % (v / (1024*1024)))print('--------')李宏毅的深度學習作業中是這么寫的
def count_parameters(model, only_trainable=False):if only_trainable:return sum(p.numel() for p in model.parameters() if p.requires_grad)else:return sum(p.numel() for p in model.parameters())另一種方法
到時候把dict的item換成你所用模型的
def print_architecture(model):name = type(model).__name__result = '-------------------%s---------------------\n' % nametotal_num_params = 0for i, (name, child) in enumerate(model.named_children()):num_params = sum([p.numel() for p in child.parameters()])total_num_params += num_paramsfor i, (name, grandchild) in enumerate(child.named_children()):num_params = sum([p.numel() for p in grandchild.parameters()])result += '[Network %s] Total number of parameters : %.3f M\n' % (name, total_num_params / (1024*1024))result += '-----------------------------------------------\n'print(result)print(model) total_params = sum(p.numel() for p in model.parameters()) print(f'{total_params:,} total parameters.') _dict = {} _dict['encoder'] = 0 _dict['decoder'] = 0 _dict['stn_head'] = 0 for _,param in enumerate(model.named_parameters()):print(param[0])# print(param[1])total_params = param[1].numel()print(f'{total_params:,} total parameters.')k = param[0].split('.')[0]if k in _dict.keys():_dict[param[0].split('.')[0]] += total_paramselse:_dict[k] = 0_dict[param[0].split('.')[0]] += total_paramsprint('----------------') for k,v in _dict.items():print(k)print(v)print("%3.3fM parameters\n" % (v / (1024*1024)))print('--------') print_architecture(model)常見模型的參數量
?
總結
以上是生活随笔為你收集整理的模型参数量(Params)/模型大小 Pytorch统计模型参数量的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 高效将PDF转换成Word,迅捷转换器很
- 下一篇: 打不过就加入!字节70w年薪架构师终是败