PyTorch中nn.Module类中__call__方法介绍
? ? ? 在PyTorch源碼的torch/nn/modules/module.py文件中,有一條__call__語句和一條forward語句,如下:
__call__ : Callable[…, Any] = _call_impl
forward: Callable[…, Any] = _forward_unimplemented
? ? ? 在PyTorch中nn.Module類是所有神經網絡模塊的基類,你的網絡也應該繼承這個類,需要重載__init__和forward函數。以下是仿照PyTorch中Module和AlexNet類實現寫的假的實現的測試代碼:
from typing import Callable, Any, Listdef _forward_unimplemented(self, *input: Any) -> None:"Should be overridden by all subclasses"print("_forward_unimplemented")raise NotImplementedErrorclass Module:def __init__(self):print("Module.__init__")forward: Callable[..., Any] = _forward_unimplementeddef _call_impl(self, *input, **kwargs):print("Module._call_impl")result = self.forward(*input, **kwargs)return result__call__: Callable[..., Any] = _call_impldef cpu(self):print("Module.cpu")class AlexNet(Module):def __init__(self):print("AlexNet.__init__")super(AlexNet, self).__init__()def forward(self, x):print("AlexNet.forward")return xmodel = AlexNet()
x: List[int] = [1, 2, 3, 4]
print("result:", model(x))model.cpu()print("test finish")
? ? ? 執行model(x)語句時,會調用AlexNet的forward函數,是因為AlexNet的父類Module中的__call__函數:首先Module中有__call__方法,因此model(x)這條語句可以正常執行。Module中并沒有直接給出__call__的實現體,而是__call__后緊跟冒號,此冒號表示類型注解;后面的Callable和Any是typing模塊中的,Callable表示可調用類型,即等號右邊應該是一個可調用類型,此處指的是_call_impl;Any是一種特殊的類型,它與所有類型兼容;Callable[…, Any]表示_call_impl可接受任意數量的參數并返回Any。這里__call__實際指向了_call_impl函數,因此調用__call__實際是調用_call_impl。
? ? ? typing模塊的介紹參考:https://blog.csdn.net/fengbingchun/article/details/122288737
? ? ? _call_impl函數體內會調用forward,Module中的forward的實現方式與__call__相同,但是_forward_unimplemented函數并沒有實現體,調用它會觸發Error即NotImplementedError。因此在子類AlexNet中一定要給出forward的具體實現,否則調用的將是_forward_unimplemented。
? ? ? 測試代碼執行結果如下:
? ? ? 如果注釋掉AlexNet中的forward,則執行結果如下:
? ? ? GitHub:?https://github.com/fengbingchun/PyTorch_Test
總結
以上是生活随笔為你收集整理的PyTorch中nn.Module类中__call__方法介绍的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Python3中__call__方法介绍
- 下一篇: windows上通过cmake-gui生