Commit 85c545c7 by masahi Committed by Tianqi Chen

Add rocm target to topi tests (#548)

* add masahi to contributors

* enable rocm target in topi tests
parent 74b0ca86
...@@ -34,3 +34,4 @@ List of Contributors ...@@ -34,3 +34,4 @@ List of Contributors
- To contributors: please add your name to the list. - To contributors: please add your name to the list.
- [Qiao Zhang](https://github.com/zhangqiaorjc) - [Qiao Zhang](https://github.com/zhangqiaorjc)
- [Jian Weng](https://github.com/were) - [Jian Weng](https://github.com/were)
- [Masahiro Masuda](https://github.com/masahi)
...@@ -13,7 +13,7 @@ def verify_broadcast_to_ele(in_shape, out_shape): ...@@ -13,7 +13,7 @@ def verify_broadcast_to_ele(in_shape, out_shape):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="broadcast_to") foo = tvm.build(s, [A, B], device, name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype) data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape) out_npy = np.broadcast_to(data_npy, out_shape)
...@@ -27,6 +27,7 @@ def verify_broadcast_to_ele(in_shape, out_shape): ...@@ -27,6 +27,7 @@ def verify_broadcast_to_ele(in_shape, out_shape):
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm")
def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
...@@ -52,7 +53,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): ...@@ -52,7 +53,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ) foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype) lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype) rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype)
...@@ -81,7 +82,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): ...@@ -81,7 +82,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm")
def test_broadcast_to(): def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,)) verify_broadcast_to_ele((1,), (10,))
......
...@@ -34,14 +34,14 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -34,14 +34,14 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
with tvm.build_config(auto_unroll_max_step=32, with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0, auto_unroll_min_depth=0,
unroll_explicit=False): unroll_explicit=device == 'rocm'):
func1 = tvm.build(s1, [A, W, B], device) func1 = tvm.build(s1, [A, W, B], device)
func2 = tvm.build(s2, [A, W, C], device) func2 = tvm.build(s2, [A, W, C], device)
func1(a, w, b) func1(a, w, b)
...@@ -49,7 +49,7 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -49,7 +49,7 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']: for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device) check_device(device)
......
...@@ -35,14 +35,14 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -35,14 +35,14 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
with tvm.build_config(auto_unroll_max_step=32, with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0, auto_unroll_min_depth=0,
unroll_explicit=False): unroll_explicit=device == 'rocm'):
func1 = tvm.build(s1, [A, W, B], device) func1 = tvm.build(s1, [A, W, B], device)
func2 = tvm.build(s2, [A, W, C], device) func2 = tvm.build(s2, [A, W, C], device)
func1(a, w, b) func1(a, w, b)
...@@ -50,7 +50,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -50,7 +50,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']: for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device) check_device(device)
......
...@@ -33,7 +33,7 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): ...@@ -33,7 +33,7 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx) b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx) c = tvm.nd.array(c_np, ctx)
...@@ -42,7 +42,7 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): ...@@ -42,7 +42,7 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
f(a, b, c, d) f(a, b, c, d)
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']: for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device) check_device(device)
def test_dense(): def test_dense():
......
...@@ -87,7 +87,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -87,7 +87,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm")
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
in_width = in_height in_width = in_height
filter_channel = in_channel filter_channel = in_channel
...@@ -170,7 +171,7 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -170,7 +171,7 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm")
def test_depthwise_conv2d(): def test_depthwise_conv2d():
print("testing nchw") print("testing nchw")
......
...@@ -83,7 +83,7 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli ...@@ -83,7 +83,7 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm")
def test_topi_depthwise_conv2d_backward_input_nhwc(): def test_topi_depthwise_conv2d_backward_input_nhwc():
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1) verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1)
......
...@@ -76,7 +76,7 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl ...@@ -76,7 +76,7 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm")
def test_topi_depthwise_conv2d_backward_weight_nhwc(): def test_topi_depthwise_conv2d_backward_weight_nhwc():
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 1) verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 1)
......
...@@ -36,14 +36,14 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type): ...@@ -36,14 +36,14 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device) f = tvm.build(s, [A, B], device)
f(a, b) f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']: for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device) check_device(device)
def test_pool(): def test_pool():
...@@ -70,14 +70,14 @@ def verify_global_pool(n, c, h, w, pool_type): ...@@ -70,14 +70,14 @@ def verify_global_pool(n, c, h, w, pool_type):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
f = tvm.build(s, [A, B], device) f = tvm.build(s, [A, B], device)
f(a, b) f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']: for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device) check_device(device)
def test_global_pool(): def test_global_pool():
......
...@@ -50,7 +50,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -50,7 +50,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="sum") foo = tvm.build(s, [A, B], device, name="sum")
# Test # Test
in_npy = np.random.uniform(size=in_shape).astype(np.float32) in_npy = np.random.uniform(size=in_shape).astype(np.float32)
...@@ -76,7 +76,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -76,7 +76,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm")
def test_reduce_map(): def test_reduce_map():
verify_reduce_map_ele(in_shape=(128, 24, 128, 24), verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
......
...@@ -17,14 +17,14 @@ def verify_relu(m, n): ...@@ -17,14 +17,14 @@ def verify_relu(m, n):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="relu") foo = tvm.build(s, [A, B], device, name="relu")
foo(a, b) foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']: for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device) check_device(device)
......
...@@ -21,14 +21,14 @@ def verify_softmax(m, n): ...@@ -21,14 +21,14 @@ def verify_softmax(m, n):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="softmax") foo = tvm.build(s, [A, B], device, name="softmax")
foo(a, b) foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']: for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device) check_device(device)
def test_softmax(): def test_softmax():
...@@ -52,14 +52,14 @@ def verify_log_softmax(m, n): ...@@ -52,14 +52,14 @@ def verify_log_softmax(m, n):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="log_softmax") foo = tvm.build(s, [A, B], device, name="log_softmax")
foo(a, b) foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']: for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device) check_device(device)
def test_log_softmax(): def test_log_softmax():
......
...@@ -11,7 +11,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): ...@@ -11,7 +11,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="expand_dims") foo = tvm.build(s, [A, B], device, name="expand_dims")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype) data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = data_npy.reshape(out_shape) out_npy = data_npy.reshape(out_shape)
...@@ -23,6 +23,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): ...@@ -23,6 +23,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm")
def verify_tranpose(in_shape, axes): def verify_tranpose(in_shape, axes):
...@@ -33,7 +34,7 @@ def verify_tranpose(in_shape, axes): ...@@ -33,7 +34,7 @@ def verify_tranpose(in_shape, axes):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="tranpose") foo = tvm.build(s, [A, B], device, name="tranpose")
data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype) data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
out_npy = data_npy.transpose(axes) out_npy = data_npy.transpose(axes)
...@@ -45,7 +46,7 @@ def verify_tranpose(in_shape, axes): ...@@ -45,7 +46,7 @@ def verify_tranpose(in_shape, axes):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm")
def verify_reshape(src_shape, dst_shape): def verify_reshape(src_shape, dst_shape):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
...@@ -55,7 +56,7 @@ def verify_reshape(src_shape, dst_shape): ...@@ -55,7 +56,7 @@ def verify_reshape(src_shape, dst_shape):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="reshape") foo = tvm.build(s, [A, B], device, name="reshape")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.reshape(data_npy, newshape=dst_shape) out_npy = np.reshape(data_npy, newshape=dst_shape)
...@@ -67,7 +68,7 @@ def verify_reshape(src_shape, dst_shape): ...@@ -67,7 +68,7 @@ def verify_reshape(src_shape, dst_shape):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm")
def verify_squeeze(src_shape, axis): def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
...@@ -77,7 +78,7 @@ def verify_squeeze(src_shape, axis): ...@@ -77,7 +78,7 @@ def verify_squeeze(src_shape, axis):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="squeeze") foo = tvm.build(s, [A, B], device, name="squeeze")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis) out_npy = np.squeeze(data_npy, axis=axis)
...@@ -93,7 +94,7 @@ def verify_squeeze(src_shape, axis): ...@@ -93,7 +94,7 @@ def verify_squeeze(src_shape, axis):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm")
def verify_concatenate(shapes, axis): def verify_concatenate(shapes, axis):
tensor_l = [] tensor_l = []
...@@ -105,7 +106,7 @@ def verify_concatenate(shapes, axis): ...@@ -105,7 +106,7 @@ def verify_concatenate(shapes, axis):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate") foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes] data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
out_npy = np.concatenate(data_npys, axis=axis) out_npy = np.concatenate(data_npys, axis=axis)
...@@ -117,7 +118,7 @@ def verify_concatenate(shapes, axis): ...@@ -117,7 +118,7 @@ def verify_concatenate(shapes, axis):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm")
def verify_split(src_shape, indices_or_sections, axis): def verify_split(src_shape, indices_or_sections, axis):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
...@@ -127,7 +128,7 @@ def verify_split(src_shape, indices_or_sections, axis): ...@@ -127,7 +128,7 @@ def verify_split(src_shape, indices_or_sections, axis):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A] + tensor_l, device, name="split") foo = tvm.build(s, [A] + tensor_l, device, name="split")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npys = np.split(data_npy, indices_or_sections, axis=axis) out_npys = np.split(data_npy, indices_or_sections, axis=axis)
...@@ -140,7 +141,8 @@ def verify_split(src_shape, indices_or_sections, axis): ...@@ -140,7 +141,8 @@ def verify_split(src_shape, indices_or_sections, axis):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm")
def test_expand_dims(): def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2) verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), (1, 3, 10), -3, 1) verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment