Unverified Commit 04499665 by mbaret Committed by GitHub

[RELAY] Fixes to MergeCompilerRegions (#5195)

* [RELAY] Fixed issues with MergeCompilerRegions

This PR addresses a few outstanding issues with
the implementation of MergeCompilerRegions. In
particular, it now handles TupleGetItem nodes properly
and other minor bugs related to region merging have
been fixed.

Change-Id: I07783afc56183a6f798a510209f23b0a5f252255

* Fixed issue using pre-merged regions

Change-Id: I0a844ac59bda1089ae0c67cef52f0b0c7ab2cbd7

* Removed some debugging logic

Change-Id: Ib6f2eede6f38bbb270073eb8d4c4dc19f60832c6

* Remove default annotations

Change-Id: I9b7696a51c95871491cbea33c40f92ec327e417f

* Annotate default 'if's

Change-Id: I0098bd1bf6788dd6366810dcefa84f1ebbffaab0

* Clang format

Change-Id: I944365cd3080a97a9261f643a8f1efa5a63cf82b

* Use src/dest in merge

Change-Id: Ie43113492bda8f1ce63eaf9615cb645bb9e2ee86

* Fixed partition test

Change-Id: I46f9e349b1a813a9140f7e4f8a2241687e2df73b

* Removed comments

Change-Id: I309afdd1951d7e796e41d13788aa487707e0ac4c
parent 2f41a396
...@@ -70,12 +70,12 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, ...@@ -70,12 +70,12 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
regions_.erase(src); regions_.erase(src);
} }
void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) { void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) {
auto region2 = GetRegion(expr); auto src = GetRegion(expr);
if (region2.defined()) { if (src.defined()) {
MergeRegions(region, region2); MergeRegions(src, dest);
} else { } else {
region->nodes.insert(expr); dest->nodes.insert(expr);
} }
} }
......
...@@ -178,10 +178,10 @@ class AnnotatedRegionSetNode : public Object { ...@@ -178,10 +178,10 @@ class AnnotatedRegionSetNode : public Object {
/*! /*!
* \brief Add an expression to a region. * \brief Add an expression to a region.
* *
* \param region The region to add the expression to. * \param dest The region to add the expression to.
* \param expr The expression. * \param expr The expression.
*/ */
void AddToRegion(AnnotatedRegion region, const Expr& expr); void AddToRegion(AnnotatedRegion dest, const Expr& expr);
/*! /*!
* \brief Make a new region. * \brief Make a new region.
......
...@@ -32,6 +32,9 @@ namespace tvm { ...@@ -32,6 +32,9 @@ namespace tvm {
namespace relay { namespace relay {
namespace annotate_target { namespace annotate_target {
// Cache compiler_begin op for equivalence check.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
// A helper class to insert annotation boundaries for a program region that will // A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler. // be handled by a specific compiler.
class AnnotateTargetWrapper : public ExprMutator { class AnnotateTargetWrapper : public ExprMutator {
...@@ -52,6 +55,13 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -52,6 +55,13 @@ class AnnotateTargetWrapper : public ExprMutator {
return fannotate[op](call->attrs, call->args); return fannotate[op](call->attrs, call->args);
} }
} }
if (expr->IsInstance<TupleGetItemNode>()) {
TupleGetItem get = Downcast<TupleGetItem>(expr);
if (get->tuple->IsInstance<CallNode>() &&
get->tuple.as<CallNode>()->op == compiler_begin_op) {
return true;
}
}
return false; return false;
} }
...@@ -110,9 +120,14 @@ class AnnotateTargetWrapper : public ExprMutator { ...@@ -110,9 +120,14 @@ class AnnotateTargetWrapper : public ExprMutator {
auto new_e = ExprMutator::VisitExpr_(op); auto new_e = ExprMutator::VisitExpr_(op);
auto get = Downcast<TupleGetItem>(new_e); auto get = Downcast<TupleGetItem>(new_e);
return TupleGetItem( if (IsSupported(get->tuple)) {
InsertEnd(get->tuple), const auto* begin_op =
get->index); runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
return TupleGetItem((*begin_op)(InsertEnd(get->tuple), target_), get->index);
} else {
return TupleGetItem(InsertEnd(get->tuple), get->index);
}
} }
Expr VisitExpr_(const FunctionNode* op) { Expr VisitExpr_(const FunctionNode* op) {
......
...@@ -113,7 +113,8 @@ def test_extern_dnnl(): ...@@ -113,7 +113,8 @@ def test_extern_dnnl():
padding=(1, 1), padding=(1, 1),
groups=32) groups=32)
end0 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl") end0 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl")
begin2 = relay.annotation.compiler_begin(end0, "dnnl") end1 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl")
begin2 = relay.annotation.compiler_begin(end1, "dnnl")
begin3 = relay.annotation.compiler_begin(end0, "dnnl") begin3 = relay.annotation.compiler_begin(end0, "dnnl")
begin4 = relay.annotation.compiler_begin(weight1, "dnnl") begin4 = relay.annotation.compiler_begin(weight1, "dnnl")
depthwise_conv2d_2 = relay.nn.conv2d(begin3, depthwise_conv2d_2 = relay.nn.conv2d(begin3,
...@@ -121,11 +122,11 @@ def test_extern_dnnl(): ...@@ -121,11 +122,11 @@ def test_extern_dnnl():
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1), padding=(1, 1),
groups=32) groups=32)
end1 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl") end2 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl")
begin5 = relay.annotation.compiler_begin(end1, "dnnl") begin5 = relay.annotation.compiler_begin(end2, "dnnl")
out = relay.add(begin2, begin5) out = relay.add(begin2, begin5)
end2 = relay.annotation.compiler_end(out, "dnnl") end3 = relay.annotation.compiler_end(out, "dnnl")
f = relay.Function([data, weight1], end2) f = relay.Function([data, weight1], end3)
mod = tvm.IRModule.from_expr(f) mod = tvm.IRModule.from_expr(f)
return mod return mod
...@@ -137,7 +138,7 @@ def test_extern_dnnl(): ...@@ -137,7 +138,7 @@ def test_extern_dnnl():
mod = annotated(dtype, ishape, w1shape) mod = annotated(dtype, ishape, w1shape)
mod = transform.AnnotateTarget("dnnl")(mod) mod = transform.AnnotateTarget("dnnl")(mod)
ref_mod = expected(dtype, ishape, w1shape) ref_mod = expected(dtype, ishape, w1shape)
# tvm.ir.assert_structural_equal(mod, ref_mod) tvm.ir.assert_structural_equal(mod, ref_mod)
def test_run(): def test_run():
if not tvm.get_global_func("relay.ext.dnnl", True): if not tvm.get_global_func("relay.ext.dnnl", True):
......
...@@ -66,13 +66,10 @@ def test_diamond_graph_fanouts(): ...@@ -66,13 +66,10 @@ def test_diamond_graph_fanouts():
O_2 = relay.nn.relu(O_1) O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, "test") ce_3 = compiler_end(O_2, "test")
cb_x = compiler_begin(ce_2, "default") X = relay.tanh(ce_2)
X = relay.tanh(cb_x)
ce_x1 = compiler_end(X, "default")
ce_x2 = compiler_end(X, "default")
cb_3 = compiler_begin(ce_3, "test") cb_3 = compiler_begin(ce_3, "test")
cb_4 = compiler_begin(ce_x1, "test") cb_4 = compiler_begin(X, "test")
O_3 = relay.add(cb_3, cb_4) O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, "test") ce_4 = compiler_end(O_3, "test")
...@@ -162,36 +159,28 @@ def test_example_graph(): ...@@ -162,36 +159,28 @@ def test_example_graph():
node1 = relay.add(begin2, begin3) node1 = relay.add(begin2, begin3)
node2 = relay.add(node0, node1) node2 = relay.add(node0, node1)
begin4 = compiler_begin(in_5, "default") node3 = relay.subtract(in_5, in_6)
begin5 = compiler_begin(in_6, "default") node4 = relay.subtract(in_7, node3)
begin6 = compiler_begin(in_7, "default")
node3 = relay.subtract(begin4, begin5)
node4 = relay.subtract(begin6, node3)
end0 = compiler_end(node4, "default")
begin7 = compiler_begin(end0, "test")
begin8 = compiler_begin(in_9, "test")
node5 = relay.add(node2, begin7) begin4 = compiler_begin(node4, "test")
begin5 = compiler_begin(in_9, "test")
node5 = relay.add(node2, begin4)
end1 = compiler_end(node5, "test") end1 = compiler_end(node5, "test")
begin9 = compiler_begin(end1, "default") node6 = relay.subtract(in_8, end1)
begin10 = compiler_begin(in_8, "default")
node6 = relay.subtract(begin10, begin9)
end2 = compiler_end(node6, "default")
node7 = relay.add(begin8, node5) node7 = relay.add(begin5, node5)
end3 = compiler_end(node7, "test") end2 = compiler_end(node7, "test")
begin11 = compiler_begin(end3, "test") begin6 = compiler_begin(end2, "test")
begin12 = compiler_begin(end2, "test") begin7 = compiler_begin(node6, "test")
node8 = relay.add(begin12, begin11) node8 = relay.add(begin7, begin6)
begin13 = compiler_begin(in_10, "test") begin8 = compiler_begin(in_10, "test")
node9 = relay.add(begin13, node8) node9 = relay.add(begin8, node8)
end4 = compiler_end(node9, "test") end3 = compiler_end(node9, "test")
f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end4) f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end3)
mod = tvm.IRModule.from_expr(f) mod = tvm.IRModule.from_expr(f)
return mod return mod
......
...@@ -725,12 +725,12 @@ def test_multiple_outputs(): ...@@ -725,12 +725,12 @@ def test_multiple_outputs():
mod = tvm.IRModule() mod = tvm.IRModule()
# function 0 # function 0
data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32")) data = relay.var("test_target_2_i0", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32")) weight = relay.var("test_target_2_i1", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32")) bn_gamma = relay.var("test_target_2_i2", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32")) bn_beta = relay.var("test_target_2_i3", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32")) bn_mean = relay.var("test_target_2_i4", relay.TensorType((16, ), "float32"))
bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32")) bn_var = relay.var("test_target_2_i5", relay.TensorType((16, ), "float32"))
conv_o = relay.nn.conv2d( conv_o = relay.nn.conv2d(
data=data, data=data,
...@@ -743,7 +743,7 @@ def test_multiple_outputs(): ...@@ -743,7 +743,7 @@ def test_multiple_outputs():
bn_var) bn_var)
relu_o = relay.nn.relu(bn_o[0]) relu_o = relay.nn.relu(bn_o[0])
tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2])) tuple_o = relay.Tuple((bn_o[2], bn_o[1], relu_o))
func0 = relay.Function([data, weight, bn_gamma, bn_beta, func0 = relay.Function([data, weight, bn_gamma, bn_beta,
bn_mean, bn_var], tuple_o) bn_mean, bn_var], tuple_o)
...@@ -752,8 +752,8 @@ def test_multiple_outputs(): ...@@ -752,8 +752,8 @@ def test_multiple_outputs():
func0 = func0.with_attr("Compiler", func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target")) tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol", func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_0")) tvm.tir.StringImm("test_target_2"))
gv0 = relay.GlobalVar("test_target_0") gv0 = relay.GlobalVar("test_target_2")
mod[gv0] = func0 mod[gv0] = func0
# body # body
...@@ -765,9 +765,9 @@ def test_multiple_outputs(): ...@@ -765,9 +765,9 @@ def test_multiple_outputs():
bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32")) bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32"))
f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var) f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var)
f0_relu_o = relay.TupleGetItem(f0_o, 0) f0_relu_o = relay.TupleGetItem(f0_o, 2)
f0_mean_o = relay.TupleGetItem(f0_o, 1) f0_mean_o = relay.TupleGetItem(f0_o, 1)
f0_var_o = relay.TupleGetItem(f0_o, 2) f0_var_o = relay.TupleGetItem(f0_o, 0)
f0_mean_abs = relay.abs(f0_mean_o) f0_mean_abs = relay.abs(f0_mean_o)
f0_var_abs = relay.abs(f0_var_o) f0_var_abs = relay.abs(f0_var_o)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment