Unverified Commit 04499665 by mbaret Committed by GitHub

[RELAY] Fixes to MergeCompilerRegions (#5195)

* [RELAY] Fixed issues with MergeCompilerRegions

This PR addresses a few outstanding issues with
the implementation of MergeCompilerRegions. In
particular, it now handles TupleGetItem nodes properly
and other minor bugs related to region merging have
been fixed.

Change-Id: I07783afc56183a6f798a510209f23b0a5f252255

* Fixed issue using pre-merged regions

Change-Id: I0a844ac59bda1089ae0c67cef52f0b0c7ab2cbd7

* Removed some debugging logic

Change-Id: Ib6f2eede6f38bbb270073eb8d4c4dc19f60832c6

* Remove default annotations

Change-Id: I9b7696a51c95871491cbea33c40f92ec327e417f

* Annotate default 'if's

Change-Id: I0098bd1bf6788dd6366810dcefa84f1ebbffaab0

* Clang format

Change-Id: I944365cd3080a97a9261f643a8f1efa5a63cf82b

* Use src/dest in merge

Change-Id: Ie43113492bda8f1ce63eaf9615cb645bb9e2ee86

* Fixed partition test

Change-Id: I46f9e349b1a813a9140f7e4f8a2241687e2df73b

* Removed comments

Change-Id: I309afdd1951d7e796e41d13788aa487707e0ac4c
parent 2f41a396
......@@ -70,12 +70,12 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
regions_.erase(src);
}
void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) {
auto region2 = GetRegion(expr);
if (region2.defined()) {
MergeRegions(region, region2);
void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) {
auto src = GetRegion(expr);
if (src.defined()) {
MergeRegions(src, dest);
} else {
region->nodes.insert(expr);
dest->nodes.insert(expr);
}
}
......
......@@ -178,10 +178,10 @@ class AnnotatedRegionSetNode : public Object {
/*!
* \brief Add an expression to a region.
*
* \param region The region to add the expression to.
* \param dest The region to add the expression to.
* \param expr The expression.
*/
void AddToRegion(AnnotatedRegion region, const Expr& expr);
void AddToRegion(AnnotatedRegion dest, const Expr& expr);
/*!
* \brief Make a new region.
......
......@@ -32,6 +32,9 @@ namespace tvm {
namespace relay {
namespace annotate_target {
// Cache compiler_begin op for equivalence check.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
// A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler.
class AnnotateTargetWrapper : public ExprMutator {
......@@ -52,6 +55,13 @@ class AnnotateTargetWrapper : public ExprMutator {
return fannotate[op](call->attrs, call->args);
}
}
if (expr->IsInstance<TupleGetItemNode>()) {
TupleGetItem get = Downcast<TupleGetItem>(expr);
if (get->tuple->IsInstance<CallNode>() &&
get->tuple.as<CallNode>()->op == compiler_begin_op) {
return true;
}
}
return false;
}
......@@ -110,9 +120,14 @@ class AnnotateTargetWrapper : public ExprMutator {
auto new_e = ExprMutator::VisitExpr_(op);
auto get = Downcast<TupleGetItem>(new_e);
return TupleGetItem(
InsertEnd(get->tuple),
get->index);
if (IsSupported(get->tuple)) {
const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
return TupleGetItem((*begin_op)(InsertEnd(get->tuple), target_), get->index);
} else {
return TupleGetItem(InsertEnd(get->tuple), get->index);
}
}
Expr VisitExpr_(const FunctionNode* op) {
......
......@@ -113,7 +113,8 @@ def test_extern_dnnl():
padding=(1, 1),
groups=32)
end0 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl")
begin2 = relay.annotation.compiler_begin(end0, "dnnl")
end1 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl")
begin2 = relay.annotation.compiler_begin(end1, "dnnl")
begin3 = relay.annotation.compiler_begin(end0, "dnnl")
begin4 = relay.annotation.compiler_begin(weight1, "dnnl")
depthwise_conv2d_2 = relay.nn.conv2d(begin3,
......@@ -121,11 +122,11 @@ def test_extern_dnnl():
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
end1 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl")
begin5 = relay.annotation.compiler_begin(end1, "dnnl")
end2 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl")
begin5 = relay.annotation.compiler_begin(end2, "dnnl")
out = relay.add(begin2, begin5)
end2 = relay.annotation.compiler_end(out, "dnnl")
f = relay.Function([data, weight1], end2)
end3 = relay.annotation.compiler_end(out, "dnnl")
f = relay.Function([data, weight1], end3)
mod = tvm.IRModule.from_expr(f)
return mod
......@@ -137,7 +138,7 @@ def test_extern_dnnl():
mod = annotated(dtype, ishape, w1shape)
mod = transform.AnnotateTarget("dnnl")(mod)
ref_mod = expected(dtype, ishape, w1shape)
# tvm.ir.assert_structural_equal(mod, ref_mod)
tvm.ir.assert_structural_equal(mod, ref_mod)
def test_run():
if not tvm.get_global_func("relay.ext.dnnl", True):
......
......@@ -66,13 +66,10 @@ def test_diamond_graph_fanouts():
O_2 = relay.nn.relu(O_1)
ce_3 = compiler_end(O_2, "test")
cb_x = compiler_begin(ce_2, "default")
X = relay.tanh(cb_x)
ce_x1 = compiler_end(X, "default")
ce_x2 = compiler_end(X, "default")
X = relay.tanh(ce_2)
cb_3 = compiler_begin(ce_3, "test")
cb_4 = compiler_begin(ce_x1, "test")
cb_4 = compiler_begin(X, "test")
O_3 = relay.add(cb_3, cb_4)
ce_4 = compiler_end(O_3, "test")
......@@ -162,36 +159,28 @@ def test_example_graph():
node1 = relay.add(begin2, begin3)
node2 = relay.add(node0, node1)
begin4 = compiler_begin(in_5, "default")
begin5 = compiler_begin(in_6, "default")
begin6 = compiler_begin(in_7, "default")
node3 = relay.subtract(begin4, begin5)
node4 = relay.subtract(begin6, node3)
end0 = compiler_end(node4, "default")
begin7 = compiler_begin(end0, "test")
begin8 = compiler_begin(in_9, "test")
node3 = relay.subtract(in_5, in_6)
node4 = relay.subtract(in_7, node3)
node5 = relay.add(node2, begin7)
begin4 = compiler_begin(node4, "test")
begin5 = compiler_begin(in_9, "test")
node5 = relay.add(node2, begin4)
end1 = compiler_end(node5, "test")
begin9 = compiler_begin(end1, "default")
begin10 = compiler_begin(in_8, "default")
node6 = relay.subtract(begin10, begin9)
end2 = compiler_end(node6, "default")
node6 = relay.subtract(in_8, end1)
node7 = relay.add(begin8, node5)
end3 = compiler_end(node7, "test")
begin11 = compiler_begin(end3, "test")
begin12 = compiler_begin(end2, "test")
node7 = relay.add(begin5, node5)
end2 = compiler_end(node7, "test")
begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(node6, "test")
node8 = relay.add(begin12, begin11)
node8 = relay.add(begin7, begin6)
begin13 = compiler_begin(in_10, "test")
node9 = relay.add(begin13, node8)
end4 = compiler_end(node9, "test")
begin8 = compiler_begin(in_10, "test")
node9 = relay.add(begin8, node8)
end3 = compiler_end(node9, "test")
f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end4)
f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end3)
mod = tvm.IRModule.from_expr(f)
return mod
......
......@@ -725,12 +725,12 @@ def test_multiple_outputs():
mod = tvm.IRModule()
# function 0
data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32"))
bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32"))
data = relay.var("test_target_2_i0", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("test_target_2_i1", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("test_target_2_i2", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("test_target_2_i3", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("test_target_2_i4", relay.TensorType((16, ), "float32"))
bn_var = relay.var("test_target_2_i5", relay.TensorType((16, ), "float32"))
conv_o = relay.nn.conv2d(
data=data,
......@@ -743,7 +743,7 @@ def test_multiple_outputs():
bn_var)
relu_o = relay.nn.relu(bn_o[0])
tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2]))
tuple_o = relay.Tuple((bn_o[2], bn_o[1], relu_o))
func0 = relay.Function([data, weight, bn_gamma, bn_beta,
bn_mean, bn_var], tuple_o)
......@@ -752,8 +752,8 @@ def test_multiple_outputs():
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_0"))
gv0 = relay.GlobalVar("test_target_0")
tvm.tir.StringImm("test_target_2"))
gv0 = relay.GlobalVar("test_target_2")
mod[gv0] = func0
# body
......@@ -765,9 +765,9 @@ def test_multiple_outputs():
bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32"))
f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var)
f0_relu_o = relay.TupleGetItem(f0_o, 0)
f0_relu_o = relay.TupleGetItem(f0_o, 2)
f0_mean_o = relay.TupleGetItem(f0_o, 1)
f0_var_o = relay.TupleGetItem(f0_o, 2)
f0_var_o = relay.TupleGetItem(f0_o, 0)
f0_mean_abs = relay.abs(f0_mean_o)
f0_var_abs = relay.abs(f0_var_o)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment