Commit a96a4a9b by Thierry Moreau Committed by Tianqi Chen

[TOPI] Automated schedule in conv2d TOPI lib, moving to GEMM intrinsic (#35)

* removing programming out of end to end example for now

* updating TOPI library to use gemm tensor intrinsic

* bug fix, autoschedule in TOPI conv lib

* removing the deprecated GEVM intrinsic

* refactoring, fixed lint test

* fix for integer division bug

* python3 bug fix for non matching types due to float division

* comment
parent dae77cdb
......@@ -31,11 +31,6 @@ for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE, BITST
print ("Downloading {}".format(file))
wget.download(url+file)
# Program the FPGA remotely
assert tvm.module.enabled("rpc")
remote = rpc.connect(host, port)
vta.program_fpga(remote, BITSTREAM_FILE)
if verbose:
logging.basicConfig(level=logging.DEBUG)
......@@ -129,8 +124,10 @@ with nnvm.compiler.build_config(opt_level=3):
params=params, target_host=target_host)
assert tvm.module.enabled("rpc")
temp = util.tempdir()
lib.save(temp.relpath("graphlib.o"))
remote = rpc.connect(host, port)
remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o")
ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0)
......
......@@ -55,7 +55,6 @@ class DevContext(object):
self.DEBUG_NO_SYNC = False
env._dev_ctx = self
self.gemm = intrin.gemm(env, env.mock_mode)
self.gevm = intrin.gevm(env, env.mock_mode)
def get_task_qid(self, qid):
"""Get transformed queue index."""
......@@ -205,11 +204,6 @@ class Environment(object):
return self.dev.gemm
@property
def gevm(self):
"""GEVM intrinsic"""
return self.dev.gevm
@property
def target_host(self):
"""The target host"""
if self.TARGET == "pynq":
......
......@@ -3,88 +3,6 @@ from __future__ import absolute_import as _abs
import tvm
def gevm(env, mock=False):
"""Vector-matrix multiply intrinsic
Parameters
----------
env : Environment
The Environment
mock : bool
Whether create a mock version.
"""
wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]),
dtype="int%d" % env.WGT_WIDTH,
name=env.wgt_scope)
inp = tvm.placeholder((wgt_shape[1], ),
dtype="int%d" % env.INP_WIDTH,
name=env.inp_scope)
k = tvm.reduce_axis((0, wgt_shape[1]), name="k")
out_dtype = "int%d" % env.ACC_WIDTH
out = tvm.compute((wgt_shape[0],),
lambda i: tvm.sum(inp[k].astype(out_dtype) *
wgt[i, k].astype(out_dtype),
axis=[k]),
name="out")
wgt_layout = tvm.decl_buffer(
wgt.shape, wgt.dtype, env.wgt_scope,
scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
inp_layout = tvm.decl_buffer(
inp.shape, inp.dtype, env.inp_scope,
scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
out_layout = tvm.decl_buffer(
out.shape, out.dtype, env.acc_scope,
scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
def intrin_func(ins, outs):
"""Vector-matrix multiply intrinsic function"""
dinp, dwgt = ins
dout = outs[0]
def instr(index):
"""Generate vector-matrix multiply VTA instruction"""
irb = tvm.ir_builder.create()
dev = env.dev
irb.scope_attr(dev.vta_axis, "coproc_scope",
dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
dev.vta_push_uop)
if index == 0 or index == 2:
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
0, 0,
dout.access_ptr("rw", "int32"),
dinp.access_ptr("r", "int32"),
dwgt.access_ptr("r", "int32"),
0, 0, 0))
else:
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
0, 1,
dout.access_ptr("rw", "int32"),
0,
0,
0, 0, 0))
return irb.get()
# return a triple of normal-set, reset, update
nop = tvm.make.Evaluate(0)
if mock:
return (nop, nop, nop)
return (instr(0), instr(1), instr(2))
return tvm.decl_tensor_intrin(out.op, intrin_func,
name="GEVM",
binds={inp: inp_layout,
wgt: wgt_layout,
out: out_layout})
def gemm(env, mock=False):
"""Matrix-matrix multiply intrinsic
......
......@@ -315,7 +315,7 @@ VTAGenericInsn getGEMMInsn(int uop_offset, int batch, int in_feat, int out_feat,
int push_next_dep) {
// Converter
union VTAInsn converter;
// GEVM instruction initialization
// GEMM instruction initialization
VTAGemInsn insn;
insn.opcode = VTA_OPCODE_GEMM;
insn.pop_prev_dep = pop_prev_dep;
......@@ -394,7 +394,7 @@ VTAGenericInsn getALUInsn(int opcode, int vector_size, bool use_imm, int imm, bo
VTAGenericInsn getFinishInsn(bool pop_prev, bool pop_next) {
// Converter
union VTAInsn converter;
// GEVM instruction initialization
// GEMM instruction initialization
VTAGemInsn insn;
insn.opcode = VTA_OPCODE_FINISH;
insn.pop_prev_dep = pop_prev;
......@@ -649,7 +649,7 @@ void printInstruction(int num_insn, VTAGenericInsn *insns) {
}
} else if (c.mem.opcode == VTA_OPCODE_GEMM) {
// Print instruction field information
printf("GEVM\n");
printf("GEMM\n");
printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
static_cast<int>(c.mem.pop_prev_dep),
static_cast<int>(c.mem.pop_next_dep),
......
......@@ -168,7 +168,7 @@ def test_gemm():
with vta.build_config():
run_test("NORMAL", print_ir, True)
def gevm_unittest(print_ir):
def gemm_unittest(print_ir):
mock = env.mock
print("----- GEMM Unit Test-------")
def run_test(header, print_ir):
......@@ -244,7 +244,7 @@ def test_gemm():
gemm_normal(False)
gevm_unittest(False)
gemm_unittest(False)
alu_unittest(False)
def _run(env, remote):
......
......@@ -23,14 +23,11 @@ def my_clip(x, a_min, a_max):
def test_vta_conv2d():
def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True):
data_shape = (batch_size, wl.in_filter // env.BLOCK_IN,
wl.height, wl.width, env.BLOCK_IN)
kernel_shape = (wl.out_filter // env.BLOCK_OUT,
wl.in_filter // env.BLOCK_IN,
wl.hkernel, wl.wkernel,
env.BLOCK_OUT, env.BLOCK_IN)
bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
data_shape = (batch_size//env.BATCH, wl.in_filter//env.BLOCK_IN,
wl.height, wl.width, env.BATCH, env.BLOCK_IN)
kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN,
wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN)
bias_shape = (1, wl.out_filter//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
......@@ -45,12 +42,13 @@ def test_vta_conv2d():
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
a_shape = (batch_size, wl.in_filter, wl.height, wl.width)
w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
stride = (wl.hstride, wl.wstride)
data_dtype = data.dtype
kernel_dtype = kernel.dtype
acc_dtype = env.acc_dtype
assert wl.hpad == wl.wpad
padding = wl.hpad
......@@ -58,7 +56,7 @@ def test_vta_conv2d():
@memoize("vta.tests.test_benchmark_topi.conv2d,verify_nhwc")
def get_ref_data():
a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype)
w_np = (np.random.uniform(size=w_shape) * 4).astype(data_dtype)
w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype)
a_np = np.abs(a_np)
w_np = np.abs(w_np)
b_np = topi.testing.conv2d_nchw_python(
......@@ -82,14 +80,15 @@ def test_vta_conv2d():
bias_orig = np.abs(bias_orig)
data_packed = data_orig.reshape(
batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
wl.height, wl.width).transpose((0, 1, 3, 4, 2))
batch_size//env.BATCH, env.BATCH,
wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3))
kernel_packed = kernel_orig.reshape(
wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT,
wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT,
wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
bias_packed = bias_orig.reshape(
wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
1, wl.out_filter // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
res_shape = topi.util.get_const_tuple(res.shape)
res_np = np.zeros(res_shape).astype(res.dtype)
......@@ -100,7 +99,7 @@ def test_vta_conv2d():
time_f = f.time_evaluator("conv2d", ctx, number=5)
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
res_unpack = res_arr.asnumpy().transpose(
(0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
(0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
if check_correctness:
assert wl.hpad == wl.wpad
stride = (wl.hstride, wl.wstride)
......@@ -127,18 +126,18 @@ def test_vta_conv2d():
# ResNet18 workloads
resnet = {
# Workloads of resnet18 on imagenet
0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
}
batch_size = 1
......@@ -148,6 +147,7 @@ def test_vta_conv2d():
print("key=%s" % key)
print(wl)
run_vta_conv2d(env, remote, key, batch_size, wl)
vta.testing.run(_run)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment