python pytorch fft_看PyTorch源代码的心路历程
1. 起因
曾經碰到過別人的模型prelu在內部的推理引擎算出的結果與其在原始框架PyTorch中不一致的情況,雖然理論上大家實現的都是一個算法,但是從參數上看,因為經過了模型轉換,中間做了一些調整。為了確定究竟是初始參數傳遞就出了問題還是在后續傳遞過程中繼續做了更改、亦或者是最終算法實現方面有著細微差別導致最終輸出不同,就想著去看一看PyTorch一路下來是怎么做的。
但是代碼跟著跟著就跟丟了,才會發現,PyTorch真的是一個很復雜的項目,但就像舌尖里面說的,環境越是惡劣,回報越是豐厚。為了以后再想跟蹤的時候方便,因此決定以PReLU為例靜態梳理一下PyTorch的代碼結構。搗鼓的這些天,對如何構建一個帶有C/C++代碼的Python又有了新的了解,這也算是意外的收獲吧。
2. 歷程
首先,我們從PReLU的導入路徑torch.nn.PReLU中知道,他應在徑進torch\nn\之下,進入該路徑雖然沒看到,但是我們在該路徑下的__init__.py中知道,其實它就在torch\nn\modules\activation.py中。類PReLU最終調用了從torch\nn\functional.py導入的prelu方法。順騰摸瓜,找到prelu,它長下面這樣:
def prelu(input, weight):
# type: (Tensor, Tensor) -> Tensor
if not torch.jit.is_scripting():
if type(input) is not Tensor and has_torch_function((input,)):
return handle_torch_function(prelu, (input,), input, weight)
return torch.prelu(input, weight)
經過人腦對代碼的一番執行你會發現,第一個if條件滿足,而第二個if不滿足。因此,最終想看算法,得去看torch.prelu()。好吧,接著干……
一番搜尋之后你會發現,Python代碼中在torch這個包下面你是找不到prelu的定義的。但是絕望之際我們在torch包的__init__.py之中看到看下面幾行代碼:
# pytorch\torch\__init__.py
# 為了簡潔,省去不必要代碼,詳細代碼參見pytorch\torch\__init__.py
try:
# _initExtension is chosen (arbitrarily) as a sentinel.
from torch._C import _initExtension
__all__ += [name for name in dir(_C)
if name[0] != '_' and
not name.endswith('Base')]
if TYPE_CHECKING:
# Some type signatures pulled in from _VariableFunctions here clash with
# signatures already imported. For now these clashes are ignored; see
# PR #43339 for details.
from torch._C._VariableFunctions import * # type: ignore
for name in dir(_C._VariableFunctions):
if name.startswith('__'):
continue
globals()[name] = getattr(_C._VariableFunctions, name)
__all__.append(name)
這是全村最后的希望了。我們知道__all__中的名字其實就是該模塊有意暴露出去的API。
什么意思呢?也就是說雖然我們明文上已經看不到了prelu的定義,但是這幾行代碼表明有一大堆身份不明的API被暗搓搓的導入了,這其中就很有可能存在我們朝思暮想的prelu。
那么我們怎么憑借這么一點微弱的線索確定我們的猜測到底對不對呢?這里我們就用到了Python的一個關鍵知識:C/C++擴展。(戳這里《使用C語言編寫Python模塊-引子》《Python調用C++之PYBIND11簡介》了解更多)
我們知道Python C/C++擴展有著固定的格式,只要我們找到模塊初始化入口,就能順藤摸瓜找到該模塊暴露的給Python解釋器所有函數。Python 3中的初始化函數樣子為PyInit_,其中就是模塊的名字。例如在前面提到的from torch._C import *中,模塊torch下面必要有一個名字為_C的子模塊。因此它的初始化函數應該為PyInit__C,我們搜索該名字就能找到模塊入口。當然另外還有一種方法,就是查看setup.py文件中關于擴展的描述信息:
// pytorch\setup.py
main_sources = ["torch/csrc/stub.c"]
C = Extension("torch._C",
libraries=main_libraries,
sources=main_sources,
language='c',
extra_compile_args=main_compile_args + extra_compile_args,
include_dirs=[],
library_dirs=library_dirs,
extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib'))
extensions.append(C)
不管是通過搜索還是查看setup.py,我們最終都成功定位到了位于pytorch\torch\csrc\stub.c下的模塊初始化函數PyInit__C(void),并進一步跟蹤其調用的函數initModule(),便可以知道具體都暴露了哪些API給Python解釋器。
// pytorch\torch\csrc\stub.c
PyMODINIT_FUNC PyInit__C(void)
{
return initModule();
}
// pytorch\torch\csrc\Module.cpp
initModule()
進入initModule()尋找一番,你會發現,模塊_C中依然沒有prelu的Python接口。怎么辦?莫慌,通過前面對torch.__init__.py的分析,我們知道我們還有希望——_C模塊下的子模塊_VariableFunctions,這真的是最后的希望了!沒了別的路可以走了,只能是硬著頭皮找。經過一番驚天地泣鬼神、艱苦卓絕的尋找,我們在initModule()的調用鏈initModule()->THPVariable_initModule(module)->torch::autograd::initTorchFunctions(module)中發現了_VariableFunctions的蹤影。Aha,simple!
void initTorchFunctions(PyObject* module) {
if (PyType_Ready(&THPVariableFunctions) < 0) {
throw python_error();
}
Py_INCREF(&THPVariableFunctions);
// Steals
Py_INCREF(&THPVariableFunctions);
if (PyModule_AddObject(module, "_VariableFunctionsClass", reinterpret_cast(&THPVariableFunctions)) < 0) {
throw python_error();
}
// PyType_GenericNew returns a new reference
THPVariableFunctionsModule = PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None);
// PyModule_AddObject steals a reference
if (PyModule_AddObject(module, "_VariableFunctions", THPVariableFunctionsModule) < 0) {
throw python_error();
}
}
但是!!別高興太早!查看模塊_VariableFunctions中暴露的接口你會發現,根本就沒有我們想要的!如下面的代碼所示:
static PyMethodDef torch_functions[] = {
{"arange", castPyCFunctionWithKeywords(THPVariable_arange),
METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"as_tensor", castPyCFunctionWithKeywords(THPVariable_as_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"dsmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
{"full", castPyCFunctionWithKeywords(THPVariable_full), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"hsmm", castPyCFunctionWithKeywords(THPVariable_hspmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"randint", castPyCFunctionWithKeywords(THPVariable_randint), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"saddmm", castPyCFunctionWithKeywords(THPVariable_sspaddmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_validate_sparse_coo_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_coo_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"spmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"get_device", castPyCFunctionWithKeywords(THPVariable_get_device), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"numel", castPyCFunctionWithKeywords(THPVariable_numel), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
${py_method_defs}
{NULL}
};
上面的代碼中我們找不到prelu的任何身影。會不會prelu可以繞開C/C++擴展的方式直接被Python使用呢?所以不會出現在這里?答案是不會,自古華山一條路,程序是不會跟你講潛規則的。那么既然最終代碼已經跟丟了,作者一定是使用了黑魔法,作為麻瓜的我無計可施,本文也該結束了……
等等,上面的C代碼中好像混入了奇怪的東西——${py_method_defs}。這種語法好像C/C++語法里面是沒有的,反而是Shell這類腳本里面才會有,難道是新特性?費勁查找了一圈,并沒有發現C/C++中有這種語法,既然不是正經語法,那么混入C/C++中肯定會導致編譯失敗,但是它確實就在那里。那么真相只有一個:它就是個占位符,后面肯定會有真正的代碼替換它!
接下來怎么辦?搜索!使用py_method_defs作為關鍵字全局搜索,最終我們會發現,確實是有一個Python腳本對這個占位符進行了替換,而替換的結果就是我們一直尋找的prelu終于出現在了模塊_VariableFunctions之中。好,破案了。
但是就像警察破案,即便有單個證據,也要找到其他證據形成完整證據鏈才能使得證據具有說服力。雖然我們通過搜索得知了prelu會出現在模塊_VariableFunctions中,但是它究竟怎么來的目前還是很模糊:占位符在什么時候被誰調用的腳本進行了替換?
實際上,這一切都是有跡可循的。蹤跡依舊在setup.py中。進入setup.py的主函數,在調用setup函數之前會看到一個名為build_deps()的函數調用,此函數最終會調用指定平臺的CMake去按照根目錄下CMakeLists.txt中的腳本進行構建。根目錄下的CMakeLists.txt最終又會調用到caffe2目錄下的CMakeLists.txt(add_subdirectory(caffe2)),而caffe2/CMakeLists.txt中就會調用到進行代碼生成的Python腳本,如下所示:
代碼生成腳本起調過程示意圖
// pytorch\caffe2\CMakeLists.txt
add_custom_command( OUTPUT
${TORCH_GENERATED_CODE}
COMMAND
"${PYTHON_EXECUTABLE}" tools/setup_helpers/generate_code.py
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
--native-functions-path "aten/src/ATen/native/native_functions.yaml"
--nn-path "aten/src"
$:--disable-autograd>
$:--selected-op-list-path="${SELECTED_OP_LIST}">
--force_schema_registration
進行代碼生成的主要流程如下面代碼塊所示,其大概流程是main()先解析傳遞給腳本的參數,之后將參數傳遞給generate_code()。結合caffe2/CMakeLists.txt中腳本調用時傳遞的參數可知,generate_code()中的是三個gen_*()函數都得到了調用,而在gen_autograd_python()會調用到一個名為create_python_bindings()的函數,這個函數就是真正執行代碼生成的地方。
代碼生成器調用流程示意圖
// tools/setup_helpers/generate_code.py
def generate_code(ninja_global=None,
declarations_path=None,
nn_path=None,
native_functions_path=None,
install_dir=None,
subset=None,
disable_autograd=False,
force_schema_registration=False,
operator_selector=None):
if subset == "pybindings" or not subset:
gen_autograd_python(
declarations_path or DECLARATIONS_PATH,
native_functions_path or NATIVE_FUNCTIONS_PATH,
autograd_gen_dir,
autograd_dir)
if operator_selector is None:
operator_selector = SelectiveBuilder.get_nop_selector()
if subset == "libtorch" or not subset:
gen_autograd(
declarations_path or DECLARATIONS_PATH,
native_functions_path or NATIVE_FUNCTIONS_PATH,
autograd_gen_dir,
autograd_dir,
disable_autograd=disable_autograd,
operator_selector=operator_selector,
)
if subset == "python" or not subset:
gen_annotated(
native_functions_path or NATIVE_FUNCTIONS_PATH,
python_install_dir,
autograd_dir)
def main():
parser = argparse.ArgumentParser(description='Autogenerate code')
parser.add_argument('--declarations-path')
parser.add_argument('--native-functions-path')
parser.add_argument('--nn-path')
parser.add_argument('--ninja-global')
parser.add_argument('--install_dir')
parser.add_argument(
'--subset',
help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.'
)
parser.add_argument(
'--disable-autograd',
default=False,
action='store_true',
help='It can skip generating autograd related code when the flag is set',
)
parser.add_argument(
'--selected-op-list-path',
help='Path to the YAML file that contains the list of operators to include for custom build.',
)
parser.add_argument(
'--operators_yaml_path',
help='Path to the model YAML file that contains the list of operators to include for custom build.',
)
parser.add_argument(
'--force_schema_registration',
action='store_true',
help='force it to generate schema-only registrations for ops that are not'
'listed on --selected-op-list'
)
options = parser.parse_args()
generate_code(
options.ninja_global,
options.declarations_path,
options.nn_path,
options.native_functions_path,
options.install_dir,
options.subset,
options.disable_autograd,
options.force_schema_registration,
# options.selected_op_list
operator_selector=get_selector(options.selected_op_list_path, options.operators_yaml_path),
)
if __name__ == "__main__":
main()
// pytorch\tools\autograd\gen_autograd.py
def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir):
from .load_derivatives import load_derivatives
differentiability_infos = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
template_path = os.path.join(autograd_dir, 'templates')
# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_python
gen_autograd_functions_python(
out, differentiability_infos, template_path)
# Generate Python bindings
from . import gen_python_functions
deprecated_path = os.path.join(autograd_dir, 'deprecated.yaml')
gen_python_functions.gen(
out, native_functions_path, deprecated_path, template_path)
// pytorch\tools\autograd\gen_python_functions.py
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Main Function
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
methods = load_signatures(native_yaml_path, deprecated_yaml_path, method=True)
create_python_bindings(
fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True)
functions = load_signatures(native_yaml_path, deprecated_yaml_path, method=False)
create_python_bindings(
fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False)
create_python_bindings(
fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False)
create_python_bindings(
fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False)
create_python_bindings(
fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)
def create_python_bindings(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
module: Optional[str],
filename: str,
*,
method: bool,
) -> None:
"""Generates Python bindings to ATen functions"""
py_methods: List[str] = []
py_method_defs: List[str] = []
py_forwards: List[str] = []
grouped: Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
for pair in pairs:
if pred(pair.function):
grouped[pair.function.func.name.name].append(pair)
for name in sorted(grouped.keys(), key=lambda x: str(x)):
overloads = grouped[name]
py_methods.append(method_impl(name, module, overloads, method=method))
py_method_defs.append(method_def(name, module, overloads, method=method))
py_forwards.extend(forward_decls(name, overloads, method=method))
fm.write_with_template(filename, filename, lambda: {
'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}',
'py_forwards': py_forwards,
'py_methods': py_methods,
'py_method_defs': py_method_defs,
})
最終通過查看native_functions.yaml的內容以及深入跟蹤加載native_functions.yaml的代碼發現,native_functions.yaml中的prelu最終會被寫到以python_torch_functions.cpp為模板的文件中,也就是調用
create_python_bindings(
fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False)
的時候被生成。整個生成的過程其實是很繁瑣的,一層層跟蹤后可以發現,最終生成的代碼可以實現將一個名為at::的函數暴露給Python。例如我們的prelu,暴露給Python的API最終會調用一個名為at::prelu()的函數來做真正的計算。那么這個at::(例如at::prelu())的定義又在哪里呢?
還是一樣,故技重施!仍然使用Python腳本根據native_functions.yaml文件中的內容去以pytorch\aten\src\ATen\templates目錄下的各種模板去生成對應的實際C++源文件。最終結果是得到at::,在這個函數中,它調用了Dispatcher這個類尋找到目標函數的句柄。通常情況下能夠使用的函數句柄都通過一個叫Library的類來管理。Python腳本以RegisterSchema.cpp為模板,生成了注冊這些目標函數的注冊代碼,并通過一個名為TORCH_LIBRARY的宏調用Library類來注冊管理。
#define TORCH_LIBRARY(ns, m) \
static void TORCH_LIBRARY_init_ ## ns (torch::Library&); \
static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \
torch::Library::DEF, \
&TORCH_LIBRARY_init_ ## ns, \
#ns, c10::nullopt, __FILE__, __LINE__ \
); \
void TORCH_LIBRARY_init_ ## ns (torch::Library& m)
class TorchLibraryInit final {
private:
using InitFn = void(Library&);
Library lib_;
public:
TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional<:dispatchkey> k, const char* file, uint32_t line)
: lib_(kind, ns, k, file, line) {
fn(lib_);
}
};
PyTorch組成示意圖
3. 總結
PyTorch雖然在使用上是非常的Pythonic,但實際上Python只不過是為了方便使用裹在C++代碼上的一層糖衣。用起來雖然好用,但是看起來實在是非常費勁,特別是如果靜態的梳理代碼,很多用于連接Python C/C++接口與實際邏輯代碼之間的C++代碼都是通過Python腳本生成的。至此,整個大的線索已經摸清了,剩下的就是去查看具體細節的實現。
說實話,人腦執行Python代碼之后再去理解C++代碼實在是費勁,也費頭發。因此我決定的讓電腦去生成C++代碼再接著看更具體的細節,比如究竟每一個算子是怎么注冊到Library之中的。
4. Bonus
我真心懷疑我們生活在一個虛擬機里,為什么呢?因為到處可見運用于計算機里面的空間和時間局部性原理的實例。就在我寫完這個博客的時候,意外的發現了一篇PyTorch工程師講解PyTorch內部原理的博文,這對后續讀代碼應該會有很大幫助。等不及就戳它吧 http://blog.ezyang.com/2019/05/pytorch-internals/
與50位技術專家面對面20年技術見證,附贈技術全景圖總結
以上是生活随笔為你收集整理的python pytorch fft_看PyTorch源代码的心路历程的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: TGA 2023 年度最佳游戏提名公布:
- 下一篇: 早报:vivo X100系列发布 iQO