importtvmimportnumpyasnpfromtvm.contribimportrocblasdeftest_matmul_add():n=1024l=128m=235A=tvm.placeholder((n,l),name='A')B=tvm.placeholder((l,m),name='B')C=rocblas.matmul(A,B)s=tvm.create_schedule(C.op)defverify(target="rocm"):ifnottvm.module.enabled(target):print("skip because %s is not enabled..."%target)returnifnottvm.get_global_func("tvm.contrib.rocblas.matmul",True):print("skip because extern function is not avalable")returnctx=tvm.rocm(0)f=tvm.build(s,[A,B,C],target)a=tvm.nd.array(np.random.uniform(size=(n,l)).astype(A.dtype),ctx)b=tvm.nd.array(np.random.uniform(size=(l,m)).astype(B.dtype),ctx)c=tvm.nd.array(np.zeros((n,m),dtype=C.dtype),ctx)f(a,b,c)np.testing.assert_allclose(c.asnumpy(),np.dot(a.asnumpy(),b.asnumpy()),rtol=1e-5)verify()if__name__=="__main__":test_matmul_add()