import tvm import numpy as np from tvm.contrib import util import vta.testing def test_gemm(): def run_gemm_packed(env, remote, batch_size, channel, block): data_shape = (batch_size // env.BATCH, channel // env.BLOCK_IN, env.BATCH, env.BLOCK_IN) weight_shape = (channel // env.BLOCK_OUT, channel // env.BLOCK_IN, env.BLOCK_OUT, env.BLOCK_IN) res_shape = (batch_size // env.BATCH, channel // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT) # To compute number of ops, use a x2 factor for FMA num_ops = 2 * channel * channel * batch_size ko = tvm.reduce_axis((0, channel // env.BLOCK_IN), name='ko') ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki') data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) weight = tvm.placeholder(weight_shape, name="weight", dtype=env.wgt_dtype) data_buf = tvm.compute(data_shape, lambda *i: data(*i), "data_buf") weight_buf = tvm.compute(weight_shape, lambda *i: weight(*i), "weight_buf") res_gem = tvm.compute(res_shape, lambda bo, co, bi, ci: tvm.sum( data_buf[bo, ko, bi, ki].astype(env.acc_dtype) * weight_buf[co, ko, ci, ki].astype(env.acc_dtype), axis=[ko, ki]), name="res_gem") res_shf = tvm.compute(res_shape, lambda *i: res_gem(*i)>>8, name="res_shf") res_max = tvm.compute(res_shape, lambda *i: tvm.max(res_shf(*i), 0), "res_max") #relu res_min = tvm.compute(res_shape, lambda *i: tvm.min(res_max(*i), (1<<(env.INP_WIDTH-1))-1), "res_min") #relu res = tvm.compute(res_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res") def verify(s, check_correctness=True): mod = vta.build(s, [data, weight, res], "ext_dev", env.target_host, name="gemm") temp = util.tempdir() mod.save(temp.relpath("gemm.o")) remote.upload(temp.relpath("gemm.o")) f = remote.load_module("gemm.o") # verify ctx = remote.ext_dev(0) # Data in original format data_orig = np.random.randint( -128, 128, size=(batch_size, channel)).astype(data.dtype) weight_orig = np.random.randint( -128, 128, size=(channel, channel)).astype(weight.dtype) data_packed = data_orig.reshape( batch_size // env.BATCH, env.BATCH, channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3)) weight_packed = weight_orig.reshape( channel // env.BLOCK_OUT, env.BLOCK_OUT, channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3)) res_np = np.zeros(res_shape).astype(res.dtype) data_arr = tvm.nd.array(data_packed, ctx) weight_arr = tvm.nd.array(weight_packed, ctx) res_arr = tvm.nd.array(res_np, ctx) res_ref = np.zeros(res_shape).astype(env.acc_dtype) for b in range(batch_size // env.BATCH): for i in range(channel // env.BLOCK_OUT): for j in range(channel // env.BLOCK_IN): res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(env.acc_dtype), weight_packed[i,j].T.astype(env.acc_dtype)) res_ref = np.right_shift(res_ref, 8) res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype) time_f = f.time_evaluator("gemm", ctx, number=20) cost = time_f(data_arr, weight_arr, res_arr) res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH, channel // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT) if check_correctness: tvm.testing.assert_allclose(res_unpack, res_ref) return cost def run_schedule(load_inp, load_wgt, gemm, alu, store_out, print_ir, check_correctness): s = tvm.create_schedule(res.op) s[data_buf].set_scope(env.inp_scope) s[weight_buf].set_scope(env.wgt_scope) s[res_gem].set_scope(env.acc_scope) s[res_shf].set_scope(env.acc_scope) s[res_min].set_scope(env.acc_scope) s[res_max].set_scope(env.acc_scope) if block: bblock = block // env.BATCH iblock = block // env.BLOCK_IN oblock = block // env.BLOCK_OUT xbo, xco, xbi, xci = s[res].op.axis xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock) store_pt = xb2 s[res_gem].compute_at(s[res], xco1) s[res_shf].compute_at(s[res], xco1) s[res_min].compute_at(s[res], xco1) s[res_max].compute_at(s[res], xco1) xbo, xco, xbi, xci = s[res_gem].op.axis # Compute one line at a time ko1, ko2 = s[res_gem].split(ko, iblock) s[res_gem].reorder(ko1, ko2, xbo, xco, xbi, xci, ki) s[data_buf].compute_at(s[res_gem], ko1) s[weight_buf].compute_at(s[res_gem], ko1) # Use VTA instructions s[data_buf].pragma(s[data_buf].op.axis[0], load_inp) s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt) s[res_gem].tensorize(xbi, gemm) s[res_shf].pragma(s[res_shf].op.axis[0], alu) s[res_min].pragma(s[res_min].op.axis[0], alu) s[res_max].pragma(s[res_max].op.axis[0], alu) s[res].pragma(store_pt, store_out) else: xbo, xco, xbi, xci = s[res_gem].op.axis s[res_gem].reorder(ko, xbo, xco, xbi, xci, ki) # Use VTA instructions s[data_buf].pragma(s[data_buf].op.axis[0], load_inp) s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt) s[res_gem].tensorize(xbi, gemm) s[res_shf].pragma(s[res_shf].op.axis[0], alu) s[res_min].pragma(s[res_min].op.axis[0], alu) s[res_max].pragma(s[res_max].op.axis[0], alu) s[res].pragma(s[res].op.axis[0], store_out) if print_ir: print(tvm.lower(s, [data, weight, res], simple_mode=True)) return verify(s, check_correctness) def gemm_normal(print_ir): mock = env.mock print("----- GEMM GOPS End-to-End Test-------") def run_test(header, print_ir, check_correctness): cost = run_schedule( env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy, print_ir, check_correctness) gops = (num_ops / cost.mean) / float(10 ** 9) print(header) print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) with vta.build_config(): run_test("NORMAL", print_ir, True) def gemm_unittest(print_ir): mock = env.mock print("----- GEMM Unit Test-------") def run_test(header, print_ir): cost = run_schedule( mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy, print_ir, False) gops = (num_ops / cost.mean) / float(10 ** 9) print(header) print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) with vta.build_config(): run_test("NORMAL", print_ir) def alu_unittest(print_ir): mock = env.mock print("----- ALU Unit Test-------") def run_test(header, print_ir): cost = run_schedule( mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy, print_ir, False) gops = (num_ops / cost.mean) / float(10 ** 9) print(header) print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) with vta.build_config(): run_test("NORMAL", print_ir) print("") def load_inp_unittest(print_ir): mock = env.mock print("----- LoadInp Unit Test-------") def run_test(header, print_ir): cost = run_schedule( env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False) gops = (num_ops / cost.mean) / float(10 ** 9) bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9) print(header) print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % ( cost.mean, gops, bandwith)) with vta.build_config(): run_test("NORMAL", print_ir) print("") def load_wgt_unittest(print_ir): mock = env.mock print("----- LoadWgt Unit Test-------") def run_test(header, print_ir): cost = run_schedule( mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False) gops = (num_ops / cost.mean) / float(10 ** 9) bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9) print(header) print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % ( cost.mean, gops, bandwith)) with vta.build_config(): run_test("NORMAL", print_ir) print("") def store_out_unittest(print_ir): mock = env.mock print("----- StoreOut Unit Test-------") def run_test(header, print_ir): cost = run_schedule( mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy, print_ir, False) gops = (num_ops / cost.mean) / float(10 ** 9) bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9) print(header) print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % ( cost.mean, gops, bandwith)) with vta.build_config(): run_test("NORMAL", print_ir) print("") gemm_normal(False) gemm_unittest(False) alu_unittest(False) def _run(env, remote): print("========GEMM 128=========") run_gemm_packed(env, remote, 128, 128, 128) vta.testing.run(_run) if __name__ == "__main__": test_gemm()