Commit 82e868a4 by masahi Committed by Tianqi Chen

[Relay] Add support for TupleGetItem in op fusion (#2914)

parent a0537ecb
...@@ -261,9 +261,30 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -261,9 +261,30 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
} }
void VisitExpr_(const TupleGetItemNode* op) final { void VisitExpr_(const TupleGetItemNode* op) final {
CHECK(graph_.node_map.count(op)); auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>();
Node* node = graph_.node_map.at(op); CHECK(tuple_type);
this->Update(op->tuple, node, kOpaque); // If this tuple contain a reference type, and we fuse TupleGetItem and
// the reference, a fused function will have a tuple containing a reference
// in its parameters. But when TVM lowers a fused function, it expects all
// arguments to be a Tensor or a tuple containing only Tensors.
// To avoid modifying codegen logic, we do not allow fusing through a reference.
// The reference itself will be recursively visited via call to ExprVisitor::VisitExpr_(op)
// below and corresponding visitor methods
bool has_reference = false;
for (auto ty : tuple_type->fields) {
if (ty.as<RefTypeNode>()) {
has_reference = true;
break;
}
}
if (has_reference) {
this->Update(op->tuple, nullptr, kOpaque);
} else {
CHECK(graph_.node_map.count(op));
Node* node = graph_.node_map.at(op);
node->pattern = kInjective;
this->Update(op->tuple, node, kInjective);
}
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
this->AddNode(op); this->AddNode(op);
} }
...@@ -809,6 +830,23 @@ class FuseMutator : private ExprMutator { ...@@ -809,6 +830,23 @@ class FuseMutator : private ExprMutator {
return TupleNode::make(new_fields); return TupleNode::make(new_fields);
} }
Expr VisitExpr_(const TupleGetItemNode* tuple_get) {
auto* ret_group = gmap_.at(tuple_get)->FindRoot();
auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index);
if (ret_group == gmap_.at(tuple_get)) {
if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
// Isolated. This case occurs when tuple is created by an Opaque op
// e.g. multibox_transform_loc
return ExprMutator::VisitExpr_(tuple_get);
}
// A new function whose output is a tuple field access
return MakeNewFunction(ret_group, tuple_get->checked_type(), new_node);
}
// This is an intermediate node in the group
return new_node;
}
Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
const GroupInfo& ginfo = ginfo_[group]; const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
......
...@@ -7,6 +7,7 @@ from tvm.relay.ir_pass import infer_type ...@@ -7,6 +7,7 @@ from tvm.relay.ir_pass import infer_type
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add from tvm.relay.op import add
from tvm.relay.module import Module from tvm.relay.module import Module
from tvm.relay.testing.config import ctx_list
# @tq, @jr should we put this in testing ns? # @tq, @jr should we put this in testing ns?
def check_rts(expr, args, expected_result, mod=None): def check_rts(expr, args, expected_result, mod=None):
...@@ -127,9 +128,47 @@ def test_plan_memory(): ...@@ -127,9 +128,47 @@ def test_plan_memory():
assert len(device_types) == 1 assert len(device_types) == 1
def test_gru_like():
def unit(rnn_dim):
X = relay.var("X", shape=(1, rnn_dim))
W = relay.var("y", shape=(3 * rnn_dim, rnn_dim))
matmul = relay.nn.dense(X, W)
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
return relay.Function([X, W], out)
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def unit_numpy(X, W):
prod = np.dot(X, W.transpose())
splits = np.split(prod, indices_or_sections=3, axis=1)
return sigmoid(splits[0]) + np.tanh(splits[1]) * np.exp(splits[2])
dtype = "float32"
rnn_dim = 1000
x = np.random.rand(1, rnn_dim).astype(dtype)
y = np.random.rand(3*rnn_dim, rnn_dim).astype(dtype) * 0.01 - 0.005
out_shape = (1, rnn_dim)
z = unit(rnn_dim)
for target, ctx in ctx_list():
with relay.build_config(opt_level=2):
graph, lib, params = relay.build(z, target)
m = graph_runtime.create(graph, lib, ctx)
m.set_input("X", tvm.nd.array(x.astype(dtype)))
m.set_input("y", tvm.nd.array(y.astype(dtype)))
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
ref = unit_numpy(x, y)
tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
test_plan_memory() test_plan_memory()
test_with_params() test_with_params()
test_add_op_scalar() test_add_op_scalar()
test_add_op_tensor() test_add_op_tensor()
test_add_op_broadcast() test_add_op_broadcast()
test_gru_like()
...@@ -217,7 +217,6 @@ def test_tuple_strided_slice(): ...@@ -217,7 +217,6 @@ def test_tuple_strided_slice():
assert not relay.ir_pass.free_vars(zz) assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape)) after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after) assert relay.ir_pass.alpha_equal(zz, after)
print(zz.astext())
def test_stop_fusion(): def test_stop_fusion():
...@@ -287,6 +286,81 @@ def test_fuse_myia_regression(): ...@@ -287,6 +286,81 @@ def test_fuse_myia_regression():
assert relay.ir_pass.alpha_equal(f, after) assert relay.ir_pass.alpha_equal(f, after)
def test_fuse_tuple_get_elemwise():
def before(dim):
X = relay.var("X", shape=(1, dim))
W = relay.var("W", shape=(3 * dim, dim))
matmul = relay.nn.dense(X, W)
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
return relay.Function([X, W], out)
def expected(dim):
p0 = relay.var("p0", shape=(1, dim))
p1 = relay.var("p1", shape=(3 * dim, dim))
matmul = relay.nn.dense(p0, p1)
f0 = relay.Function([p0, p1], matmul)
p01 = relay.var("p01", shape=(1, 3 * dim))
splitted = relay.split(p01, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
f1 = relay.Function([p01], out)
X = relay.var("X", shape=(1, dim))
W = relay.var("W", shape=(3 * dim, dim))
y = relay.Call(f0, [X, W])
z = relay.Call(f1, [y])
return relay.Function([X, W], z)
dim = 10
z = before(dim)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dim))
assert relay.ir_pass.alpha_equal(zz, after)
def test_tuple_get_root():
def before(dim):
X = relay.var("X", shape=(1, 3 * dim))
W = relay.var("W", shape=(dim, dim))
splitted = relay.split(X, indices_or_sections=3, axis=1)
out = relay.nn.dense(splitted[0], W)
return relay.Function([X, W], out)
def expected(dim):
p0 = relay.var("p0", shape=(1, 3 * dim))
splitted = relay.split(p0, indices_or_sections=3, axis=1)
out = splitted[0]
f0 = relay.Function([p0], out)
p01 = relay.var("p01", shape=(1, dim))
p1 = relay.var("p1", shape=(dim, dim))
out = relay.nn.dense(p01, p1)
f1 = relay.Function([p01, p1], out)
X = relay.var("X", shape=(1, 3 * dim))
W = relay.var("W", shape=(dim, dim))
y = relay.Call(f0, [X])
z = relay.Call(f1, [y, W])
return relay.Function([X, W], z)
dim = 10
z = before(dim)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dim))
assert relay.ir_pass.alpha_equal(zz, after)
if __name__ == "__main__": if __name__ == "__main__":
test_fuse_simple() test_fuse_simple()
test_conv2d_fuse() test_conv2d_fuse()
...@@ -295,3 +369,5 @@ if __name__ == "__main__": ...@@ -295,3 +369,5 @@ if __name__ == "__main__":
test_tuple_strided_slice() test_tuple_strided_slice()
test_stop_fusion() test_stop_fusion()
test_fuse_myia_regression() test_fuse_myia_regression()
test_fuse_tuple_get_elemwise()
test_tuple_get_root()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment