Commit 35485307 by Yizhi Liu Committed by Tianqi Chen

android gemm for topi/recipe (#628)

parent 8fea0879
"""Example code to do square matrix multiplication on Android Phone."""
import tvm
import os
from tvm.contrib import rpc, util, ndk
import numpy as np
# Set to be address of tvm proxy.
proxy_host = os.environ["TVM_ANDROID_RPC_PROXY_HOST"]
proxy_port = 9090
key = "android"
# Change target configuration.
# Run `adb shell cat /proc/cpuinfo` to find the arch.
arch = "arm64"
target = "llvm -target=%s-linux-android" % arch
def ngflops(N):
return 2.0 * float(N * N * N) / (10**9)
dtype = 'float32'
def evaluate(func, ctx, N, times):
a_np = np.random.uniform(size=(N, N)).astype(dtype)
b_np = np.random.uniform(size=(N, N)).astype(dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx)
time_f = func.time_evaluator(func.entry_name, ctx, number=times)
cost = time_f(a, b, c).mean
gf = ngflops(N) / cost
print('%g secs/op, %g GFLOPS' % (cost, gf))
np.testing.assert_almost_equal(c.asnumpy(), a_np.dot(b_np), decimal=2)
def test_gemm_gpu(N, times, bn, num_block, num_thread):
assert(bn <= N)
assert(num_thread * num_thread * 16 <= N)
assert(num_block * num_block * 2 <= N)
A = tvm.placeholder((N, N), name='A')
B = tvm.placeholder((N, N), name='Btmp')
k = tvm.reduce_axis((0, N), name='k')
packedB = tvm.compute((N, N / bn, bn),
lambda x, y, z: B[x, y * bn + z], name = 'B')
C = tvm.compute(
(N, N),
lambda ii, jj: tvm.sum(A[ii, k] * packedB[k, jj / bn, jj % bn], axis=k),
name='C')
s = tvm.create_schedule(C.op)
CC = s.cache_write(C, "local")
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_x = tvm.thread_axis("threadIdx.x")
thread_y = tvm.thread_axis("threadIdx.y")
thread_xz = tvm.thread_axis((0, 2), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, 2), "vthread", name="vy")
pby, pbi = s[packedB].split(packedB.op.axis[0], nparts=num_thread)
pbx, pbj = s[packedB].split(packedB.op.axis[1], nparts=num_thread)
s[packedB].bind(pby, thread_y)
s[packedB].bind(pbx, thread_x)
pbz, pbk = s[packedB].split(packedB.op.axis[2], factor=8)
s[packedB].vectorize(pbk)
by, yi = s[C].split(C.op.axis[0], nparts=num_block)
bx, xi = s[C].split(C.op.axis[1], nparts=num_thread)
s[C].bind(by, block_y)
s[C].bind(bx, thread_y)
s[C].reorder(by, bx, yi, xi)
tyz, yi = s[C].split(yi, nparts=2)
ty, yi = s[C].split(yi, nparts=num_block)
txz, xi = s[C].split(xi, nparts=2)
tx, xi = s[C].split(xi, nparts=num_thread)
s[C].reorder(tyz, txz, ty, tx, yi, xi)
s[C].bind(tyz, thread_yz)
s[C].bind(txz, thread_xz)
s[C].bind(ty, block_x)
s[C].bind(tx, thread_x)
xyi, xxi = s[C].split(xi, factor=8)
s[C].reorder(tyz, txz, ty, tx, yi, xyi, xxi)
s[C].vectorize(xxi)
s[CC].compute_at(s[C], yi)
yo, xo = CC.op.axis
s[CC].reorder(k, yo, xo)
xo, xi = s[CC].split(xo, factor=8)
s[CC].vectorize(xi)
ko, ki = s[CC].split(k, factor=2)
s[CC].unroll(ki)
print(tvm.lower(s, [A, B, C], simple_mode=True))
f = tvm.build(s, [A, B, C], "opencl", target_host=target, name="gemm_gpu")
temp = util.tempdir()
path_dso = temp.relpath("gemm_gpu.so")
f.export_library(path_dso, ndk.create_shared)
# connect to the proxy
remote = rpc.connect(proxy_host, proxy_port, key=key)
ctx = remote.cl(0)
remote.upload(path_dso)
f = remote.load_module("gemm_gpu.so")
evaluate(f, ctx, N, times)
if __name__ == "__main__":
test_gemm_gpu(1024, times=5, bn=8, num_block=2, num_thread=8)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment