Commit e3eff20d by Yong Wu Committed by Yao Wang

[Relay] shape func for zeros, zeros_like, ones, ones_like (#4448)

parent f931fe1f
...@@ -122,6 +122,20 @@ def cast_shape_func(attrs, inputs, out_ndims): ...@@ -122,6 +122,20 @@ def cast_shape_func(attrs, inputs, out_ndims):
# shape func # shape func
@script @script
def _full_shape_func(x):
out_ndim = len(x)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = x[i]
return out
def full_shape_func(attrs, inputs, out_ndims):
"""
Shape func for zeros, zeros_like, ones, ones_like.
"""
return [_full_shape_func(*inputs)]
@script
def _broadcast_shape_func(x, y, ndim): def _broadcast_shape_func(x, y, ndim):
out = output_tensor((ndim,), "int64") out = output_tensor((ndim,), "int64")
if len(x.shape) == 0: if len(x.shape) == 0:
...@@ -162,6 +176,10 @@ def elemwise_shape_func(attrs, inputs, _): ...@@ -162,6 +176,10 @@ def elemwise_shape_func(attrs, inputs, _):
return [topi.math.identity(inputs[0])] return [topi.math.identity(inputs[0])]
register_shape_func("cast", False, cast_shape_func) register_shape_func("cast", False, cast_shape_func)
register_shape_func("zeros", False, full_shape_func)
register_shape_func("zeros_like", False, full_shape_func)
register_shape_func("ones", False, full_shape_func)
register_shape_func("ones_like", False, full_shape_func)
register_shape_func("add", False, broadcast_shape_func) register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func) register_shape_func("subtract", False, broadcast_shape_func)
......
...@@ -41,10 +41,11 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): ...@@ -41,10 +41,11 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
mod["main"] = relay.Function([x, y], op(x, y)) mod["main"] = relay.Function([x, y], op(x, y))
x_np = np.random.uniform(size=x_np_shape).astype(dtype) x_np = np.random.uniform(size=x_np_shape).astype(dtype)
y_np = np.random.uniform(size=y_np_shape).astype(dtype) y_np = np.random.uniform(size=y_np_shape).astype(dtype)
res_np = np_op(x_np, y_np)
for kind in ["debug", "vm"]: for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np) result = ex.evaluate()(x_np, y_np)
tvm.testing.assert_allclose(result.asnumpy(), np_op(x_np, y_np)) tvm.testing.assert_allclose(result.asnumpy(), res_np)
def test_any_broadcast(): def test_any_broadcast():
# Test broadcast with 1s # Test broadcast with 1s
...@@ -77,6 +78,32 @@ def test_any_broadcast_fail(): ...@@ -77,6 +78,32 @@ def test_any_broadcast_fail():
check_fail((relay.Any(),), (3, 2), (2), (4, 2), relay.add, np.add) check_fail((relay.Any(),), (3, 2), (2), (4, 2), relay.add, np.add)
def verify_any_full(x_shape, x_np_shape, relay_op, np_op, dtype='float32'):
x = relay.var('x', shape=x_shape, dtype=dtype)
mod = relay.module.Module()
mod['main'] = relay.Function([x], relay.zeros_like(x))
x_np = np.random.uniform(size=x_np_shape).astype(dtype)
res_np = np.zeros_like(x_np)
for kind in ['debug', 'vm']:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm')
result = ex.evaluate()(x_np).asnumpy()
tvm.testing.assert_allclose(result, res_np)
def test_any_full():
# zeros, zeros_like, ones, ones_like
verify_any_full(any_dims(3), (2, 3, 5), relay.zeros, np.zeros, "float32")
verify_any_full(any_dims(3), (225, 115, 15), relay.zeros, np.zeros, "float32")
verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros, np.zeros, "int32")
verify_any_full(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like, "float32")
verify_any_full(any_dims(3), (225, 115, 15), relay.zeros_like, np.zeros_like, "float32")
verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like, np.zeros_like, "int32")
verify_any_full(any_dims(3), (2, 3, 5), relay.ones, np.ones, "float32")
verify_any_full(any_dims(3), (225, 115, 15), relay.ones, np.ones, "float32")
verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones, np.ones, "int32")
verify_any_full(any_dims(3), (2, 3, 5), relay.ones_like, np.ones_like, "float32")
verify_any_full(any_dims(3), (225, 115, 15), relay.ones_like, np.ones_like, "float32")
verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like, np.ones_like, "int32")
def test_any_concat(): def test_any_concat():
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32") x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
y = relay.var('y', shape=(1, 2), dtype="float32") y = relay.var('y', shape=(1, 2), dtype="float32")
...@@ -85,10 +112,10 @@ def test_any_concat(): ...@@ -85,10 +112,10 @@ def test_any_concat():
mod["main"] = relay.Function([x, y], z) mod["main"] = relay.Function([x, y], z)
x_np = np.random.uniform(size=(3, 2)).astype('float32') x_np = np.random.uniform(size=(3, 2)).astype('float32')
y_np = np.random.uniform(size=(1, 2)).astype('float32') y_np = np.random.uniform(size=(1, 2)).astype('float32')
ref = np.concatenate([x_np, y_np], axis=0)
for kind in ["debug", "vm"]: for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np) result = ex.evaluate()(x_np, y_np)
ref = np.concatenate([x_np, y_np], axis=0)
tvm.testing.assert_allclose(result.asnumpy(), ref) tvm.testing.assert_allclose(result.asnumpy(), ref)
def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape): def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape):
...@@ -116,10 +143,10 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): ...@@ -116,10 +143,10 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
mod = relay.module.Module() mod = relay.module.Module()
mod["main"] = relay.Function([x], y) mod["main"] = relay.Function([x], y)
data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype) data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype)
expected = np.argwhere(data)
for kind in ["debug", "vm"]: for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data).asnumpy() result = ex.evaluate()(data).asnumpy()
expected = np.argwhere(data)
assert result.shape == expected.shape assert result.shape == expected.shape
tvm.testing.assert_allclose(result.flatten(), expected.flatten()) tvm.testing.assert_allclose(result.flatten(), expected.flatten())
...@@ -412,10 +439,10 @@ def verify_any_pad(data_shape, pad_width, static_data_shape): ...@@ -412,10 +439,10 @@ def verify_any_pad(data_shape, pad_width, static_data_shape):
y = relay.nn.pad(data, pad_width) y = relay.nn.pad(data, pad_width)
mod["main"] = relay.Function([data], y) mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype) data_np = np.random.uniform(size=static_data_shape).astype(dtype)
ref_out = np.pad(data_np, pad_width)
for kind in ["debug", "vm"]: for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np) result = ex.evaluate()(data_np)
ref_out = np.pad(data_np, pad_width)
tvm.testing.assert_allclose(result.asnumpy(), ref_out) tvm.testing.assert_allclose(result.asnumpy(), ref_out)
def test_any_pad(): def test_any_pad():
...@@ -497,12 +524,12 @@ def test_recursive_concat(): ...@@ -497,12 +524,12 @@ def test_recursive_concat():
mod = relay.module.Module() mod = relay.module.Module()
mod["main"] = func mod["main"] = func
data = np.array(0.0, dtype='int32') data = np.array(0.0, dtype='int32')
ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
# TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail # TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail
# so currently we cannot run this test case on VM # so currently we cannot run this test case on VM
for kind in ["debug"]: for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data) result = ex.evaluate()(data)
ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
np.testing.assert_allclose(result.asnumpy(), ref) np.testing.assert_allclose(result.asnumpy(), ref)
def test_recursive_concat_with_wrong_annotation(): def test_recursive_concat_with_wrong_annotation():
...@@ -553,6 +580,7 @@ def test_recursive_concat_with_wrong_annotation(): ...@@ -553,6 +580,7 @@ def test_recursive_concat_with_wrong_annotation():
assert "in particular dimension 0 conflicts 2 does not match 1" in str(e) assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
if __name__ == "__main__": if __name__ == "__main__":
test_any_full()
test_any_broadcast() test_any_broadcast()
test_any_broadcast_fail() test_any_broadcast_fail()
test_any_concat() test_any_concat()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment