"""Test code for broadcasting operators.""" import os import numpy as np import tvm import topi def verify_broadcast_to_ele(in_shape, out_shape): # Build the logic and compile the function A = tvm.placeholder(shape=in_shape, name="A") B = topi.broadcast_to(A, out_shape) s = topi.cuda.schedule_broadcast_to(B) def check_device(device): if not tvm.module.enabled(device): print("Skip because %s is not enabled" % device) return ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) foo = tvm.build(s, [A, B], device, name="broadcast_to") data_npy = np.random.uniform(size=in_shape).astype(A.dtype) out_npy = np.broadcast_to(data_npy, out_shape) data_nd = tvm.nd.array(data_npy, ctx) out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx) for _ in range(1): foo(data_nd, out_nd) np.testing.assert_allclose(out_nd.asnumpy(), out_npy) check_device("opencl") check_device("cuda") check_device("metal") def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): # Build the logic and compile the function A = tvm.placeholder(shape=lhs_shape, name="A") B = tvm.placeholder(shape=rhs_shape, name="B") if typ == "add": C = topi.broadcast_add(A, B) elif typ == "sub": C = topi.broadcast_sub(A, B) elif typ == "div": C = topi.broadcast_div(A, B) elif typ == "mul": C = topi.broadcast_mul(A, B) elif typ == "maximum": C = topi.broadcast_maximum(A, B) elif typ == "minimum": C = topi.broadcast_minimum(A, B) else: raise NotImplementedError s = topi.cuda.schedule_broadcast_binary_op(C) def check_device(device): if not tvm.module.enabled(device): print("Skip because %s is not enabled" % device) return ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ) lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype) rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype) if typ == "add": out_npy = lhs_npy + rhs_npy elif typ == "sub": out_npy = lhs_npy - rhs_npy elif typ == "div": rhs_npy = np.abs(rhs_npy) + 0.001 out_npy = lhs_npy / rhs_npy elif typ == "mul": out_npy = lhs_npy * rhs_npy elif typ == "maximum": out_npy = np.maximum(lhs_npy, rhs_npy) elif typ == "minimum": out_npy = np.minimum(lhs_npy, rhs_npy) else: raise NotImplementedError lhs_nd = tvm.nd.array(lhs_npy, ctx) rhs_nd = tvm.nd.array(rhs_npy, ctx) out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx) for _ in range(1): foo(lhs_nd, rhs_nd, out_nd) np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4) check_device("opencl") check_device("cuda") check_device("metal") def test_broadcast_to(): verify_broadcast_to_ele((1,), (10,)) verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4)) verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32)) def test_broadcast_binary(): verify_broadcast_binary_ele((5, 2, 3), (2, 1), typ="add") verify_broadcast_binary_ele((5, 64, 128), (2, 5, 64, 1), typ="mul") verify_broadcast_binary_ele((2, 3, 1, 32), (64, 32), typ="div") verify_broadcast_binary_ele((1, 32), (64, 32), typ="sub") verify_broadcast_binary_ele((32,), (64, 32), typ="maximum") verify_broadcast_binary_ele((1, 2, 2, 1, 32), (64, 32), typ="minimum") if __name__ == "__main__": test_broadcast_to() test_broadcast_binary()