Unverified Commit 28ee806d by Zhi Committed by GitHub

[relay][external codegen] outline and inline lifted functions for external codegen (#4996)

* outline and inline lifted functions for external codegen

* add batch_norm test

* test batch_norm inline
parent fcf8420a
...@@ -334,6 +334,13 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -334,6 +334,13 @@ class RelayBuildModule : public runtime::ModuleNode {
// Fuse the operations if it is needed. // Fuse the operations if it is needed.
relay_module = transform::FuseOps()(relay_module); relay_module = transform::FuseOps()(relay_module);
relay_module = transform::InferType()(relay_module); relay_module = transform::InferType()(relay_module);
// Inline the functions that have been lifted by the module scope.
//
// TODO(@zhiics) Note that we need to be careful about the subgraphs with
// global function calls. We should make sure that these callees are also
// inline functions. However, this should be very unlikely for accelerators
// and vendor-provided libraries. So we don't handle for now.
relay_module = transform::Inline()(relay_module);
CHECK(relay_module.defined()); CHECK(relay_module.defined());
return relay_module; return relay_module;
......
...@@ -921,6 +921,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -921,6 +921,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives()); pass_seqs.push_back(transform::InlinePrimitives());
// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
// and we use these ops to invoke the symbols in the module generated by
// external codegen.
pass_seqs.push_back(transform::Inline());
// Manifest the allocations. // Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation. // Compute away possibly introduced constant computation.
......
...@@ -122,6 +122,7 @@ struct PrimitiveInliner : ExprMutator { ...@@ -122,6 +122,7 @@ struct PrimitiveInliner : ExprMutator {
auto global = pair.first; auto global = pair.first;
auto base_func = pair.second; auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) { if (auto* n = base_func.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
auto func = GetRef<Function>(n); auto func = GetRef<Function>(n);
DLOG(INFO) << "Before inlining primitives: " << global DLOG(INFO) << "Before inlining primitives: " << global
......
...@@ -189,6 +189,7 @@ class LambdaLifter : public ExprMutator { ...@@ -189,6 +189,7 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions; auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) { for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) { if (auto* n = pair.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
auto func = GetRef<Function>(n); auto func = GetRef<Function>(n);
func = FunctionNode::make(func->params, func = FunctionNode::make(func->params,
VisitExpr(func->body), VisitExpr(func->body),
......
...@@ -110,6 +110,8 @@ class AnnotationChecker : public ExprVisitor { ...@@ -110,6 +110,8 @@ class AnnotationChecker : public ExprVisitor {
*/ */
class Partitioner : public ExprMutator { class Partitioner : public ExprMutator {
public: public:
explicit Partitioner(const IRModule& module) : module_(module) {}
std::shared_ptr<Subgraph> GetSubgraph(const Expr node) { std::shared_ptr<Subgraph> GetSubgraph(const Expr node) {
for (auto candidate : this->subgraphs_) { for (auto candidate : this->subgraphs_) {
if (candidate->nodes.find(node) != candidate->nodes.end()) { if (candidate->nodes.find(node) != candidate->nodes.end()) {
...@@ -163,8 +165,10 @@ class Partitioner : public ExprMutator { ...@@ -163,8 +165,10 @@ class Partitioner : public ExprMutator {
// Replace the begin annotation with an external call input variable. // Replace the begin annotation with an external call input variable.
auto compiler_attrs = call->attrs.as<CompilerAttrs>(); auto compiler_attrs = call->attrs.as<CompilerAttrs>();
// The type of the created variable is the same as the compiler_begin
// node.
auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++), auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++),
input_expr->checked_type_); call->checked_type_);
// Find the corresponding subgraph and add the argument. // Find the corresponding subgraph and add the argument.
auto subgraph = GetSubgraph(GetRef<Call>(call)); auto subgraph = GetSubgraph(GetRef<Call>(call));
...@@ -182,7 +186,7 @@ class Partitioner : public ExprMutator { ...@@ -182,7 +186,7 @@ class Partitioner : public ExprMutator {
auto compiler_attrs = call->attrs.as<CompilerAttrs>(); auto compiler_attrs = call->attrs.as<CompilerAttrs>();
// Check if the argument already belongs to an exist subgraph // Check if the argument already belongs to an existing subgraph
auto subgraph = GetSubgraph(call->args[0]); auto subgraph = GetSubgraph(call->args[0]);
if (!subgraph) { if (!subgraph) {
auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>()); auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>());
...@@ -207,16 +211,28 @@ class Partitioner : public ExprMutator { ...@@ -207,16 +211,28 @@ class Partitioner : public ExprMutator {
} }
auto subgraph_func = auto subgraph_func =
FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs()); FunctionNode::make(params, input, call->checked_type_, {}, Attrs());
Expr arg0 = call->args[0];
std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id); std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id);
subgraph_func = subgraph_func =
FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name)); FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1)); subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler, subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler,
tvm::tir::StringImmNode::make(compiler_attrs->compiler)); tvm::tir::StringImmNode::make(compiler_attrs->compiler));
return CallNode::make(subgraph_func, args); subgraph_func = FunctionSetAttr(subgraph_func, attr::kInline, tvm::Integer(1));
CHECK(!module_->ContainGlobalVar(name))
<< "Global function " << name << " already exists";
// Create a global function and add it to the IRModule for the subgraph.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent relay
// function level passes (i.e. simplify inference and fusion) optimizing it.
GlobalVar glob_func(name);
module_->Add(glob_func, subgraph_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = CallNode::make(glob_func, args);
ret->checked_type_ = call->checked_type_;
return std::move(ret);
} }
} }
...@@ -330,50 +346,39 @@ class Partitioner : public ExprMutator { ...@@ -330,50 +346,39 @@ class Partitioner : public ExprMutator {
} }
} }
IRModule Partition() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
return module_;
}
private: private:
int var_id_{0}; int var_id_{0};
int subgraph_id_{0}; int subgraph_id_{0};
std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_; std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_;
IRModule module_;
}; };
/*!
* \brief TODO(@zhiics, @comaniac) Combine parallel regions that belong to
* the same codegen backend. This reduces rounds trips between TVM and external
* backends. Likely we can borrow some ideas from operator fusion.
*
* For example, sg1 and sg2 should be combined if they belong to the same
* codegen tool in the following case.
*
* op1
* / \
* sg1 sg2
*
* |
* \|/
*
* op1
* |
* sg1_sg2
*
* where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two
* inputs that obtained from the tuple.
*/
Expr PartitionGraph(const Expr& expr) {
Partitioner part;
return part.Mutate(expr);
}
} // namespace partitioning } // namespace partitioning
namespace transform { namespace transform {
Pass PartitionGraph() { Pass PartitionGraph() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](Function f, IRModule m, PassContext pc) { [=](IRModule m, PassContext pc) {
return Downcast<Function>(partitioning::PartitionGraph(f)); return partitioning::Partitioner(m).Partition();
}; };
auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {}); auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()}); return Sequential({partitioned, InferType()});
} }
......
...@@ -298,6 +298,9 @@ IRModule ToANormalForm(const IRModule& m) { ...@@ -298,6 +298,9 @@ IRModule ToANormalForm(const IRModule& m) {
auto funcs = m->functions; auto funcs = m->functions;
for (const auto& it : funcs) { for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0); CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
}
Expr ret = Expr ret =
TransformF([&](const Expr& e) { TransformF([&](const Expr& e) {
return ToANormalFormAux(e); return ToANormalFormAux(e);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment