import tvm import numpy as np def test_sum(): # graph n = tvm.Var('n') m = tvm.Var('m') A = tvm.placeholder((n, m), name='A') k = tvm.IterVar((0, m)) B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B') # schedule s = tvm.Schedule(B.op) # create iter var and assign them tags. num_thread = 1 block_x = tvm.IterVar(thread_tag="blockIdx.x") thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") _, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x) _, x = s[B].split(x, outer=thread_x) tvm.init_opencl() codes = [] fsum = tvm.build(s, args=[A, B], target="opencl", name="myadd", record_codes=codes) for c in codes: print(c) num_device = 1 for i in range(num_device): ctx = tvm.opencl(i) if not ctx.enabled: continue # launch the kernel. n = 1028 m = 129 #a = tvm.nd.array(np.zeros((n, m)).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx) fsum(a, b) np.testing.assert_allclose( b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4) if __name__ == "__main__": test_sum()