Commit 04f7a183 by masahi Committed by Tianqi Chen

[Relay] Add support for tuple node in operator fusion (#2187)

parent 2f1d709f
......@@ -232,8 +232,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}
void VisitExpr_(const TupleNode* op) {
CHECK(graph_.node_map.count(op));
Node* tuple_node = graph_.node_map.at(op);
tuple_node->pattern = kInjective;
for (const Expr& field : op->fields) {
this->Update(field, nullptr, kOpaque);
this->Update(field, tuple_node, kInjective);
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
......@@ -712,32 +715,15 @@ class FuseMutator : private ExprMutator {
// then we must have a group assignment for it already.
CHECK(gmap_.count(call));
auto* ret_group = gmap_.at(call)->FindRoot();
Array<Expr> new_args;
for (auto arg : call->args) {
auto type = arg->checked_type();
CHECK(gmap_.count(arg.get()))
<< "cannot find group of " << arg;
auto* arg_group = gmap_.at(arg.get())->FindRoot();
Expr new_arg = this->Mutate(arg);
Array<Expr> new_args = GetNewArguments(call->args, ret_group);
if (ret_group != arg_group) {
Var param = ginfo_[ret_group].GetOrAllocParam(new_arg, type);
new_args.push_back(param);
} else {
new_args.push_back(new_arg);
}
}
auto new_call = CallNode::make(
call->op, new_args, call->attrs, call->type_args);
if (ret_group->root_ref == call) {
// This is the root of the group
// create the new call node.
const GroupInfo& ginfo = ginfo_[ret_group];
auto func = FunctionNode::make(
ginfo.params, new_call, call->checked_type(), {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
return CallNode::make(func, ginfo.arguments, Attrs());
return MakeNewFunction(ret_group, call->checked_type(), new_call);
} else {
// This is an intermediate node of a fused function
// simply return the new call.
......@@ -747,6 +733,51 @@ class FuseMutator : private ExprMutator {
return ExprMutator::VisitExpr_(call);
}
}
Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
Tuple new_tuple = TupleNode::make(new_fields);
if (ret_group == gmap_.at(tuple)) {
bool isolated = true;
for (size_t i = 0; i < new_fields.size(); ++i) {
isolated &= (new_fields[i].same_as(ginfo_[ret_group].params[i]));
}
if (isolated) {
// Do not put a isolated tuple into a function
return ExprMutator::VisitExpr_(tuple);
}
// This tuple has been fused with other ops before it
return MakeNewFunction(ret_group, tuple->checked_type(), new_tuple);
}
// This tuple is an intermediate node in the group
return new_tuple;
}
Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
return CallNode::make(func, ginfo.arguments, Attrs());
}
Array<Expr> GetNewArguments(const tvm::Array<Expr>& args,
GraphPartitioner::Group* current_group) {
Array<Expr> new_args;
for (auto arg : args) {
auto* arg_group = gmap_.at(arg.get())->FindRoot();
auto type = arg->checked_type();
Expr new_arg = this->Mutate(arg);
if (current_group != arg_group) {
Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type);
new_args.push_back(param);
} else {
new_args.push_back(new_arg);
}
}
return new_args;
}
// Debug function, dump the group assignment in text.
void DebugDumpGroup(const Expr& body) {
std::string text = RelayPrint(body, false, [this](const Expr& expr) -> std::string {
......
......@@ -28,8 +28,6 @@ def test_fuse_simple():
assert relay.ir_pass.alpha_equal(zz, after)
def test_conv2d_fuse():
"""Test fusion case of conv2d"""
def before(dshape):
......@@ -106,7 +104,86 @@ def test_conv2d_fuse():
assert relay.ir_pass.alpha_equal(zz, after)
def test_concatenate():
"""Test fusion case involving concat op and Tuple node"""
def before(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW")
concat = relay.concatenate((upsampled, x), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
return relay.Function(relay.ir_pass.free_vars(out), out)
def expected(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
f0 = relay.Function([x], pooled)
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=dshape)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
concat = relay.concatenate((upsampled, p1), axis=1)
out = relay.add(concat, relay.const(1, "float32"))
f1 = relay.Function([p0, p1], out)
x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y, x])
return relay.Function([x], z)
dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
def test_tuple_root():
"""Test fusion case where Tuple node is the root in its group"""
def before(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
upsampled = relay.nn.upsampling(pooled, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, x))
return relay.Function(relay.ir_pass.free_vars(out), out)
def expected(dshape):
x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
f0 = relay.Function([x], pooled)
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, p1))
f1 = relay.Function([p0, p1], out)
x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x])
z = relay.Call(f1, [y, x])
return relay.Function([x], z)
dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_concatenate()
test_tuple_root()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment