Commit b5e0d790 by Siju Committed by Tianqi Chen

[RELAY]sch and compute for reduce ops (#2091)

parent b9038343
......@@ -15,5 +15,6 @@ _reg.register_schedule("argmax", _schedule_reduce)
_reg.register_schedule("argmin", _schedule_reduce)
_reg.register_schedule("sum", _schedule_reduce)
_reg.register_schedule("max", _schedule_reduce)
_reg.register_schedule("min", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce)
_reg.register_schedule("mean", _schedule_reduce)
......@@ -106,8 +106,11 @@ def test_where():
assert zz.checked_type == relay.TensorType((3, 4), "float32")
def verify_reduce(test_func, data, axis, keepdims, exclude, output):
x = relay.var("x", relay.TensorType(data, "float32"))
def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
test_func = funcs[0]
ref_func = funcs[1]
x = relay.var("x", relay.TensorType(data, dtype))
z = test_func(x, axis, keepdims, exclude)
zz = relay.ir_pass.infer_type(z)
if axis:
......@@ -116,25 +119,60 @@ def verify_reduce(test_func, data, axis, keepdims, exclude, output):
assert "keepdims=" in z.astext()
if exclude:
assert "exclude=" in z.astext()
out_type = "int32" if test_func in [relay.argmin, relay.argmax] else "float32"
out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype
assert zz.checked_type == relay.ty.TensorType(output, out_type)
if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or len(output) == 0:
return
func = relay.Function([x], z)
x_data = np.random.uniform(size=data).astype(dtype)
if ref_func in [np.sum]:
ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims)
elif ref_func in [np.max, np.min, np.mean, np.prod]:
ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
else: #argmin/argmax
if axis and len(axis) > 1:
return
ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
def test_reduce_functions():
def _with_keepdims(func):
def _wrapper(data, axis=None, keepdims=False):
if not keepdims:
return func(data, axis=axis)
else:
if axis is not None:
axis = axis[0]
out_shape = list(data.shape)
out_shape[axis] = 1
else:
out_shape = [1 for _ in range(len(data.shape))]
return func(data, axis=axis).reshape(out_shape)
return _wrapper
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
for func in [relay.sum,
relay.max,
relay.min,
relay.mean,
relay.prod,
relay.argmin,
relay.argmax]:
for func in [[relay.sum, np.sum],
[relay.max, np.max],
[relay.min, np.min],
[relay.mean, np.mean],
[relay.prod, np.prod],
[relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3), (1,), True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, True, False, (1, 1, 1))
verify_reduce(func, (4, 4, 3), None, False, True, ())
verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment