Commit b5e0d790 by Siju Committed by Tianqi Chen

[RELAY]sch and compute for reduce ops (#2091)

parent b9038343
...@@ -15,5 +15,6 @@ _reg.register_schedule("argmax", _schedule_reduce) ...@@ -15,5 +15,6 @@ _reg.register_schedule("argmax", _schedule_reduce)
_reg.register_schedule("argmin", _schedule_reduce) _reg.register_schedule("argmin", _schedule_reduce)
_reg.register_schedule("sum", _schedule_reduce) _reg.register_schedule("sum", _schedule_reduce)
_reg.register_schedule("max", _schedule_reduce) _reg.register_schedule("max", _schedule_reduce)
_reg.register_schedule("min", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce) _reg.register_schedule("prod", _schedule_reduce)
_reg.register_schedule("mean", _schedule_reduce) _reg.register_schedule("mean", _schedule_reduce)
...@@ -106,8 +106,11 @@ def test_where(): ...@@ -106,8 +106,11 @@ def test_where():
assert zz.checked_type == relay.TensorType((3, 4), "float32") assert zz.checked_type == relay.TensorType((3, 4), "float32")
def verify_reduce(test_func, data, axis, keepdims, exclude, output): def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
x = relay.var("x", relay.TensorType(data, "float32")) test_func = funcs[0]
ref_func = funcs[1]
x = relay.var("x", relay.TensorType(data, dtype))
z = test_func(x, axis, keepdims, exclude) z = test_func(x, axis, keepdims, exclude)
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
if axis: if axis:
...@@ -116,25 +119,60 @@ def verify_reduce(test_func, data, axis, keepdims, exclude, output): ...@@ -116,25 +119,60 @@ def verify_reduce(test_func, data, axis, keepdims, exclude, output):
assert "keepdims=" in z.astext() assert "keepdims=" in z.astext()
if exclude: if exclude:
assert "exclude=" in z.astext() assert "exclude=" in z.astext()
out_type = "int32" if test_func in [relay.argmin, relay.argmax] else "float32" out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype
assert zz.checked_type == relay.ty.TensorType(output, out_type) assert zz.checked_type == relay.ty.TensorType(output, out_type)
if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or len(output) == 0:
return
func = relay.Function([x], z)
x_data = np.random.uniform(size=data).astype(dtype)
if ref_func in [np.sum]:
ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims)
elif ref_func in [np.max, np.min, np.mean, np.prod]:
ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
else: #argmin/argmax
if axis and len(axis) > 1:
return
ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
def test_reduce_functions(): def test_reduce_functions():
def _with_keepdims(func):
def _wrapper(data, axis=None, keepdims=False):
if not keepdims:
return func(data, axis=axis)
else:
if axis is not None:
axis = axis[0]
out_shape = list(data.shape)
out_shape[axis] = 1
else:
out_shape = [1 for _ in range(len(data.shape))]
return func(data, axis=axis).reshape(out_shape)
return _wrapper
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
for func in [relay.sum, for func in [[relay.sum, np.sum],
relay.max, [relay.max, np.max],
relay.min, [relay.min, np.min],
relay.mean, [relay.mean, np.mean],
relay.prod, [relay.prod, np.prod],
relay.argmin, [relay.argmin, _with_keepdims(np.argmin)],
relay.argmax]: [relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4)) verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3), (1,), True, False, (d1, 1, d3)) verify_reduce(func, (d1, d2, d3), (1,), True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1)) verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3)) verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ()) verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, True, False, (1, 1, 1))
verify_reduce(func, (4, 4, 3), None, False, True, ()) verify_reduce(func, (4, 4, 3), None, False, True, ())
verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,)) verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,)) verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment