TVM实战
TVM實(shí)戰(zhàn)
問(wèn)題的由來(lái)
最近客戶反饋我們的backend導(dǎo)入Pytorch模型會(huì)出錯(cuò),而TFLite模型是OK的。
打印模型的IR后,我們發(fā)現(xiàn):
這是Pytorch模型的IR片段:
%0 = qnn.quantize(%input, 0.0186579f, 114, out_dtype="uint8", axis=1);%1 = nn.pad(%0, 114f, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]);%2 = qnn.quantize(%features.0.0_weight, 0.00288958f, 0, out_dtype="int8", axis=0);%3 = qnn.conv2d(%1, %2, 114, 0, 0.0186579f, 0.00288958f, strides=[2, 2], padding=[0, 0, 0, 0], channels=32, kernel_size=[3, 3], out_dtype="int32");%4 = qnn.quantize(%features.0.0_bias, 5.39136e-05f, 0, out_dtype="int32", axis=0);%5 = nn.bias_add(%3, %4);%6 = qnn.requantize(%5, 5.39136e-05f, 0, 0.0150183f, 0, axis=1, out_dtype="int32");%7 = clip(%6, a_min=0f, a_max=255f);%8 = cast(%7, dtype="uint8");這是TFLite模型的IR片段:
%0 = qnn.quantize(%input, 0.0186579f /* ty=float32 */, 114 /* ty=int32 */, out_dtype="uint8", axis=1) /* ty=Tensor[(1, 3, 224, 224), uint8] */;%1 = nn.pad(%0, 114f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* ty=Tensor[(1, 3, 226, 226), uint8] */;%2 = qnn.conv2d(%1, meta[relay.Constant][0] /* ty=Tensor[(32, 3, 3, 3), int8] */, 114 /* ty=int32 */, 0 /* ty=int32 */, 0.0186579f /* ty=float32 */, 0.00288958f /* ty=float32 */, strides=[2, 2], padding=[0, 0, 0, 0], channels=32, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 32, 112, 112), int32] */;%3 = nn.bias_add(%2, meta[relay.Constant][1] /* ty=Tensor[(32), int32] */) /* ty=Tensor[(1, 32, 112, 112), int32] */;%4 = qnn.requantize(%3, 5.39136e-05f /* ty=float32 */, 0 /* ty=int32 */, 0.0150183f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="uint8") /* ty=Tensor[(1, 32, 112, 112), uint8] */;可以看出同一個(gè)QnnConv2D兩者的展開(kāi)形式存在一定的差異,但是語(yǔ)義上卻基本是一致的。
Relay IR
在繼續(xù)主題之前,我們首先來(lái)看看這個(gè)展開(kāi)形式的差異是如何造成的。
TFLite/Pytorch模型導(dǎo)入之后,有一個(gè)轉(zhuǎn)換成Relay IR的過(guò)程,也就是所謂的frontend。
相關(guān)代碼在:
TFLite: python/tvm/relay/frontend/tflite.py
Pytorch: python/tvm/relay/frontend/pytorch.py和python/tvm/relay/frontend/qnn_torch.py
可以看出Pytorch的量化模型的weight是浮點(diǎn)數(shù),而TFLite的量化模型的weight是整數(shù)。
有frontend自然就有backend:
python/tvm/relay/op/contrib/ethosn.py
和切割計(jì)算圖有關(guān)的代碼主要是:
seq = tvm.transform.Sequential([......transform.MergeComposite(pattern_table()),transform.AnnotateTarget("ethos-n"),transform.MergeCompilerRegions(),transform.PartitionGraph(),] ) seq(mod)這里的Pass基本看名字就知道功能了,唯一需要關(guān)注的是MergeComposite。
這是pattern_table的其中一個(gè)表項(xiàng):
def qnn_conv_pattern():pattern = is_op("nn.pad")(wildcard(), wildcard()) | wildcard()pattern = is_op("qnn.conv2d")(pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant())pattern = is_op("nn.bias_add")(pattern, is_constant())pattern = is_op("qnn.requantize")(pattern, is_constant(), is_constant(), is_constant(), is_constant())return pattern一般來(lái)說(shuō)IR會(huì)定的比較細(xì)粒度,而AI硬件更喜歡粗粒度的op。所以往往若干個(gè)IR op組合起來(lái),才能得到一個(gè)AI硬件op。這時(shí)就需要進(jìn)行模板匹配。
一般來(lái)說(shuō),數(shù)據(jù)可以分為變量和常量(is_constant()),如果既可能是變量,也可能是常量的話,就用wildcard()匹配。
MergeComposite會(huì)將這個(gè)模板打包成一個(gè)函數(shù),變量會(huì)寫(xiě)為函數(shù)的參數(shù),而常量放到全局的meta data里。
添加Pass
下面我們來(lái)看看如何添加Pass進(jìn)行這樣的轉(zhuǎn)換。
class QnnQuantizeConstFold : public DFPatternRewrite {public:QnnQuantizeConstFold() {data_pat_ = IsConstant();pattern_ = IsOp("qnn.quantize")({data_pat_, IsConstant(), IsConstant()});}Expr Callback(const Expr& pre, const Expr& post,const Map<DFPattern, Array<Expr>>& node_map) const override {......if (output->dtype == DataType::Int(8)) {return QuantizeData<int8_t>(......);} else if (output->dtype == DataType::Int(32)) {return QuantizeData<int32_t>(......);}return post;}protected:DFPattern data_pat_; };Rewrite_是Pass最重要的函數(shù)。其中最簡(jiǎn)單的當(dāng)屬DFPatternRewrite。
Callback的參數(shù)中,pre表示原始Expr,post表示替換后的Expr,返回值也是替換后的Expr。(方便使用鏈?zhǔn)秸{(diào)用?)
如果Pass什么都不做的話,直接返回post就可以了。
PS:雖然pre和post在運(yùn)行之初是相同的,但是一旦替換開(kāi)始,兩者就有差異了,所以如果是寫(xiě)入的內(nèi)容,一定要引用post里的那份。
Python版的Pass
Pass不僅能用C++寫(xiě),也可用python寫(xiě)。
class QnnQuantizeConstFold(tvm.relay.dataflow_pattern.DFPatternCallback):def __init__(self, require_type=False):super().__init__(require_type)self.pattern = is_op("qnn.quantize")(is_constant(), is_constant(), is_constant())def callback(self, pre, post, node_map):......if (dtype == "int8"):return tvm.relay.Constant(tvm.nd.array(data.astype(np.int8)))if (dtype == "int32"):return tvm.relay.Constant(tvm.nd.array(data.astype(np.int32)))return post可以看出,寫(xiě)法也是大同小異,只是更便于集成,也不用寫(xiě)FFI接口了。
使用方法:
func = mod["main"] func = tvm.relay.Function(func.params, func.body, None, func.type_params, func.attrs) func = tvm.relay.dataflow_pattern.rewrite(RemoveClipAfterRequantize(), func) func = tvm.relay.dataflow_pattern.rewrite(QnnQuantizeConstFold(), func) mod = tvm.IRModule.from_expr(func)這里展示了DFPatternCallback的使用方法,還有IRModule和Expr的相互轉(zhuǎn)換方法。
參考:
https://tvm.apache.org/docs/reference/langref/relay_pattern.html
Pattern Matching in Relay
tvmc
tvmc是tvm提供的一個(gè)命令行接口。
export TARGET="bnns, llvm -device=arm_cpu -mtriple=aarch64-linux-gnu" python3 -m tvm.driver.tvmc compile ./mobilenet_v1_0.25_224_quant.tflite --target "$TARGET" -o tvmc.tar --cross-compiler "$CC" --cross-compiler-options "$CC_OPTIONS"要點(diǎn)如下:
注意TARGET的寫(xiě)法,這展示了如何將一個(gè)包含空格的字符串作為一個(gè)bash變量傳遞給命令行參數(shù)的做法。定義變量時(shí)的雙引號(hào)和使用時(shí)的雙引號(hào)都是必不可少的。
有些backend需要partation,將能執(zhí)行的op分配到該設(shè)備上,因此TARGET也包含了兩部分(用逗號(hào)分隔),其中第一部分用于partation。
相關(guān)代碼在:
python/tvm/driver/tvmc/composite_target.py
總結(jié)
- 上一篇: TVM概述
- 下一篇: LINUX 下编译 ffmpeg