"""Test code for broadcasting operators.""" import numpy as np import tvm import topi def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): A = tvm.placeholder(shape=in_shape, name="A") B = topi.expand_dims(A, axis, num_newaxis) s = topi.cuda.schedule_broadcast(B) def check_device(device): if not tvm.module.enabled(device): print("Skip because %s is not enabled" % device) return ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) foo = tvm.build(s, [A, B], device, name="expand_dims") data_npy = np.random.uniform(size=in_shape).astype(A.dtype) out_npy = data_npy.reshape(out_shape) data_nd = tvm.nd.array(data_npy, ctx) out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx) foo(data_nd, out_nd) np.testing.assert_allclose(out_nd.asnumpy(), out_npy) check_device("opencl") check_device("cuda") check_device("metal") def verify_tranpose(in_shape, axes): A = tvm.placeholder(shape=in_shape, name="A") B = topi.transpose(A, axes) s = topi.cuda.schedule_injective(B) def check_device(device): if not tvm.module.enabled(device): print("Skip because %s is not enabled" % device) return ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) foo = tvm.build(s, [A, B], device, name="tranpose") data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype) out_npy = data_npy.transpose(axes) data_nd = tvm.nd.array(data_npy, ctx) out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype) foo(data_nd, out_nd) np.testing.assert_allclose(out_nd.asnumpy(), out_npy) check_device("cuda") check_device("opencl") check_device("metal") def test_expand_dims(): verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2) verify_expand_dims((3, 10), (1, 3, 10), -3, 1) def test_tranpose(): verify_tranpose((3, 10, 2), (1, 0, 2)) verify_tranpose((3, 10, 5), (2, 0, 1)) verify_tranpose((3, 10), None) if __name__ == "__main__": test_tranpose() test_expand_dims()