Unverified Commit 28ee806d by Zhi Committed by GitHub

[relay][external codegen] outline and inline lifted functions for external codegen (#4996)

* outline and inline lifted functions for external codegen

* add batch_norm test

* test batch_norm inline
parent fcf8420a
......@@ -334,6 +334,13 @@ class RelayBuildModule : public runtime::ModuleNode {
// Fuse the operations if it is needed.
relay_module = transform::FuseOps()(relay_module);
relay_module = transform::InferType()(relay_module);
// Inline the functions that have been lifted by the module scope.
//
// TODO(@zhiics) Note that we need to be careful about the subgraphs with
// global function calls. We should make sure that these callees are also
// inline functions. However, this should be very unlikely for accelerators
// and vendor-provided libraries. So we don't handle for now.
relay_module = transform::Inline()(relay_module);
CHECK(relay_module.defined());
return relay_module;
......
......@@ -921,6 +921,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());
// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
// and we use these ops to invoke the symbols in the module generated by
// external codegen.
pass_seqs.push_back(transform::Inline());
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
......
......@@ -122,6 +122,7 @@ struct PrimitiveInliner : ExprMutator {
auto global = pair.first;
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
auto func = GetRef<Function>(n);
DLOG(INFO) << "Before inlining primitives: " << global
......
......@@ -189,6 +189,7 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
auto func = GetRef<Function>(n);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
......
......@@ -110,6 +110,8 @@ class AnnotationChecker : public ExprVisitor {
*/
class Partitioner : public ExprMutator {
public:
explicit Partitioner(const IRModule& module) : module_(module) {}
std::shared_ptr<Subgraph> GetSubgraph(const Expr node) {
for (auto candidate : this->subgraphs_) {
if (candidate->nodes.find(node) != candidate->nodes.end()) {
......@@ -163,8 +165,10 @@ class Partitioner : public ExprMutator {
// Replace the begin annotation with an external call input variable.
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
// The type of the created variable is the same as the compiler_begin
// node.
auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++),
input_expr->checked_type_);
call->checked_type_);
// Find the corresponding subgraph and add the argument.
auto subgraph = GetSubgraph(GetRef<Call>(call));
......@@ -182,7 +186,7 @@ class Partitioner : public ExprMutator {
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
// Check if the argument already belongs to an exist subgraph
// Check if the argument already belongs to an existing subgraph
auto subgraph = GetSubgraph(call->args[0]);
if (!subgraph) {
auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>());
......@@ -207,16 +211,28 @@ class Partitioner : public ExprMutator {
}
auto subgraph_func =
FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs());
FunctionNode::make(params, input, call->checked_type_, {}, Attrs());
Expr arg0 = call->args[0];
std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id);
subgraph_func =
FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler,
tvm::tir::StringImmNode::make(compiler_attrs->compiler));
return CallNode::make(subgraph_func, args);
subgraph_func = FunctionSetAttr(subgraph_func, attr::kInline, tvm::Integer(1));
CHECK(!module_->ContainGlobalVar(name))
<< "Global function " << name << " already exists";
// Create a global function and add it to the IRModule for the subgraph.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent relay
// function level passes (i.e. simplify inference and fusion) optimizing it.
GlobalVar glob_func(name);
module_->Add(glob_func, subgraph_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = CallNode::make(glob_func, args);
ret->checked_type_ = call->checked_type_;
return std::move(ret);
}
}
......@@ -330,50 +346,39 @@ class Partitioner : public ExprMutator {
}
}
IRModule Partition() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
return module_;
}
private:
int var_id_{0};
int subgraph_id_{0};
std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_;
IRModule module_;
};
/*!
* \brief TODO(@zhiics, @comaniac) Combine parallel regions that belong to
* the same codegen backend. This reduces rounds trips between TVM and external
* backends. Likely we can borrow some ideas from operator fusion.
*
* For example, sg1 and sg2 should be combined if they belong to the same
* codegen tool in the following case.
*
* op1
* / \
* sg1 sg2
*
* |
* \|/
*
* op1
* |
* sg1_sg2
*
* where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two
* inputs that obtained from the tuple.
*/
Expr PartitionGraph(const Expr& expr) {
Partitioner part;
return part.Mutate(expr);
}
} // namespace partitioning
namespace transform {
Pass PartitionGraph() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(partitioning::PartitionGraph(f));
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) {
return partitioning::Partitioner(m).Partition();
};
auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {});
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()});
}
......
......@@ -298,6 +298,9 @@ IRModule ToANormalForm(const IRModule& m) {
auto funcs = m->functions;
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
}
Expr ret =
TransformF([&](const Expr& e) {
return ToANormalFormAux(e);
......
......@@ -18,14 +18,12 @@
import os
import sys
import numpy as np
import pytest
import tvm
from tvm import te
import tvm.relay.testing
import tvm.relay.transform as transform
from tvm import relay
from tvm import runtime
from tvm.relay import transform
from tvm.contrib import util
from tvm.relay.annotation import compiler_begin, compiler_end
from tvm.relay.expr_functor import ExprMutator
......@@ -189,7 +187,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
return lib
def check_vm_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
with relay.build_config(opt_level=3):
exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save()
lib = update_lib(lib)
......@@ -200,7 +198,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def check_graph_runtime_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
with relay.build_config(opt_level=3):
json, lib, param = relay.build(mod, target=target, params=params)
lib = update_lib(lib)
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
......@@ -297,6 +295,7 @@ def test_extern_ccompiler_single_op():
def test_extern_ccompiler_default_ops():
def expected():
mod = tvm.IRModule()
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
x0 = relay.var("x0", shape=(8, 8))
......@@ -305,11 +304,14 @@ def test_extern_ccompiler_default_ops():
# Function that uses C compiler
func = relay.Function([x0, y0], add)
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler",
tvm.tir.StringImm("ccompiler"))
func = func.set_attribute("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0"))
add_call = relay.Call(func, [x, y])
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y])
# Function that uses default compiler. Ops are fused in this function.
p0 = relay.var("p0", shape=(8, 8))
log = relay.log(p0)
......@@ -320,7 +322,6 @@ def test_extern_ccompiler_default_ops():
tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call)
mod = tvm.IRModule()
mod["main"] = main
return mod
......@@ -371,28 +372,65 @@ def test_extern_dnnl():
dtype = 'float32'
ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3)
data = relay.var('data', shape=(ishape), dtype=dtype)
weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(data,
weight1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
weight1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
f = relay.Function([data, weight1], out)
def expected():
data0 = relay.var("data", shape=(ishape), dtype=dtype)
input0 = relay.var("input0", shape=(w1shape), dtype=dtype)
input1 = relay.var("input1", shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(data0,
input0,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
input1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
func = relay.Function([data0, input0, input1], out)
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler", tvm.tir.StringImm("dnnl"))
func = func.set_attribute("ExternalSymbol",
tvm.tir.StringImm("dnnl_0"))
glb_var = relay.GlobalVar("dnnl_0")
mod = tvm.IRModule()
mod[glb_var] = func
data = relay.var("data", shape=(ishape), dtype=dtype)
weight = relay.var("input", shape=(w1shape), dtype=dtype)
main_f = relay.Function([data, weight], glb_var(data, weight, weight))
mod["main"] = main_f
return mod
def get_func():
data = relay.var("data", shape=(ishape), dtype=dtype)
weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(data,
weight1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
weight1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
return relay.Function([data, weight1], out)
mod = tvm.IRModule()
mod['main'] = WholeGraphAnnotator('dnnl').visit(f)
mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func())
mod = transform.PartitionGraph()(mod)
assert relay.alpha_equal(mod, expected())
ref_mod = tvm.IRModule()
ref_mod['main'] = f
ref_mod["main"] = get_func()
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
......@@ -427,6 +465,175 @@ def test_extern_dnnl_mobilenet():
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
def test_function_lifting():
def partition():
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32"))
conv = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=(3, 3),
channels=16,
padding=(1, 1))
bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean,
bn_mvar)
func = relay.Function([data, weight, bn_gamma, bn_beta, bn_mmean,
bn_mvar], bn_output.astuple())
mod = tvm.IRModule()
mod["main"] = func
op_list = ["nn.batch_norm", "nn.conv2d"]
mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
opt_pass = transform.Sequential([
transform.InferType(),
transform.PartitionGraph(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.AlterOpLayout(),
])
with relay.build_config(opt_level=3):
mod = opt_pass(mod)
return mod
def expected():
# function for batch_norm
data0 = relay.var("data0", relay.TensorType((1, 16, 224, 224),
"float32"))
mod = tvm.IRModule()
bn_gamma = relay.var("bn_gamma1", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("bn_beta1", relay.TensorType((16, ), "float32"))
bn_mmean = relay.var("bn_mean1", relay.TensorType((16, ), "float32"))
bn_mvar = relay.var("bn_var1", relay.TensorType((16, ), "float32"))
bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
bn.astuple())
func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Compiler",
tvm.tir.StringImm("test_compiler"))
func0 = func0.set_attribute("ExternalSymbol",
tvm.tir.StringImm("test_compiler_0"))
gv0 = relay.GlobalVar("test_compiler_0")
mod[gv0] = func0
# function for conv2d
data1 = relay.var("data1", relay.TensorType((1, 3, 224, 224), "float32"))
weight1 = relay.var("weight1", relay.TensorType((16, 3, 3, 3), "float32"))
conv = relay.nn.conv2d(
data=data1,
weight=weight1,
kernel_size=(3, 3),
channels=16,
padding=(1, 1))
func1 = relay.Function([data1, weight1], conv)
func1 = func1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.set_attribute("Compiler",
tvm.tir.StringImm("test_compiler"))
func1 = func1.set_attribute("ExternalSymbol",
tvm.tir.StringImm("test_compiler_1"))
gv1 = relay.GlobalVar("test_compiler_1")
mod[gv1] = func1
# main function
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma0 = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
bn_beta0 = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
bn_mmean0 = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
bn_mvar0 = relay.var("bn_var", relay.TensorType((16, ), "float32"))
call1 = gv1(data, weight)
call0 = gv0(call1, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0)
mod["main"] = relay.Function([data, weight, bn_gamma0, bn_beta0, bn_mmean0,
bn_mvar0], call0)
mod = transform.InferType()(mod)
return mod
partitioned = partition()
ref_mod = expected()
assert relay.analysis.alpha_equal(partitioned, ref_mod)
def test_function_lifting_inline():
def partition():
data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32"))
bn_output = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean,
bn_mvar)
func = relay.Function([data, bn_gamma, bn_beta, bn_mmean,
bn_mvar], bn_output.astuple())
mod = tvm.IRModule()
mod["main"] = func
op_list = ["nn.batch_norm", "nn.conv2d"]
mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
opt_pass = transform.Sequential([
transform.InferType(),
transform.PartitionGraph(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.AlterOpLayout(),
transform.Inline(),
])
with relay.build_config(opt_level=3):
mod = opt_pass(mod)
return mod
def expected():
# function for batch_norm
data0 = relay.var("data0", relay.TensorType((1, 16, 224, 224),
"float32"))
mod = tvm.IRModule()
bn_gamma = relay.var("bn_gamma1", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("bn_beta1", relay.TensorType((16, ), "float32"))
bn_mmean = relay.var("bn_mean1", relay.TensorType((16, ), "float32"))
bn_mvar = relay.var("bn_var1", relay.TensorType((16, ), "float32"))
bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
bn.astuple())
func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Compiler",
tvm.tir.StringImm("test_compiler"))
func0 = func0.set_attribute("ExternalSymbol",
tvm.tir.StringImm("test_compiler_0"))
# main function
data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
bn_gamma0 = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
bn_beta0 = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
bn_mmean0 = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
bn_mvar0 = relay.var("bn_var", relay.TensorType((16, ), "float32"))
call0 = func0(data, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0)
mod["main"] = relay.Function([data, bn_gamma0, bn_beta0, bn_mmean0,
bn_mvar0], call0)
mod = transform.InferType()(mod)
return mod
partitioned = partition()
ref_mod = expected()
assert relay.analysis.alpha_equal(partitioned, ref_mod)
if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
......@@ -434,3 +641,5 @@ if __name__ == "__main__":
test_extern_ccompiler()
test_extern_dnnl()
test_extern_dnnl_mobilenet()
test_function_lifting()
test_function_lifting_inline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment