关于torch.fx的使用
torch.fx
前言
最近在學(xué)習(xí)一些AI編譯器,推理框架的知識,恰好看到了torch.fx這個部分。這個其實在1.10就已經(jīng)出來了,但是一直不知道,所以花了一點時間學(xué)習(xí)了這部分的內(nèi)容。
以下所有的代碼基于Mac M1 pytorch 1.13,其他的os/版本沒有進(jìn)行測試
1.什么是torch.fx
首先去查看官網(wǎng)docTORCH.FX
FX is a toolkit for developers to use to transform nn.Module instances. 這句話很好的定義了FX的本質(zhì):用來改變module實例的一種工具。包括了三個主要的組件:symbolic tracer intermediate representation python code generation
符號追蹤可以捕獲模塊的語義進(jìn)行解析;中間表示也就是IR記錄了中間的操作,比如輸入輸出和調(diào)用的函數(shù)等;代碼生成這個比較有意思,因為這是一個python-to-python的轉(zhuǎn)換工具,這就從本質(zhì)上區(qū)別了FX與一些AI編譯器,推理庫的區(qū)別。從流程上看,FX與推理庫都是解析模型生成IR,然后融合算子呀優(yōu)化等等,但是FX只是為了優(yōu)化改變模型的功能,最終落腳點還是在python上;而其他的庫都是經(jīng)過一系列優(yōu)化后可以脫離python依賴部署到c++等邊緣環(huán)境上。
2. torch.fx有什么用
既然使用fx可以改變module,那么具體可以有哪些應(yīng)用場景呢?我總結(jié)了下面幾個主要的
- 追蹤模型圖,改變模型部分結(jié)構(gòu),替換某些算子
- 在python代碼的層面對模型進(jìn)行優(yōu)化
- 根據(jù)trace得到的結(jié)果更好的可視化模型
- 對模型進(jìn)行量化
2.1 模型算子替換
首先來看看官網(wǎng)給出的例子
import torch from torch import nn from torch import fx from torch.fx import symbolic_traceclass MyModel(nn.Module):def __init__(self):super().__init__()self.param=nn.Parameter(torch.Tensor([1,2,3,4]))def forward(self,x):return (x+self.param).clamp(min=0.0,max=1.0)model=MyModel()symbolic_traced=symbolic_trace(model) print(symbolic_traced.graph) print(symbolic_traced.code) symbolic_traced.graph.print_tabular()從圖里我們可以清楚地看到模型進(jìn)行的操作以及IR,它也很好的定義了算子的分類(這個對下面部分內(nèi)容很有用)。然后我們?nèi)绻胗胹igmoid替換clamp,如果按照官網(wǎng)以及大多數(shù)已有文章的例子是有錯誤的
# 將clamp轉(zhuǎn)為sigmoid def transform(m):gm=fx.Tracer().trace(m)for node in gm.nodes:if node.op=='call_method':if node.target=="clamp":print(node.target)node.target=torch.sigmoidgm.lint()return fx.GraphModule(m,gm)trans_model=transform(model) print(trans_model.graph) print(trans_model.code) trans_model.graph.print_tabular()很明顯可以看到node.target必須是字符串,所以這樣替換是不對的。而原示例給出的是torch.mul替換torch.add,如果測試那個代碼,node.target==torch.add這個根本不會成立(target是str),所以這里我才將target條件更正。
那怎么替換clamp呢,而且還要驗證替換后模型的結(jié)果無誤差
# 將clamp轉(zhuǎn)為sigmoid def transform(m):gm=fx.Tracer().trace(m)for node in gm.nodes:if node.op=='call_method':if node.name=="clamp":print(node.target)node.target="sigmoid"node.name="sigmoid"node.kwargs={}gm.lint()return fx.GraphModule(m,gm)trans_model=transform(model) print(trans_model.graph) print(trans_model.code) trans_model.graph.print_tabular()從模型打印結(jié)果來看替換是成功的,但是還要經(jīng)過輸出檢驗
class MyModel1(nn.Module):def __init__(self):super().__init__()self.param=nn.Parameter(torch.Tensor([1,2,3,4]))#self.linear=torch.nn.Linear(4,5)def forward(self,x):return (x+self.param).sigmoid()test=MyModel1()inputs = torch.randn(1,4) torch.testing.assert_close(test(inputs),trans_model(inputs))這里沒有任何輸出,證明輸出與gt一致。當(dāng)然不止一種實現(xiàn),下面給出其他兩種
# 將clamp轉(zhuǎn)為sigmoid def transform(m):gm=symbolic_trace(m)for node in gm.graph.nodes:if node.op=='call_method':if node.name=="clamp":print(node.target)node.target="sigmoid"node.name="sigmoid"node.kwargs={}gm.recompile()return gmtrans_model=transform(model) print(trans_model.graph) print(trans_model.code) torch.testing.assert_close(test(inputs),trans_model(inputs))# 將clamp轉(zhuǎn)為sigmoid from torch.fx import replace_patterndef pattern(x):return x.clamp(min=0.0,max=1.0)def replacement(x):return x.sigmoid()replace_pattern(symbolic_traced,pattern,replacement) print(symbolic_traced.graph) print(symbolic_traced.code) torch.testing.assert_close(test(inputs),symbolic_traced(inputs))2.2 算子融合
在做推理部署的時候最常用的就是算子融合,也就是將多個算子的計算在數(shù)學(xué)上進(jìn)行等效替換,從而減少了算子數(shù)量以及整體的計算量,加速了推理時間。torch.fx也給了我們很好的算子融合替換幫助,因為上面說了有了trace我們可以很輕松地對模型算子進(jìn)行替換,例如最常見的conv+bn融合,丟棄dropout
這部分代碼可以參考官方樣例/torch/fx/experimental/optimization.py,我這里直接白嫖過來演示一下
from torch.nn.utils.fusion import fuse_conv_bn_eval from torch.fx.node import Argument, Target from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast import copydef _parent_name(target : str) -> Tuple[str, str]:"""Splits a qualname into parent path and last atom.For example, `foo.bar.baz` -> (`foo.bar`, `baz`)"""*parent, name = target.rsplit('.', 1)return parent[0] if parent else '', namedef matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):if len(node.args) == 0:return Falsenodes: Tuple[Any, fx.Node] = (node.args[0], node)for expected_type, current_node in zip(pattern, nodes):if not isinstance(current_node, fx.Node):return Falseif current_node.op != 'call_module':return Falseif not isinstance(current_node.target, str):return Falseif current_node.target not in modules:return Falseif type(modules[current_node.target]) is not expected_type:return Falsereturn Truedef replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):assert(isinstance(node.target, str))parent_name, name = _parent_name(node.target)modules[node.target] = new_modulesetattr(modules[parent_name], name, new_module)def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:"""Fuses convolution/BN layers for inference purposes. Will deepcopy yourmodel by default, but can modify the model inplace as well."""patterns = [(nn.Conv1d, nn.BatchNorm1d),(nn.Conv2d, nn.BatchNorm2d),(nn.Conv3d, nn.BatchNorm3d)]if not inplace:model = copy.deepcopy(model)fx_model = fx.symbolic_trace(model)modules = dict(fx_model.named_modules())new_graph = copy.deepcopy(fx_model.graph)for pattern in patterns:for node in new_graph.nodes:if matches_module_pattern(pattern, node, modules):if len(node.args[0].users) > 1: # Output of conv is used by other nodescontinueconv = modules[node.args[0].target]bn = modules[node.target]fused_conv = fuse_conv_bn_eval(conv, bn)replace_node_module(node.args[0], modules, fused_conv)node.replace_all_uses_with(node.args[0])new_graph.erase_node(node)return fx.GraphModule(fx_model, new_graph)def remove_dropout(model: nn.Module) -> nn.Module:"""Removes all dropout layers from the module."""fx_model = fx.symbolic_trace(model)class DropoutRemover(torch.fx.Transformer):def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:if isinstance(self.submodules[target], nn.Dropout):assert len(args) == 1return args[0]else:return super().call_module(target, args, kwargs)return DropoutRemover(fx_model).transform()class TestConv2d(nn.Module):def __init__(self,in_channels,out_channels,**kwargs):super(TestConv2d,self).__init__()self.conv=nn.Conv2d(in_channels,out_channels,**kwargs)self.bn=nn.BatchNorm2d(out_channels)self.relu=nn.ReLU(True)def forward(self,x):x=self.conv(x)x=self.bn(x)x=self.relu(x)return xclass TestModel(nn.Module):def __init__(self):super().__init__()self.conv1=TestConv2d(3,32,kernel_size=3)self.conv2=TestConv2d(32,64,kernel_size=3)self.dropout=nn.Dropout(0.3)def forward(self,x):x=self.conv1(x)x=self.conv2(x)x=self.dropout(x)return xdef show(string,count):print(f"{'='*count}{string}{'='*count}")test_model=TestModel()# 在eval下進(jìn)行融合,丟棄 test_model.eval()### origin origin_model=symbolic_trace(test_model) show("origin result",20) print(origin_model.graph) print(origin_model.code)### fusefuse_model=fuse(test_model) fuse_model=remove_dropout(fuse_model) show("fuse result",20) print(fuse_model.graph) print(fuse_model.code)可以看到經(jīng)過算子融合與丟棄,模型沒有了bn dropout十分簡潔。有人會說為什么不把relu也融進(jìn)conv,這在量化中可以實現(xiàn)截斷但是如果是全精度也就是FP32下如果scale和zeropoint不一致沒法量化回來,所以這里并沒有進(jìn)行融合。
2.3 模型可視化
不知道多少人用過torchviz對模型進(jìn)行過可視化,不能說不好只能說根本不直觀。這里我恰好看到了一篇講利用fx進(jìn)行模型結(jié)構(gòu)可視化的博客,可惜博主代碼沒有全部給出來。不過根據(jù)他的文章也算是給了我一種很好的思路,既然我們都有模型的DAG,IR,那我們應(yīng)該可以更加直觀的實現(xiàn)模型結(jié)構(gòu)的可視化。所以這部分就算是完成博主沒有給出來的代碼,模型定義就用博主博客中的模型
利用torch.fx提取PyTorch網(wǎng)絡(luò)結(jié)構(gòu)信息繪制網(wǎng)絡(luò)結(jié)構(gòu)圖 - wrong.wang,大家可以先去看看博主的這篇文章我不過多講重復(fù)內(nèi)容。另外如果想實現(xiàn)功能,還得去研究一下fx解釋器的源碼torch.fx.interpreter — PyTorch 1.13 documentation
from torchviz import make_dot import graphviz import torch.nn.functional as Fclass TestModel(nn.Module):def __init__(self):super(TestModel, self).__init__()self.bias = nn.Parameter(torch.randn(1))self.main = nn.Sequential(nn.Conv2d(3, 4, 1), nn.ReLU(True))self.skip = nn.Conv2d(2, 4, 3, stride=1, padding=1)def forward(self, x, y):x = self.main(x)y = (self.skip(y)+self.bias).clamp(0, 1)x_size = x.size()[-2:]y = F.interpolate(y, x_size, mode="bilinear", align_corners=False)return torch.sigmoid(x) + yx=torch.randn(1,3,16,16) y=torch.randn(1,2,8,8) test_model=TestModel() z=test_model(x,y) g=make_dot(z,params=dict(test_model.named_parameters())) g.render(directory="test",format='svg',view=False)首先用torchviz繪制一下模型
看著這張圖,似懂非懂的樣子,并不能直觀的看到模型的結(jié)構(gòu)。然后開始實現(xiàn)博主的代碼
import tracebackclass Get_IR(torch.fx.Interpreter):def run_node(self,n):try:result=super().run_node(n)except Exception:traceback.print_exc()raise RuntimeError(f"Error while run node:{n.format_node()}")is_find=Falsedef extract_meta(t):if isinstance(t,torch.Tensor):nonlocal is_findis_find=Truereturn _extra_meta(t)else:return tdef _extra_meta(t):if n.op=="call_module":submod=self.fetch_attr(n.target)return {'name':n.name,'op':n.op,'args':n.args,'shape':t.shape,'target':n.target,'kw':n.kwargs,'mod':submod}elif n.op=="call_method":return {'name':n.name,'op':n.op,'args':n.args,'shape':t.shape,'target':n.target,'kw':n.kwargs}elif n.op=="call_function":return {'name':n.name,'op':n.op,'args':n.args,'shape':t.shape,'target':n.target,'kw':n.kwargs}else:return {'name':n.name,'op':n.op,'args':n.args,'shape':t.shape}n.meta["result"]=torch.fx.node.map_aggregate(result,extract_meta)n.meta["find"]=is_findreturn result traced=symbolic_trace(test_model)args=(x,y) kwargs={} _=Get_IR(traced).run(*args,**kwargs) print(traced.graph.print_tabular()) for node in traced.graph.nodes:print(node.meta)其實這部分就是利用解釋器會遍歷圖中的每個節(jié)點,所以我們只需要自定義一下run_node(),在里面加入解析網(wǎng)絡(luò)結(jié)構(gòu),輸入輸出的功能就可以了。
可以看到meta里面已經(jīng)有了模型結(jié)構(gòu)所需要的一切,但是這里雖然打印出來size和getitem是存在的,但是實際上并沒有在條件中解析到,目前還沒找到原因。
def create_str(node):if node.op=="call_module":return f"<<TABLE><TR><TD COLSPAN='2'>{node.meta['result']['mod']}</TD></TR><TR><TD>{node.meta['result']['name']}</TD><TD>{node.meta['result']['shape']}</TD></TR></TABLE>>"elif node.meta['find']:return f"<<TABLE><TR><TD>{node.meta['result']['name']}</TD></TR><TR><TD>{node.meta['result']['shape']}</TD></TR></TABLE>>"else:return f"<<TABLE><TR><TD>{node.meta['result']}</TD></TR></TABLE>>"def single_node(model: torch.nn.Module, graph: graphviz.Digraph, node: torch.fx.Node):node_label = create_str(node) # 生成當(dāng)前節(jié)點的labelnode_kwargs = dict(shape='plaintext',align='center',fontname='monospace')graph.node(node.name, label=node_label, **node_kwargs) # 在Graphviz圖中添加當(dāng)前節(jié)點# 遍歷當(dāng)前節(jié)點的所有輸入節(jié)點,添加Graphviz圖中的邊for in_node in node.all_input_nodes:edge_kwargs = dict()if (not node.meta["find"]or not in_node.meta["find"]):# 如果當(dāng)前節(jié)點的輸入和輸出中都沒有Tensor,就把當(dāng)前邊置為淺灰色虛線,弱化顯示edge_kwargs.update(dict(style="dashed", color="lightgrey"))# 添加當(dāng)前邊graph.edge(in_node.name, node.name, **edge_kwargs)def model_graph(model: torch.nn.Module, *args, **kwargs) -> graphviz.Digraph:# 將nn.Module轉(zhuǎn)換為torch.fx.GraphModule,獲取計算圖symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(model)# 執(zhí)行一下網(wǎng)絡(luò),以此獲取每個節(jié)點輸入輸出的具體信息Get_IR(symbolic_traced).run(*args, **kwargs)# 定義一個Graphviz網(wǎng)絡(luò)graph = graphviz.Digraph("model", format="svg", node_attr={"shape": "plaintext"})for node in symbolic_traced.graph.nodes: # 遍歷所有節(jié)點single_node(model, graph, node)return graphmodel = TestModel()graph = model_graph(model, torch.randn(1, 3, 16, 16), torch.randn(1, 2, 8, 8)) graph.render(directory="test", view=False)這樣來看模型結(jié)果就清晰許多,也和博主給出的結(jié)果高度還原。當(dāng)時就是因為看到了這個結(jié)構(gòu)圖所以讓我好好看了一遍解釋器部分的源碼來實現(xiàn)這個效果,如果未來自己做推理框架希望也能很清晰直觀地給出模型結(jié)構(gòu)圖這和簡單易用一樣都是最基本的。
2.4 量化
在不大幅度減小模型精度的情況下,對已有訓(xùn)練好的模型以低精度執(zhí)行計算這就是量化。一般對于pytorch就是從FP32(FP16如果有amp)轉(zhuǎn)到INT8
可以參考torch的官方文檔https://pytorch.org/docs/master/quantization.html#prototype-fx-graph-mode-quantization
利用fx可以輕松的插入量化節(jié)點,并進(jìn)行校準(zhǔn)。不過量化需要已知數(shù)據(jù)分布,所以下面的步驟就是
這里我就用resnet18在cifar10上訓(xùn)練得到模型為例,訓(xùn)練部分的代碼網(wǎng)上很多這里就不再給出
model=resnet18(pretrained=True) model.fc=nn.Linear(model.fc.in_features,10)if not os.path.exists("raw.pth"):train_model(model,train_loader,test_loader,10,torch.device("mps:0"))torch.save(model.state_dict(),"raw.pth")這里說個坑哈,千萬別用mac訓(xùn)練太慢了。如果用cuda估計幾分鐘以內(nèi)就算完了,但是因為用服務(wù)器不能多屏還是覺得不好所以忍著在mac上訓(xùn)練(順便摸摸魚)
然后開始量化,參考https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_dynamic.html#post-training-dynamic-quantization進(jìn)行后訓(xùn)練動態(tài)量化
print(torch.backends.quantized.supported_engines)這個很重要,得知道使用的平臺支持的engine
import os import time import copyimport torch from torch import nn from torch import optim import torch.nn.functional as Fimport torchvision from torchvision import transforms from torchvision.models.resnet import resnet18from torch.quantization.quantize_fx import prepare_fx,convert_fx from torch.ao.quantization import get_default_qconfig_mapping from torch.ao.quantization.fx.graph_module import ObservedGraphModulemodel=resnet18(pretrained=True) model.fc=nn.Linear(model.fc.in_features,10) model.load_state_dict(torch.load("raw.pth",map_location='cpu')) model.to(torch.device("cpu")) model.eval()torch.backends.quantized.engine = 'qnnpack' qconfig_mapping=get_default_qconfig_mapping("qnnpack")model_to_quantize=copy.deepcopy(model) prepared_model=prepare_fx(model_to_quantize,qconfig_mapping,example_inputs=torch.randn([1,3,224,224])) print(f"prepared model {prepared_model.graph.print_tabular()}")quantized_model=convert_fx(prepared_model) print(f"{'='*100}") print(f"quantized model {quantized_model.graph.print_tabular()}")這里就載入訓(xùn)練好的模型,然后進(jìn)行量化。根據(jù)官網(wǎng)的例子找到核心內(nèi)容仿照就好
可以看到圖中轉(zhuǎn)為了torch.quint8,模型的大小肯定也縮小了很多
def print_size_of_model(model):torch.save(model.state_dict(),"tmp.pt")print(f"The model size:{os.path.getsize('tmp.pt')/1e6}MB")os.remove("tmp.pt")print_size_of_model(prepared_model) print_size_of_model(quantized_model)模型大小差不多變成了原來的1/4,但是光變小不行還得看精度
# 測試一下精度 train_loader,test_loader=prepare_dataloader() example_data=torch.randn([1,3,224,224]) out1=model(example_data) out2=quantized_model(example_data)print(torch.allclose(out1,out2,1e-3))out1 out2evaluate_model(model,test_loader,device='cpu') evaluate_model(quantized_model,test_loader,device='cpu')直接G了,這什么鬼呀雖然推理時間差不多少了一半但是這準(zhǔn)確率跟瞎猜差不多了,這可不行!!!所以還需要進(jìn)行量化的重要一步:校準(zhǔn)
我們需要已知數(shù)據(jù)分布的情況下對模型進(jìn)行量化才能使量化后的模型依然保持準(zhǔn)確率,所以下面就進(jìn)行量化校準(zhǔn)
# 校準(zhǔn)恢復(fù)精度 model_to_quantize=copy.deepcopy(model) prepared_model=prepare_fx(model_to_quantize,qconfig_mapping,example_inputs=torch.randn([1,3,224,224])) prepared_model.eval() with torch.inference_mode():for inputs,labels in test_loader:prepared_model(inputs)quantized_recover_model=convert_fx(prepared_model)out3=quantized_recover_model(example_data)print(torch.allclose(out1,out3,1e-3)) out3 evaluate_model(quantized_recover_model,test_loader,device='cpu')雖然這里精度并沒有對齊,但是準(zhǔn)確率還是恢復(fù)上來了。對于邊緣,移動端的部署來說,這么一點點微小的準(zhǔn)確率損失可以換來存儲占用小75%,推理速度提高一倍,這是誰都能接受的。
最后
看了AI編譯器,推理框架后再來看fx,總感覺相似但是又不同。就像之前說的本質(zhì)上二者就不同,fx只存在于python而不考慮硬件部署上,但是如果我們首先利用fx在python端盡力優(yōu)化好然后再去推理框架上微調(diào)一下結(jié)構(gòu),那會比反復(fù)調(diào)整推理框架適應(yīng)所有可能的算子輕松很多,畢竟python還是比c++寫起來坑少很多的,而且這樣的話推理框架就可以很自然的附帶出python的推理api,希望以后有時間我可以根據(jù)這個思路早點寫出來。
總結(jié)
以上是生活随笔為你收集整理的关于torch.fx的使用的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 汇聚“地表最强”云原生战队 云原生技术实
- 下一篇: 服务器上的SDR传感器状态,ipmito