# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import tvm from tvm import te import numpy as np import topi.testing from tvm.contrib import cblas def verify_matmul_add(m, l, n, transa=False, transb=False, dtype="float32"): bias = te.var('bias', dtype=dtype) ashape = (l, n) if transa else (n, l) bshape = (m, l) if transb else (l, m) A = te.placeholder(ashape, name='A', dtype=dtype) B = te.placeholder(bshape, name='B', dtype=dtype) C = cblas.matmul(A, B, transa, transb) D = te.compute(C.shape, lambda i, j: C[i,j] + bias, name="D") s = te.create_schedule(D.op) def get_numpy(a, b, bb, transa, transb): if transa: a = a.transpose() if transb: b = b.transpose() return np.dot(a, b) + bb def verify(target="llvm"): if not tvm.runtime.enabled(target): print("skip because %s is not enabled..." % target) return if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): print("skip because extern function is not available") return ctx = tvm.cpu(0) f = tvm.build(s, [A, B, D, bias], target) a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx) d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx) bb = 10.0 f(a, b, d, bb) tvm.testing.assert_allclose( d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), bb, transa, transb), rtol=1e-5) verify() def test_matmul_add(): verify_matmul_add(235, 128, 1024) verify_matmul_add(235, 128, 1024, True, False) verify_matmul_add(235, 128, 1024, False, True) verify_matmul_add(235, 128, 1024, True, True) verify_matmul_add(1, 16, 4) verify_matmul_add(1, 16, 3, True, False) verify_matmul_add(1, 16, 3, False, False) verify_matmul_add(1, 16, 3, True, True) def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype="float32"): ashape = (batch, l, n) if transa else (batch, n, l) bshape = (batch, m, l) if transb else (batch, l, m) A = te.placeholder(ashape, name='A', dtype=dtype) B = te.placeholder(bshape, name='B', dtype=dtype) C = cblas.batch_matmul(A, B, transa, transb) D = te.compute(C.shape, lambda k, i, j: C[k, i,j], name="D") s = te.create_schedule(D.op) def get_numpy(a, b, transa, transb): if transa: a = a.transpose(0, 2, 1) if not transb: b = b.transpose(0, 2, 1) return topi.testing.batch_matmul(a, b) def verify(target="llvm"): if not tvm.runtime.enabled(target): print("skip because %s is not enabled..." % target) return if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): print("skip because extern function is not available") return ctx = tvm.cpu(0) f = tvm.build(s, [A, B, D], target) a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx) d = tvm.nd.array(np.zeros((batch, n, m), dtype=D.dtype), ctx) f(a, b, d) tvm.testing.assert_allclose( d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5) verify() def test_batch_matmul(): verify_batch_matmul(16, 235, 128, 1024) verify_batch_matmul(16, 235, 128, 1024, True, False) verify_batch_matmul(16, 235, 128, 1024, False, True) verify_batch_matmul(16, 235, 128, 1024, True, True) verify_batch_matmul(1, 1, 16, 3) verify_batch_matmul(1, 1, 16, 3, True, False) verify_batch_matmul(1, 1, 16, 3, False, False) verify_batch_matmul(1, 1, 16, 3, True, True) verify_batch_matmul(1, 1, 16, 3, iterative=True) if __name__ == "__main__": test_matmul_add() test_batch_matmul()