"""Example code to do square matrix multiplication on Android Phone.""" import tvm import os from tvm.contrib import rpc, util, ndk import numpy as np # Set to be address of tvm proxy. proxy_host = os.environ["TVM_ANDROID_RPC_PROXY_HOST"] proxy_port = 9090 key = "android" # Change target configuration. # Run `adb shell cat /proc/cpuinfo` to find the arch. arch = "arm64" target = "llvm -target=%s-linux-android" % arch def ngflops(N): return 2.0 * float(N * N * N) / (10**9) dtype = 'float32' def evaluate(func, ctx, N, times): a_np = np.random.uniform(size=(N, N)).astype(dtype) b_np = np.random.uniform(size=(N, N)).astype(dtype) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx) time_f = func.time_evaluator(func.entry_name, ctx, number=times) cost = time_f(a, b, c).mean gf = ngflops(N) / cost print('%g secs/op, %g GFLOPS' % (cost, gf)) np.testing.assert_almost_equal(c.asnumpy(), a_np.dot(b_np), decimal=2) def test_gemm_gpu(N, times, bn, num_block, num_thread): assert(bn <= N) assert(num_thread * num_thread * 16 <= N) assert(num_block * num_block * 2 <= N) A = tvm.placeholder((N, N), name='A') B = tvm.placeholder((N, N), name='Btmp') k = tvm.reduce_axis((0, N), name='k') packedB = tvm.compute((N, N / bn, bn), lambda x, y, z: B[x, y * bn + z], name = 'B') C = tvm.compute( (N, N), lambda ii, jj: tvm.sum(A[ii, k] * packedB[k, jj / bn, jj % bn], axis=k), name='C') s = tvm.create_schedule(C.op) CC = s.cache_write(C, "local") block_x = tvm.thread_axis("blockIdx.x") block_y = tvm.thread_axis("blockIdx.y") thread_x = tvm.thread_axis("threadIdx.x") thread_y = tvm.thread_axis("threadIdx.y") thread_xz = tvm.thread_axis((0, 2), "vthread", name="vx") thread_yz = tvm.thread_axis((0, 2), "vthread", name="vy") pby, pbi = s[packedB].split(packedB.op.axis[0], nparts=num_thread) pbx, pbj = s[packedB].split(packedB.op.axis[1], nparts=num_thread) s[packedB].bind(pby, thread_y) s[packedB].bind(pbx, thread_x) pbz, pbk = s[packedB].split(packedB.op.axis[2], factor=8) s[packedB].vectorize(pbk) by, yi = s[C].split(C.op.axis[0], nparts=num_block) bx, xi = s[C].split(C.op.axis[1], nparts=num_thread) s[C].bind(by, block_y) s[C].bind(bx, thread_y) s[C].reorder(by, bx, yi, xi) tyz, yi = s[C].split(yi, nparts=2) ty, yi = s[C].split(yi, nparts=num_block) txz, xi = s[C].split(xi, nparts=2) tx, xi = s[C].split(xi, nparts=num_thread) s[C].reorder(tyz, txz, ty, tx, yi, xi) s[C].bind(tyz, thread_yz) s[C].bind(txz, thread_xz) s[C].bind(ty, block_x) s[C].bind(tx, thread_x) xyi, xxi = s[C].split(xi, factor=8) s[C].reorder(tyz, txz, ty, tx, yi, xyi, xxi) s[C].vectorize(xxi) s[CC].compute_at(s[C], yi) yo, xo = CC.op.axis s[CC].reorder(k, yo, xo) xo, xi = s[CC].split(xo, factor=8) s[CC].vectorize(xi) ko, ki = s[CC].split(k, factor=2) s[CC].unroll(ki) print(tvm.lower(s, [A, B, C], simple_mode=True)) f = tvm.build(s, [A, B, C], "opencl", target_host=target, name="gemm_gpu") temp = util.tempdir() path_dso = temp.relpath("gemm_gpu.so") f.export_library(path_dso, ndk.create_shared) # connect to the proxy remote = rpc.connect(proxy_host, proxy_port, key=key) ctx = remote.cl(0) remote.upload(path_dso) f = remote.load_module("gemm_gpu.so") evaluate(f, ctx, N, times) if __name__ == "__main__": test_gemm_gpu(1024, times=5, bn=8, num_block=2, num_thread=8)