Commit 79abd2c3 by Jared Roesch Committed by Tianqi Chen

Fix fusion bug when call symbol that is not an operator. (#2630)

parent aaad5f98
...@@ -208,11 +208,22 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -208,11 +208,22 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
Node* node = graph_.node_map.at(call); Node* node = graph_.node_map.at(call);
static auto fpattern = static auto fpattern =
Op::GetAttr<TOpPattern>("TOpPattern"); Op::GetAttr<TOpPattern>("TOpPattern");
// setup pattern. // Now we set the pattern of this call.
//
// If we see a call mentioning an operator we should mark it with its
// annotated pattern.
//
// If the pattern is not annotated we will default to opaque.
//
// Finally if the operator position is not a call node we will
// need to call Update, as it may be an arbitrary expression.
OpPatternKind op_pattern = kOpaque; OpPatternKind op_pattern = kOpaque;
if (const OpNode* opnode = call->op.as<OpNode>()) { if (const OpNode* opnode = call->op.as<OpNode>()) {
op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]); op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
} else {
this->Update(call->op, node, kOpaque);
} }
node->pattern = op_pattern; node->pattern = op_pattern;
const auto* rtype = call->checked_type().as<TensorTypeNode>(); const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the message back to all the children it references. // pass the message back to all the children it references.
......
...@@ -251,6 +251,42 @@ def test_stop_fusion(): ...@@ -251,6 +251,42 @@ def test_stop_fusion():
assert relay.ir_pass.alpha_equal(z, after) assert relay.ir_pass.alpha_equal(z, after)
def test_fuse_myia_regression():
def before(dshape, dtype):
x = relay.var('x', shape=dshape, dtype=dtype)
y = relay.var('y', shape=dshape, dtype=dtype)
sb = relay.ScopeBuilder()
with sb.if_scope(relay.op.greater(x, y)):
sb.ret(relay.Function([], x))
with sb.else_scope():
sb.ret(relay.Function([], y))
return relay.Function([x, y],
relay.Call(sb.get(), []))
def expected(dshape, dtype):
x = relay.var('x', shape=dshape, dtype=dtype)
y = relay.var('y', shape=dshape, dtype=dtype)
sb = relay.ScopeBuilder()
p1 = relay.var('p1', shape=dshape, dtype=dtype)
p2 = relay.var('p2', shape=dshape, dtype=dtype)
fused_gt = relay.Function([p1, p2],
relay.op.greater(p1, p2))
with sb.if_scope(fused_gt(x, y)):
sb.ret(relay.Function([], x))
with sb.else_scope():
sb.ret(relay.Function([], y))
return relay.Function([x, y],
relay.Call(sb.get(), []))
dshape = ()
dtype = 'int64'
f = before(dshape, dtype)
f = relay.ir_pass.infer_type(f)
f = relay.ir_pass.fuse_ops(f)
after = relay.ir_pass.infer_type(expected(dshape, dtype))
assert relay.ir_pass.alpha_equal(f, after)
if __name__ == "__main__": if __name__ == "__main__":
test_fuse_simple() test_fuse_simple()
test_conv2d_fuse() test_conv2d_fuse()
...@@ -258,3 +294,4 @@ if __name__ == "__main__": ...@@ -258,3 +294,4 @@ if __name__ == "__main__":
test_tuple_root() test_tuple_root()
test_tuple_strided_slice() test_tuple_strided_slice()
test_stop_fusion() test_stop_fusion()
test_fuse_myia_regression()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment