Commit 4d8ecb37 by Wuwei Lin Committed by Tianqi Chen

[RELAY] Copy subfunction arguments to output tuple field (#2537)

parent b63182ea
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include "./pattern_util.h"
#include "../../common/arena.h" #include "../../common/arena.h"
...@@ -738,7 +739,6 @@ class FuseMutator : private ExprMutator { ...@@ -738,7 +739,6 @@ class FuseMutator : private ExprMutator {
Expr VisitExpr_(const TupleNode* tuple) { Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot(); auto* ret_group = gmap_.at(tuple)->FindRoot();
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group); Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
Tuple new_tuple = TupleNode::make(new_fields);
if (ret_group == gmap_.at(tuple)) { if (ret_group == gmap_.at(tuple)) {
// This tuple is the root of its group. Check if all fields come from other groups. // This tuple is the root of its group. Check if all fields come from other groups.
bool isolated = new_fields.size() == ginfo_[ret_group].params.size(); bool isolated = new_fields.size() == ginfo_[ret_group].params.size();
...@@ -750,10 +750,18 @@ class FuseMutator : private ExprMutator { ...@@ -750,10 +750,18 @@ class FuseMutator : private ExprMutator {
return ExprMutator::VisitExpr_(tuple); return ExprMutator::VisitExpr_(tuple);
} }
// This tuple has been fused with other ops before it // This tuple has been fused with other ops before it
return MakeNewFunction(ret_group, tuple->checked_type(), new_tuple); for (size_t i = 0; i < new_fields.size(); i++) {
// Copy function arguments to tuple field of the output because currently graph memory
// planer doesn't support inplace operations
if (new_fields[i].as<VarNode>()) {
auto copy = Copy(new_fields[i]);
new_fields.Set(i, copy);
}
}
return MakeNewFunction(ret_group, tuple->checked_type(), TupleNode::make(new_fields));
} }
// This tuple is an intermediate node in the group // This tuple is an intermediate node in the group
return new_tuple; return TupleNode::make(new_fields);
} }
Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
......
...@@ -318,6 +318,13 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) { ...@@ -318,6 +318,13 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); return CallNode::make(op, {lhs, rhs}, Attrs(), {});
} }
inline Expr Copy(Expr data) {
static const Op& op = Op::Get("copy");
return CallNode::make(op, {data}, Attrs(), {});
}
Expr MakeConcatenate(Expr data, int axis); Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides); Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
......
...@@ -43,6 +43,18 @@ def test_compile_placeholder_bypass(): ...@@ -43,6 +43,18 @@ def test_compile_placeholder_bypass():
with relay.build_config(opt_level=0): with relay.build_config(opt_level=0):
graph, lib, params = relay.build(func, 'llvm') graph, lib, params = relay.build(func, 'llvm')
def test_compile_injective_with_tuple():
x = relay.var("x", shape=(2, 3))
y = relay.var("y", shape=(2, 3))
x_transpose = relay.transpose(x)
output = relay.Tuple([x_transpose, y])
func = relay.Function([x, y], output)
relay.build(func, 'llvm')
if __name__ == "__main__": if __name__ == "__main__":
test_compile_engine() test_compile_engine()
test_compile_placeholder_bypass() test_compile_placeholder_bypass()
test_compile_injective_with_tuple()
...@@ -161,8 +161,9 @@ def test_tuple_root(): ...@@ -161,8 +161,9 @@ def test_tuple_root():
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3])) p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
p1_copy = relay.copy(p1)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW") upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, p1)) out = relay.Tuple((upsampled, p1_copy))
f1 = relay.Function([p0, p1], out) f1 = relay.Function([p0, p1], out)
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment