Unverified Commit 28ee806d by Zhi Committed by GitHub

[relay][external codegen] outline and inline lifted functions for external codegen (#4996)

* outline and inline lifted functions for external codegen

* add batch_norm test

* test batch_norm inline
parent fcf8420a
...@@ -334,6 +334,13 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -334,6 +334,13 @@ class RelayBuildModule : public runtime::ModuleNode {
// Fuse the operations if it is needed. // Fuse the operations if it is needed.
relay_module = transform::FuseOps()(relay_module); relay_module = transform::FuseOps()(relay_module);
relay_module = transform::InferType()(relay_module); relay_module = transform::InferType()(relay_module);
// Inline the functions that have been lifted by the module scope.
//
// TODO(@zhiics) Note that we need to be careful about the subgraphs with
// global function calls. We should make sure that these callees are also
// inline functions. However, this should be very unlikely for accelerators
// and vendor-provided libraries. So we don't handle for now.
relay_module = transform::Inline()(relay_module);
CHECK(relay_module.defined()); CHECK(relay_module.defined());
return relay_module; return relay_module;
......
...@@ -921,6 +921,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe ...@@ -921,6 +921,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives()); pass_seqs.push_back(transform::InlinePrimitives());
// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
// and we use these ops to invoke the symbols in the module generated by
// external codegen.
pass_seqs.push_back(transform::Inline());
// Manifest the allocations. // Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation. // Compute away possibly introduced constant computation.
......
...@@ -122,6 +122,7 @@ struct PrimitiveInliner : ExprMutator { ...@@ -122,6 +122,7 @@ struct PrimitiveInliner : ExprMutator {
auto global = pair.first; auto global = pair.first;
auto base_func = pair.second; auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) { if (auto* n = base_func.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
auto func = GetRef<Function>(n); auto func = GetRef<Function>(n);
DLOG(INFO) << "Before inlining primitives: " << global DLOG(INFO) << "Before inlining primitives: " << global
......
...@@ -189,6 +189,7 @@ class LambdaLifter : public ExprMutator { ...@@ -189,6 +189,7 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions; auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) { for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) { if (auto* n = pair.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
auto func = GetRef<Function>(n); auto func = GetRef<Function>(n);
func = FunctionNode::make(func->params, func = FunctionNode::make(func->params,
VisitExpr(func->body), VisitExpr(func->body),
......
...@@ -110,6 +110,8 @@ class AnnotationChecker : public ExprVisitor { ...@@ -110,6 +110,8 @@ class AnnotationChecker : public ExprVisitor {
*/ */
class Partitioner : public ExprMutator { class Partitioner : public ExprMutator {
public: public:
explicit Partitioner(const IRModule& module) : module_(module) {}
std::shared_ptr<Subgraph> GetSubgraph(const Expr node) { std::shared_ptr<Subgraph> GetSubgraph(const Expr node) {
for (auto candidate : this->subgraphs_) { for (auto candidate : this->subgraphs_) {
if (candidate->nodes.find(node) != candidate->nodes.end()) { if (candidate->nodes.find(node) != candidate->nodes.end()) {
...@@ -163,8 +165,10 @@ class Partitioner : public ExprMutator { ...@@ -163,8 +165,10 @@ class Partitioner : public ExprMutator {
// Replace the begin annotation with an external call input variable. // Replace the begin annotation with an external call input variable.
auto compiler_attrs = call->attrs.as<CompilerAttrs>(); auto compiler_attrs = call->attrs.as<CompilerAttrs>();
// The type of the created variable is the same as the compiler_begin
// node.
auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++), auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++),
input_expr->checked_type_); call->checked_type_);
// Find the corresponding subgraph and add the argument. // Find the corresponding subgraph and add the argument.
auto subgraph = GetSubgraph(GetRef<Call>(call)); auto subgraph = GetSubgraph(GetRef<Call>(call));
...@@ -182,7 +186,7 @@ class Partitioner : public ExprMutator { ...@@ -182,7 +186,7 @@ class Partitioner : public ExprMutator {
auto compiler_attrs = call->attrs.as<CompilerAttrs>(); auto compiler_attrs = call->attrs.as<CompilerAttrs>();
// Check if the argument already belongs to an exist subgraph // Check if the argument already belongs to an existing subgraph
auto subgraph = GetSubgraph(call->args[0]); auto subgraph = GetSubgraph(call->args[0]);
if (!subgraph) { if (!subgraph) {
auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>()); auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>());
...@@ -207,16 +211,28 @@ class Partitioner : public ExprMutator { ...@@ -207,16 +211,28 @@ class Partitioner : public ExprMutator {
} }
auto subgraph_func = auto subgraph_func =
FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs()); FunctionNode::make(params, input, call->checked_type_, {}, Attrs());
Expr arg0 = call->args[0];
std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id); std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id);
subgraph_func = subgraph_func =
FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name)); FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1)); subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler, subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler,
tvm::tir::StringImmNode::make(compiler_attrs->compiler)); tvm::tir::StringImmNode::make(compiler_attrs->compiler));
return CallNode::make(subgraph_func, args); subgraph_func = FunctionSetAttr(subgraph_func, attr::kInline, tvm::Integer(1));
CHECK(!module_->ContainGlobalVar(name))
<< "Global function " << name << " already exists";
// Create a global function and add it to the IRModule for the subgraph.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent relay
// function level passes (i.e. simplify inference and fusion) optimizing it.
GlobalVar glob_func(name);
module_->Add(glob_func, subgraph_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = CallNode::make(glob_func, args);
ret->checked_type_ = call->checked_type_;
return std::move(ret);
} }
} }
...@@ -330,50 +346,39 @@ class Partitioner : public ExprMutator { ...@@ -330,50 +346,39 @@ class Partitioner : public ExprMutator {
} }
} }
IRModule Partition() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
return module_;
}
private: private:
int var_id_{0}; int var_id_{0};
int subgraph_id_{0}; int subgraph_id_{0};
std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_; std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_;
IRModule module_;
}; };
/*!
* \brief TODO(@zhiics, @comaniac) Combine parallel regions that belong to
* the same codegen backend. This reduces rounds trips between TVM and external
* backends. Likely we can borrow some ideas from operator fusion.
*
* For example, sg1 and sg2 should be combined if they belong to the same
* codegen tool in the following case.
*
* op1
* / \
* sg1 sg2
*
* |
* \|/
*
* op1
* |
* sg1_sg2
*
* where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two
* inputs that obtained from the tuple.
*/
Expr PartitionGraph(const Expr& expr) {
Partitioner part;
return part.Mutate(expr);
}
} // namespace partitioning } // namespace partitioning
namespace transform { namespace transform {
Pass PartitionGraph() { Pass PartitionGraph() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func = runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](Function f, IRModule m, PassContext pc) { [=](IRModule m, PassContext pc) {
return Downcast<Function>(partitioning::PartitionGraph(f)); return partitioning::Partitioner(m).Partition();
}; };
auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {}); auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()}); return Sequential({partitioned, InferType()});
} }
......
...@@ -298,6 +298,9 @@ IRModule ToANormalForm(const IRModule& m) { ...@@ -298,6 +298,9 @@ IRModule ToANormalForm(const IRModule& m) {
auto funcs = m->functions; auto funcs = m->functions;
for (const auto& it : funcs) { for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0); CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
}
Expr ret = Expr ret =
TransformF([&](const Expr& e) { TransformF([&](const Expr& e) {
return ToANormalFormAux(e); return ToANormalFormAux(e);
......
...@@ -18,14 +18,12 @@ ...@@ -18,14 +18,12 @@
import os import os
import sys import sys
import numpy as np import numpy as np
import pytest
import tvm import tvm
from tvm import te
import tvm.relay.testing import tvm.relay.testing
import tvm.relay.transform as transform
from tvm import relay from tvm import relay
from tvm import runtime from tvm import runtime
from tvm.relay import transform
from tvm.contrib import util from tvm.contrib import util
from tvm.relay.annotation import compiler_begin, compiler_end from tvm.relay.annotation import compiler_begin, compiler_end
from tvm.relay.expr_functor import ExprMutator from tvm.relay.expr_functor import ExprMutator
...@@ -189,7 +187,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -189,7 +187,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
return lib return lib
def check_vm_result(): def check_vm_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): with relay.build_config(opt_level=3):
exe = relay.vm.compile(mod, target=target, params=params) exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save() code, lib = exe.save()
lib = update_lib(lib) lib = update_lib(lib)
...@@ -200,7 +198,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -200,7 +198,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def check_graph_runtime_result(): def check_graph_runtime_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): with relay.build_config(opt_level=3):
json, lib, param = relay.build(mod, target=target, params=params) json, lib, param = relay.build(mod, target=target, params=params)
lib = update_lib(lib) lib = update_lib(lib)
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
...@@ -297,6 +295,7 @@ def test_extern_ccompiler_single_op(): ...@@ -297,6 +295,7 @@ def test_extern_ccompiler_single_op():
def test_extern_ccompiler_default_ops(): def test_extern_ccompiler_default_ops():
def expected(): def expected():
mod = tvm.IRModule()
x = relay.var("x", shape=(8, 8)) x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8)) y = relay.var("y", shape=(8, 8))
x0 = relay.var("x0", shape=(8, 8)) x0 = relay.var("x0", shape=(8, 8))
...@@ -305,11 +304,14 @@ def test_extern_ccompiler_default_ops(): ...@@ -305,11 +304,14 @@ def test_extern_ccompiler_default_ops():
# Function that uses C compiler # Function that uses C compiler
func = relay.Function([x0, y0], add) func = relay.Function([x0, y0], add)
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler", func = func.set_attribute("Compiler",
tvm.tir.StringImm("ccompiler")) tvm.tir.StringImm("ccompiler"))
func = func.set_attribute("ExternalSymbol", func = func.set_attribute("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0")) tvm.tir.StringImm("ccompiler_0"))
add_call = relay.Call(func, [x, y]) glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y])
# Function that uses default compiler. Ops are fused in this function. # Function that uses default compiler. Ops are fused in this function.
p0 = relay.var("p0", shape=(8, 8)) p0 = relay.var("p0", shape=(8, 8))
log = relay.log(p0) log = relay.log(p0)
...@@ -320,7 +322,6 @@ def test_extern_ccompiler_default_ops(): ...@@ -320,7 +322,6 @@ def test_extern_ccompiler_default_ops():
tvm.tir.IntImm("int32", 1)) tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call]) fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call) main = relay.Function([x, y], fused_call)
mod = tvm.IRModule()
mod["main"] = main mod["main"] = main
return mod return mod
...@@ -371,8 +372,43 @@ def test_extern_dnnl(): ...@@ -371,8 +372,43 @@ def test_extern_dnnl():
dtype = 'float32' dtype = 'float32'
ishape = (1, 32, 14, 14) ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3) w1shape = (32, 1, 3, 3)
data = relay.var('data', shape=(ishape), dtype=dtype)
weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype) def expected():
data0 = relay.var("data", shape=(ishape), dtype=dtype)
input0 = relay.var("input0", shape=(w1shape), dtype=dtype)
input1 = relay.var("input1", shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(data0,
input0,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
input1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
func = relay.Function([data0, input0, input1], out)
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler", tvm.tir.StringImm("dnnl"))
func = func.set_attribute("ExternalSymbol",
tvm.tir.StringImm("dnnl_0"))
glb_var = relay.GlobalVar("dnnl_0")
mod = tvm.IRModule()
mod[glb_var] = func
data = relay.var("data", shape=(ishape), dtype=dtype)
weight = relay.var("input", shape=(w1shape), dtype=dtype)
main_f = relay.Function([data, weight], glb_var(data, weight, weight))
mod["main"] = main_f
return mod
def get_func():
data = relay.var("data", shape=(ishape), dtype=dtype)
weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(data, depthwise_conv2d_1 = relay.nn.conv2d(data,
weight1, weight1,
kernel_size=(3, 3), kernel_size=(3, 3),
...@@ -385,14 +421,16 @@ def test_extern_dnnl(): ...@@ -385,14 +421,16 @@ def test_extern_dnnl():
groups=32) groups=32)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
f = relay.Function([data, weight1], out) return relay.Function([data, weight1], out)
mod = tvm.IRModule() mod = tvm.IRModule()
mod['main'] = WholeGraphAnnotator('dnnl').visit(f) mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func())
mod = transform.PartitionGraph()(mod) mod = transform.PartitionGraph()(mod)
assert relay.alpha_equal(mod, expected())
ref_mod = tvm.IRModule() ref_mod = tvm.IRModule()
ref_mod['main'] = f ref_mod["main"] = get_func()
i_data = np.random.uniform(0, 1, ishape).astype(dtype) i_data = np.random.uniform(0, 1, ishape).astype(dtype)
w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
...@@ -427,6 +465,175 @@ def test_extern_dnnl_mobilenet(): ...@@ -427,6 +465,175 @@ def test_extern_dnnl_mobilenet():
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
def test_function_lifting():
def partition():
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32"))
conv = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=(3, 3),
channels=16,
padding=(1, 1))
bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean,
bn_mvar)
func = relay.Function([data, weight, bn_gamma, bn_beta, bn_mmean,
bn_mvar], bn_output.astuple())
mod = tvm.IRModule()
mod["main"] = func
op_list = ["nn.batch_norm", "nn.conv2d"]
mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
opt_pass = transform.Sequential([
transform.InferType(),
transform.PartitionGraph(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.AlterOpLayout(),
])
with relay.build_config(opt_level=3):
mod = opt_pass(mod)
return mod
def expected():
# function for batch_norm
data0 = relay.var("data0", relay.TensorType((1, 16, 224, 224),
"float32"))
mod = tvm.IRModule()
bn_gamma = relay.var("bn_gamma1", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("bn_beta1", relay.TensorType((16, ), "float32"))
bn_mmean = relay.var("bn_mean1", relay.TensorType((16, ), "float32"))
bn_mvar = relay.var("bn_var1", relay.TensorType((16, ), "float32"))
bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
bn.astuple())
func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Compiler",
tvm.tir.StringImm("test_compiler"))
func0 = func0.set_attribute("ExternalSymbol",
tvm.tir.StringImm("test_compiler_0"))
gv0 = relay.GlobalVar("test_compiler_0")
mod[gv0] = func0
# function for conv2d
data1 = relay.var("data1", relay.TensorType((1, 3, 224, 224), "float32"))
weight1 = relay.var("weight1", relay.TensorType((16, 3, 3, 3), "float32"))
conv = relay.nn.conv2d(
data=data1,
weight=weight1,
kernel_size=(3, 3),
channels=16,
padding=(1, 1))
func1 = relay.Function([data1, weight1], conv)
func1 = func1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.set_attribute("Compiler",
tvm.tir.StringImm("test_compiler"))
func1 = func1.set_attribute("ExternalSymbol",
tvm.tir.StringImm("test_compiler_1"))
gv1 = relay.GlobalVar("test_compiler_1")
mod[gv1] = func1
# main function
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma0 = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
bn_beta0 = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
bn_mmean0 = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
bn_mvar0 = relay.var("bn_var", relay.TensorType((16, ), "float32"))
call1 = gv1(data, weight)
call0 = gv0(call1, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0)
mod["main"] = relay.Function([data, weight, bn_gamma0, bn_beta0, bn_mmean0,
bn_mvar0], call0)
mod = transform.InferType()(mod)
return mod
partitioned = partition()
ref_mod = expected()
assert relay.analysis.alpha_equal(partitioned, ref_mod)
def test_function_lifting_inline():
def partition():
data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32"))
bn_output = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean,
bn_mvar)
func = relay.Function([data, bn_gamma, bn_beta, bn_mmean,
bn_mvar], bn_output.astuple())
mod = tvm.IRModule()
mod["main"] = func
op_list = ["nn.batch_norm", "nn.conv2d"]
mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
opt_pass = transform.Sequential([
transform.InferType(),
transform.PartitionGraph(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.AlterOpLayout(),
transform.Inline(),
])
with relay.build_config(opt_level=3):
mod = opt_pass(mod)
return mod
def expected():
# function for batch_norm
data0 = relay.var("data0", relay.TensorType((1, 16, 224, 224),
"float32"))
mod = tvm.IRModule()
bn_gamma = relay.var("bn_gamma1", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("bn_beta1", relay.TensorType((16, ), "float32"))
bn_mmean = relay.var("bn_mean1", relay.TensorType((16, ), "float32"))
bn_mvar = relay.var("bn_var1", relay.TensorType((16, ), "float32"))
bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
bn.astuple())
func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Compiler",
tvm.tir.StringImm("test_compiler"))
func0 = func0.set_attribute("ExternalSymbol",
tvm.tir.StringImm("test_compiler_0"))
# main function
data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
bn_gamma0 = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
bn_beta0 = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
bn_mmean0 = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
bn_mvar0 = relay.var("bn_var", relay.TensorType((16, ), "float32"))
call0 = func0(data, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0)
mod["main"] = relay.Function([data, bn_gamma0, bn_beta0, bn_mmean0,
bn_mvar0], call0)
mod = transform.InferType()(mod)
return mod
partitioned = partition()
ref_mod = expected()
assert relay.analysis.alpha_equal(partitioned, ref_mod)
if __name__ == "__main__": if __name__ == "__main__":
test_multi_node_compiler() test_multi_node_compiler()
test_extern_ccompiler_single_op() test_extern_ccompiler_single_op()
...@@ -434,3 +641,5 @@ if __name__ == "__main__": ...@@ -434,3 +641,5 @@ if __name__ == "__main__":
test_extern_ccompiler() test_extern_ccompiler()
test_extern_dnnl() test_extern_dnnl()
test_extern_dnnl_mobilenet() test_extern_dnnl_mobilenet()
test_function_lifting()
test_function_lifting_inline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment