Commit 4d8ecb37 by Wuwei Lin Committed by Tianqi Chen

[RELAY] Copy subfunction arguments to output tuple field (#2537)

parent b63182ea
......@@ -10,6 +10,7 @@
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include "./pattern_util.h"
#include "../../common/arena.h"
......@@ -738,7 +739,6 @@ class FuseMutator : private ExprMutator {
Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
Tuple new_tuple = TupleNode::make(new_fields);
if (ret_group == gmap_.at(tuple)) {
// This tuple is the root of its group. Check if all fields come from other groups.
bool isolated = new_fields.size() == ginfo_[ret_group].params.size();
......@@ -750,10 +750,18 @@ class FuseMutator : private ExprMutator {
return ExprMutator::VisitExpr_(tuple);
}
// This tuple has been fused with other ops before it
return MakeNewFunction(ret_group, tuple->checked_type(), new_tuple);
for (size_t i = 0; i < new_fields.size(); i++) {
// Copy function arguments to tuple field of the output because currently graph memory
// planer doesn't support inplace operations
if (new_fields[i].as<VarNode>()) {
auto copy = Copy(new_fields[i]);
new_fields.Set(i, copy);
}
}
return MakeNewFunction(ret_group, tuple->checked_type(), TupleNode::make(new_fields));
}
// This tuple is an intermediate node in the group
return new_tuple;
return TupleNode::make(new_fields);
}
Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
......
......@@ -318,6 +318,13 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr Copy(Expr data) {
static const Op& op = Op::Get("copy");
return CallNode::make(op, {data}, Attrs(), {});
}
Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
......
......@@ -43,6 +43,18 @@ def test_compile_placeholder_bypass():
with relay.build_config(opt_level=0):
graph, lib, params = relay.build(func, 'llvm')
def test_compile_injective_with_tuple():
x = relay.var("x", shape=(2, 3))
y = relay.var("y", shape=(2, 3))
x_transpose = relay.transpose(x)
output = relay.Tuple([x_transpose, y])
func = relay.Function([x, y], output)
relay.build(func, 'llvm')
if __name__ == "__main__":
test_compile_engine()
test_compile_placeholder_bypass()
test_compile_injective_with_tuple()
......@@ -161,8 +161,9 @@ def test_tuple_root():
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=(dshape[0], dshape[1], dshape[2], dshape[3]))
p1_copy = relay.copy(p1)
upsampled = relay.nn.upsampling(p0, scale=2, layout="NCHW")
out = relay.Tuple((upsampled, p1))
out = relay.Tuple((upsampled, p1_copy))
f1 = relay.Function([p0, p1], out)
x = relay.var("x", shape=dshape)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment