Commit c162e7d6 by masahi Committed by Tianqi Chen

[Relay, OpFusion] Fix handling TupleGetItem for nested tuples (#2929)

parent 39c116f0
......@@ -3,7 +3,7 @@
from __future__ import absolute_import
from ..expr import const
from .op import register_gradient
from .transform import collapse_sum_like, where
from .transform import collapse_sum_like, broadcast_to_like, where
from .tensor import exp, negative, power, less
from .tensor import zeros_like, ones_like
......@@ -77,3 +77,20 @@ def divide_grad(orig, grad):
x, y = orig.args
return [collapse_sum_like(grad / y, x),
collapse_sum_like(- (grad * orig / y), y)]
@register_gradient("zeros_like")
def zeros_like_grad(orig, grad):
"""Returns [0]"""
return [orig]
@register_gradient("ones_like")
def ones_like_grad(orig, grad):
"""Returns [0]"""
return [zeros_like(orig.args[0])]
@register_gradient("collapse_sum_like")
def collapse_sum_like_grad(orig, grad):
"""Returns [broadcast_to_like(grad, x), 0]"""
x, y = orig.args
return [broadcast_to_like(grad, x), zeros_like(y)]
......@@ -263,21 +263,19 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
void VisitExpr_(const TupleGetItemNode* op) final {
auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>();
CHECK(tuple_type);
// If this tuple contain a reference type, and we fuse TupleGetItem and
// the reference, a fused function will have a tuple containing a reference
// in its parameters. But when TVM lowers a fused function, it expects all
// arguments to be a Tensor or a tuple containing only Tensors.
// To avoid modifying codegen logic, we do not allow fusing through a reference.
// The reference itself will be recursively visited via call to ExprVisitor::VisitExpr_(op)
// below and corresponding visitor methods
bool has_reference = false;
// when TVM lowers a fused function, it expects all arguments to be a Tensor or
// a tuple containing only Tensors. But this tuple may contain a reference or
// another tuple. To avoid modifying codegen logic, we do not allow fusing through this node
// if the tuple contains such non Tensor fields. However, all fields will be recursively
// visited via call to ExprVisitor::VisitExpr_(op) below and corresponding visitor methods.
bool has_non_tensor = false;
for (auto ty : tuple_type->fields) {
if (ty.as<RefTypeNode>()) {
has_reference = true;
if (!ty.as<TensorTypeNode>()) {
has_non_tensor = true;
break;
}
}
if (has_reference) {
if (has_non_tensor) {
this->Update(op->tuple, nullptr, kOpaque);
} else {
CHECK(graph_.node_map.count(op));
......
......@@ -20,8 +20,8 @@ def test_id():
ex = create_executor()
x = rand(dtype, *shape)
forward, (grad,) = ex.evaluate(back_func)(x)
np.testing.assert_allclose(forward.asnumpy(), x.asnumpy())
np.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy())
tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
def test_add():
......@@ -35,8 +35,8 @@ def test_add():
ex = create_executor()
x = rand(dtype, *shape)
forward, (grad,) = ex.evaluate(back_func)(x)
np.testing.assert_allclose(forward.asnumpy(), 2 * x.asnumpy())
np.testing.assert_allclose(grad.asnumpy(), 2 * np.ones_like(x.asnumpy()))
tvm.testing.assert_allclose(forward.asnumpy(), 2 * x.asnumpy())
tvm.testing.assert_allclose(grad.asnumpy(), 2 * np.ones_like(x.asnumpy()))
def test_temp_add():
......@@ -51,8 +51,8 @@ def test_temp_add():
ex = create_executor()
x = rand(dtype, *shape)
forward, (grad,) = ex.evaluate(back_func)(x)
np.testing.assert_allclose(forward.asnumpy(), 4 * x.asnumpy())
np.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy()))
tvm.testing.assert_allclose(forward.asnumpy(), 4 * x.asnumpy())
tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy()))
def test_sub():
......@@ -66,8 +66,8 @@ def test_sub():
ex = create_executor()
x = rand(dtype, *shape)
forward, (grad,) = ex.evaluate(back_func)(x)
np.testing.assert_allclose(forward.asnumpy(), np.zeros_like(x.asnumpy()))
np.testing.assert_allclose(grad.asnumpy(), np.zeros_like(x.asnumpy()))
tvm.testing.assert_allclose(forward.asnumpy(), np.zeros_like(x.asnumpy()))
tvm.testing.assert_allclose(grad.asnumpy(), np.zeros_like(x.asnumpy()))
def test_broadcast_add():
......@@ -90,11 +90,11 @@ def test_broadcast_add():
relay.TupleType([t1, t2])]))
ex = create_executor()
forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd)
np.testing.assert_allclose(forward.asnumpy(), expected_forward)
np.testing.assert_allclose(grad_x.asnumpy(),
np.ones_like(expected_forward).sum(axis=2, keepdims=True))
np.testing.assert_allclose(grad_y.asnumpy(),
np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
tvm.testing.assert_allclose(forward.asnumpy(), expected_forward)
tvm.testing.assert_allclose(grad_x.asnumpy(),
np.ones_like(expected_forward).sum(axis=2, keepdims=True))
tvm.testing.assert_allclose(grad_y.asnumpy(),
np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
def test_broadcast_subtract():
......@@ -117,11 +117,11 @@ def test_broadcast_subtract():
relay.TupleType([t1, t2])]))
ex = create_executor()
forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd)
np.testing.assert_allclose(forward.asnumpy(), expected_forward)
np.testing.assert_allclose(grad_x.asnumpy(),
np.ones_like(expected_forward).sum(axis=2, keepdims=True))
np.testing.assert_allclose(grad_y.asnumpy(),
-np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
tvm.testing.assert_allclose(forward.asnumpy(), expected_forward)
tvm.testing.assert_allclose(grad_x.asnumpy(),
np.ones_like(expected_forward).sum(axis=2, keepdims=True))
tvm.testing.assert_allclose(grad_y.asnumpy(),
-np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
def test_tuple():
......@@ -147,10 +147,10 @@ def test_tuple():
expected_forward = x_np + y_np - z_np
ex = create_executor()
forward, (grad_x, grad_y, grad_z) = ex.evaluate(back_func)(x_nd, y_nd, z_nd)
np.testing.assert_allclose(forward.asnumpy(), expected_forward)
np.testing.assert_allclose(grad_x.asnumpy(), np.ones_like(grad_x.asnumpy()))
np.testing.assert_allclose(grad_y.asnumpy(), np.ones_like(grad_y.asnumpy()))
np.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy()))
tvm.testing.assert_allclose(forward.asnumpy(), expected_forward)
tvm.testing.assert_allclose(grad_x.asnumpy(), np.ones_like(grad_x.asnumpy()))
tvm.testing.assert_allclose(grad_y.asnumpy(), np.ones_like(grad_y.asnumpy()))
tvm.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy()))
def test_pow():
......@@ -168,8 +168,9 @@ def test_pow():
i_nd = rand(dtype, *shape)
ex = create_executor(mod=mod)
forward, (grad_i,) = ex.evaluate(back_func)(i_nd)
np.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
np.testing.assert_allclose(grad_i.asnumpy(), 8 * np.ones_like(grad_i.asnumpy()))
tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
tvm.testing.assert_allclose(grad_i.asnumpy(), 8 * np.ones_like(grad_i.asnumpy()))
def test_ref():
shape = (10, 10)
......@@ -187,8 +188,28 @@ def test_ref():
x_nd = rand(dtype, *shape)
ex = create_executor()
forward, (grad_x,) = ex.evaluate(back_func)(x_nd)
np.testing.assert_allclose(forward.asnumpy(), 2 * x_nd.asnumpy())
np.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy()))
tvm.testing.assert_allclose(forward.asnumpy(), 2 * x_nd.asnumpy())
tvm.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy()))
def test_square_second_order():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x * x)
back_func = relay.ir_pass.infer_type(gradient(func))
y = relay.var("y", t)
back_func_adjusted = relay.Function([y], relay.TupleGetItem(relay.TupleGetItem(back_func(y), 1), 0))
back_func_adjusted = relay.ir_pass.infer_type(back_func_adjusted)
back_back_func = relay.ir_pass.infer_type(gradient(back_func_adjusted))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
x_nd = rand(dtype, *shape)
ex = create_executor()
forward, (grad_x,) = ex.evaluate(back_back_func)(x_nd)
tvm.testing.assert_allclose(forward.asnumpy(), 2 * x_nd.asnumpy())
tvm.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy()))
if __name__ == "__main__":
test_id()
......@@ -200,3 +221,4 @@ if __name__ == "__main__":
test_tuple()
test_pow()
test_ref()
test_square_second_order()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment