将编译器pass添加到Relay
將編譯器pass添加到Relay
編譯器pass是擴展Relay功能集和對Relay程序執行優化的主要接口。通過編寫編譯器pass,可以修改AST或收集有關AST的信息,具體取決于目標。事實上,Relay的一些最重要的內置功能(如autodiff和類型推斷),只不過是“標準”編譯器pass。
在高層次上,寫pass有兩個關鍵組成部分:
創建一個或多個遍歷程序的C++類
將遍歷實現及元數據包裝在pass manager API中,以便可以與pass基礎結構完整交互。
首先,將概述編寫編譯器pass的關鍵機制。然后,將介紹一個Relay中常量折疊pass的具體示例。
AST遍歷器
用于遍歷Relay程序的基類是ExprFunctor。提供的公共接口是一個VisitExpr方法,接受一個表達式和零個或多個參數,返回某種類型的實例。擴展此類時,可以通過為每種類型的表達式重寫VisitExpr_ f的實現,定義AST遍歷模式。
VisitExpr和VisitExpr_間的關系與調度有關。每個VisitExpr_定義都針對特定類型的表達式,但不總是知道要訪問的節點類型。為了解決這個問題,ExprFunctor提供了一個VisitExpr函數,該函數從給定的表達式路由到處理VisitExpr_案例。盡管C++已經提供了動態調度,ExpPrimor還是定義了VisteExPR使用的VTe表。通過定義vtable,可以更好地控制調度。例如,如果想定義一個PrintVisitor遍歷器,在每次訪問前打印“Here”,可以覆蓋VisitExpr:
void PrintVisitor::VisitExpr(const Expr& expr) {
std::cout << “Here” << std::endl;
ExprFunctor::VisitExpr(expr);
}
ExprFunctor本身是一個非常通用的類,這就是為什么經常會擴展ExprVisitor或ExprMutator。這些類擴展了ExprFunctor,提供了VisitExpr_的默認實現,該實現獲取每種表達式類型的公共遍歷模式。擁有這些默認實現,不同行為的表達式類型提供覆蓋實現。在下面的介紹中,將單獨描述每個子類。
ExprVisitor
ExprVisitor用于不修改程序,執行程序分析和收集信息的過程。在這個類中,VisitExpr和私有對應項不返回任何內容。此類提供的VisitExpr_實現,只需訪問表達式的所有字段即可。IfNode的默認實現如下所示。
void ExprVisitor::VisitExpr_(const IfNode* op) {
this->VisitExpr(op->cond);
this->VisitExpr(op->true_branch);
this->VisitExpr(op->false_branch);
}
在這里調用的是VisitExpr,不是VisitExpr,可以使用vtable in ExprFunctor進行路由。
現在,如果想編寫一個類調用檢查器,檢查程序中是否出現任何函數調用,只需要擴展ExprVisitor,定義以下VisitExpr_方法:
void VisitExpr_(const CallNode* n) final {
result_ = true;
}
其中result_是一個字段。在這種情況下,不需要在CallNode的字段上進一步遞歸,因為result_已經為true,原始表達式包含一個調用。為了使visitor可用,將提供以下公共方法:
bool Check(const Expr& expr) final {
result_ = false;
VisitExpr(expr);
return result_;
}
這就是所需要的。在調用頂級遞歸前,定義一個公共接口,執行一些bookkeeping記錄是非常常見的。當然,可以通過創建一個獨立的pass,進一步包裝API,該pass創建一個CallChecker實例調用Check,只花了很少的努力就實現了目標。
Expression Mutators
ExprMutator用于以某種方式轉換程序的pass。使用該類,VisitExpr及私有對應項返回Expr。此類提供的默認VisitExpr_,實現訪問表達式的所有字段,這些字段都是表達式,將這些字段設置為訪問結果。TupleGetItemNode的默認實現如下所示。
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
auto t = this->Mutate(g->tuple);
if (g->tuple == t) {
return GetRef(g);
} else {
return TupleGetItem(t, g->index);
}
}
這里有幾件事需要注意。首先,Mutate是ExprMutator中VisitExpr的別名。其次,如果Mutate調用修改了tuple字段,只返回一個新節點。這種更新方法稱為功能更新,這樣做可以避免不必要的分配。
ExprMutator的一個特性是ExprVisitor沒有的,一個用于緩存結果的內置備注字段。ExprMutator有一個memoizer是有道理的,知道正在緩存哪些類型的結果(即Expr),ExprVisitor的訪問方法不返回任何內容。通常,當想要將結果緩存在ExprVisitor的子類中時,需要定義緩存。
現在,如果想編寫一個類IfCollapser,用真正分支替換每個if語句,將覆蓋IfNode的VisitExpr_:
Expr ExprMutator::VisitExpr_(const IfNode* op) {
return this->Mutate(op->true_branch);
}
返回的表達式不一定是IfNode,因為返回類型是Expr?,F在,創建公共接口:
Expr CollapseIfs(const Expr& expr) final {
return this->Mutate(expr);
}
有了這個mutator,不需要做任何記錄,但仍然希望遵循使用描述性方法,作為接口的慣例。
示例:常量折疊
為了更好地理解編寫pass,將以常量折疊pass(見src/relay/transforms/fold_constant.cc)為指導,因為是一個相對簡單的過程,包含了兩種類型的遍歷。
常量折疊涉及計算程序中,只涉及常量值的表達式,然后用計算結果替換這些表達式。此pass的目標是預先加載所有可以進行的計算。為了實現這一點,常量折疊pass使用訪客(ConstantChecker)和變異子(ConstantFolder)。
ConstantChecker Visitor
此訪問者用于檢查表達式是否為常量。在Relay中,如果表達式是常量節點或只有常量字段的元組節點,將定義為常量。
使用一個memo_字段,從節點映射是否為常量,緩存這些結果。以下是ConstantChecker中的VisitExpr_定義。
void VisitExpr_(const ConstantNode* n) final {
memo_[GetRef(n)] = true;
}
void VisitExpr_(const TupleNode* n) final {
bool result = true;
for (const auto& field : n->fields) {
if (!Check(field)) {
result = false;
break;
}
}
memo_[GetRef(n)] = result;
}
用于協調這些定義的記錄是一個檢查方法,返回給定表達式是否被視為常量。
bool Check(const Expr& expr) {
const auto it = memo_.find(expr);
if (it != memo_.end())
return it->second;
VisitExpr(expr);
return memo_[expr];
}
不會為遇到的每個節點修改memo_;相反,只在遇到的節點可能是常量時修改memo_。然后,當memo_不包含expr時,依賴于默認值為false。
ConstantFolder Mutator常量折疊變異體
該mutator變異器執行大部分常量折疊pass,在內部使用ConstantChecker。在Relay中,常量折疊涉及三種節點類型:LetNode、TupleItemGetNode和CallNode。在下面的段落中,將解釋pass中每個角色的作用。
Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as()) {
memo_[op->var] = value;
return this->Mutate(op->body);
} else {
Var var = Downcast(this->Mutate(op->var));
Expr body = this->Mutate(op->body);
if (var.same_as(op->var) &&
value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef(op);
} else {
return Let(var, value, body);
}
}
}
在LetNode的情況下,首先嘗試對表達式中綁定的值進行常量折疊。填充memo_,返回訪問主體的結果,將綁定值傳播到主體中的使用站點。如果不能將綁定值常量化,將模擬默認實現。
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr res = ExprMutator::VisitExpr_(op);
op = res.as();
if (const auto* tuple = op->tuple.as()) {
return tuple->fields[op->index];
} else {
return res;
}
}
在TupleItemGetNode的情況下,檢查op->tuple字段是否是TupleNode。用op->index指向的元組字段替換元組get。需要檢查的原因是op->tuple可能計算為一個tuple,本身不是tuple。
Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttrMap(“TOpIsStateful”);
Expr res = ExprMutator::VisitExpr_(call);
call = res.as();
// We don’t constant fold function with zero arguments.
// This is a heuristic that is useful.
// For example it is harmful to fold ones(shape=(4, 5)).
if (call->args.size() == 0) return res;
const OpNode* op = call->op.as();
if (op == nullptr) return res;
// skip stateful ops.
if (op_stateful.get(GetRef(op), false)) return res;
bool all_const_args = true;
for (Expr arg : call->args) {
if (!checker_.Check(arg)) {
all_const_args = false;
}
}
if (all_const_args) {
return ConstEvaluate(res);
} else {
return res;
}
}
在CallNode的情況下,首先使用ExprMutator的VisitExpr_訪問調用,將調用的所有字段折疊起來。使用ExprMutator::VisitExpr_uu而不是VisitExpr,因為希望繞過vtable(避免無限循環),使用ExprMutator提供的默認實現。然后,僅在所有參數都是常量時(使用ConstantChecker)計算調用。對調用求值會產生一個值,因此使用help方法ValueToExpr,將求值表達式放回AST中。
現在,為常量文件夾構造一個更方便的接口FoldConstant。FoldConstant是ConstantFolder類外的一個獨立函數,接受一個表達式,在內部創建和使用ConstantFolder實例(完整定義可在src/relay/transforms/fold_constant.cc中找到)。
向pass管理器注冊pass
參閱:ref:pass infra上的文檔,了解有關此主題的更多詳細信息。
編寫AST遍歷器后,可以使用以下代碼,將pass注冊為TVM API端點:
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 2, “FoldConstant”, {});
}
} // namespace transform
如果將上述代碼生成的Pass對象,提供給Pass基礎設施,將確保將AST遍歷應用于給定Relay模塊中的每個函數,這是常量折疊Pass的預期行為(它應盡可能折疊所有常數)。
函數CreateFunctionPass允許注冊pass的優化級別(在本例中為2),該級別可用于根據pass的通用工具、pass名稱以及pass的任何依賴項,將pas分組。pass的依賴項以任何pass的列表的形式給出,這些pass的結果是運行當前pass所必需的。FoldConstant沒有任何依賴項,但許多Relay pass確實依賴于類型信息,因此InferType是一個常見的依賴項;另一些可能依賴于程序,通過ToANormalForm pass處于A-normal形式。
注意,PassContext對象包含pass用于錯誤報告和配置選項的信息;FoldConstant不需要此信息,但其它pass可能會引用PassContext對象。
現在可以通過pass基礎設施調用pass,不過最好為pass添加一個Python綁定,如下面的代碼片段所示:
TVM_REGISTER_GLOBAL(“relay._transform.FoldConstant”)
.set_body_typed(FoldConstant);
一旦以上述方式定義了Pass對象,就可以使用Pass基礎設施的順序構造調用,該構造獲取一個Pass列表,按順序應用于Relay模塊,從而獲得轉換后的模塊。例如,下面的代碼將FoldConstant和ToANormalForm pass(一個接一個),應用于mod中的每個函數,獲得一個新模塊。
seq = transform.Sequential([
relay.transform.FoldConstant(),
relay.transform.ToANormalForm()
])
new_mod = seq(mod)
有關注冊的更多詳細信息,可以在TVM Runtime系統中找到,有關pass manager接口的更多信息可以在pass基礎設施中找到。Relay的標準pass在include/tvm/Relay/transform.h中列出,在src/Relay/transforms/中實現。
參考鏈接:
https://tvm.apache.org/docs/dev/how_to/relay_add_pass.html
總結
以上是生活随笔為你收集整理的将编译器pass添加到Relay的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 如何将算子添加到Relay
- 下一篇: TVM cmake示例展示