Commit e3eff20d by Yong Wu Committed by Yao Wang

[Relay] shape func for zeros, zeros_like, ones, ones_like (#4448)

parent f931fe1f
......@@ -122,6 +122,20 @@ def cast_shape_func(attrs, inputs, out_ndims):
# shape func
@script
def _full_shape_func(x):
out_ndim = len(x)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = x[i]
return out
def full_shape_func(attrs, inputs, out_ndims):
"""
Shape func for zeros, zeros_like, ones, ones_like.
"""
return [_full_shape_func(*inputs)]
@script
def _broadcast_shape_func(x, y, ndim):
out = output_tensor((ndim,), "int64")
if len(x.shape) == 0:
......@@ -162,6 +176,10 @@ def elemwise_shape_func(attrs, inputs, _):
return [topi.math.identity(inputs[0])]
register_shape_func("cast", False, cast_shape_func)
register_shape_func("zeros", False, full_shape_func)
register_shape_func("zeros_like", False, full_shape_func)
register_shape_func("ones", False, full_shape_func)
register_shape_func("ones_like", False, full_shape_func)
register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
......
......@@ -41,10 +41,11 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
mod["main"] = relay.Function([x, y], op(x, y))
x_np = np.random.uniform(size=x_np_shape).astype(dtype)
y_np = np.random.uniform(size=y_np_shape).astype(dtype)
res_np = np_op(x_np, y_np)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np)
tvm.testing.assert_allclose(result.asnumpy(), np_op(x_np, y_np))
tvm.testing.assert_allclose(result.asnumpy(), res_np)
def test_any_broadcast():
# Test broadcast with 1s
......@@ -77,6 +78,32 @@ def test_any_broadcast_fail():
check_fail((relay.Any(),), (3, 2), (2), (4, 2), relay.add, np.add)
def verify_any_full(x_shape, x_np_shape, relay_op, np_op, dtype='float32'):
x = relay.var('x', shape=x_shape, dtype=dtype)
mod = relay.module.Module()
mod['main'] = relay.Function([x], relay.zeros_like(x))
x_np = np.random.uniform(size=x_np_shape).astype(dtype)
res_np = np.zeros_like(x_np)
for kind in ['debug', 'vm']:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm')
result = ex.evaluate()(x_np).asnumpy()
tvm.testing.assert_allclose(result, res_np)
def test_any_full():
# zeros, zeros_like, ones, ones_like
verify_any_full(any_dims(3), (2, 3, 5), relay.zeros, np.zeros, "float32")
verify_any_full(any_dims(3), (225, 115, 15), relay.zeros, np.zeros, "float32")
verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros, np.zeros, "int32")
verify_any_full(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like, "float32")
verify_any_full(any_dims(3), (225, 115, 15), relay.zeros_like, np.zeros_like, "float32")
verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like, np.zeros_like, "int32")
verify_any_full(any_dims(3), (2, 3, 5), relay.ones, np.ones, "float32")
verify_any_full(any_dims(3), (225, 115, 15), relay.ones, np.ones, "float32")
verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones, np.ones, "int32")
verify_any_full(any_dims(3), (2, 3, 5), relay.ones_like, np.ones_like, "float32")
verify_any_full(any_dims(3), (225, 115, 15), relay.ones_like, np.ones_like, "float32")
verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like, np.ones_like, "int32")
def test_any_concat():
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
y = relay.var('y', shape=(1, 2), dtype="float32")
......@@ -85,10 +112,10 @@ def test_any_concat():
mod["main"] = relay.Function([x, y], z)
x_np = np.random.uniform(size=(3, 2)).astype('float32')
y_np = np.random.uniform(size=(1, 2)).astype('float32')
ref = np.concatenate([x_np, y_np], axis=0)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np)
ref = np.concatenate([x_np, y_np], axis=0)
tvm.testing.assert_allclose(result.asnumpy(), ref)
def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape):
......@@ -116,10 +143,10 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
mod = relay.module.Module()
mod["main"] = relay.Function([x], y)
data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype)
expected = np.argwhere(data)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data).asnumpy()
expected = np.argwhere(data)
assert result.shape == expected.shape
tvm.testing.assert_allclose(result.flatten(), expected.flatten())
......@@ -412,10 +439,10 @@ def verify_any_pad(data_shape, pad_width, static_data_shape):
y = relay.nn.pad(data, pad_width)
mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
ref_out = np.pad(data_np, pad_width)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np)
ref_out = np.pad(data_np, pad_width)
tvm.testing.assert_allclose(result.asnumpy(), ref_out)
def test_any_pad():
......@@ -497,12 +524,12 @@ def test_recursive_concat():
mod = relay.module.Module()
mod["main"] = func
data = np.array(0.0, dtype='int32')
ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
# TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail
# so currently we cannot run this test case on VM
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
np.testing.assert_allclose(result.asnumpy(), ref)
def test_recursive_concat_with_wrong_annotation():
......@@ -553,6 +580,7 @@ def test_recursive_concat_with_wrong_annotation():
assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
if __name__ == "__main__":
test_any_full()
test_any_broadcast()
test_any_broadcast_fail()
test_any_concat()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment