Commit edac6a8d by Tianqi Chen

Refactor, refactor code structure, fix pynq rpc (#29)

parent 5c5806ba
...@@ -37,10 +37,10 @@ remote = rpc.connect(host, port) ...@@ -37,10 +37,10 @@ remote = rpc.connect(host, port)
vta.program_fpga(remote, BITSTREAM_FILE) vta.program_fpga(remote, BITSTREAM_FILE)
if verbose: if verbose:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.DEBUG)
# Change to -device=vta-cpu to run cpu only inference. # Change to -device=vtacpu to run cpu only inference.
target = "llvm -device=vta" target = tvm.target.create("llvm -device=vta")
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon" target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
synset = eval(open(os.path.join(CATEG_FILE)).read()) synset = eval(open(os.path.join(CATEG_FILE)).read())
...@@ -109,7 +109,7 @@ dtype = "float32" ...@@ -109,7 +109,7 @@ dtype = "float32"
sym = vta.graph.remove_stochastic(sym) sym = vta.graph.remove_stochastic(sym)
sym = vta.graph.clean_cast(sym) sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym) sym = vta.graph.clean_conv_fuse(sym)
if "vta" in target: if target.device_name == "vta":
sym = vta.graph.pack(sym, shape_dict, factor) sym = vta.graph.pack(sym, shape_dict, factor)
graph_attr.set_shape_inputs(sym, shape_dict) graph_attr.set_shape_inputs(sym, shape_dict)
...@@ -118,10 +118,10 @@ graph_attr.set_dtype_inputs(sym, dtype_dict) ...@@ -118,10 +118,10 @@ graph_attr.set_dtype_inputs(sym, dtype_dict)
sym = sym.apply("InferType") sym = sym.apply("InferType")
with nnvm.compiler.build_config(opt_level=3): with nnvm.compiler.build_config(opt_level=3):
if "vta" not in target: if target.device_name != "vta":
graph, lib, params = nnvm.compiler.build( graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict, sym, target_host, shape_dict, dtype_dict,
params=params, target_host=target_host) params=params)
else: else:
with vta.build_config(): with vta.build_config():
graph, lib, params = nnvm.compiler.build( graph, lib, params = nnvm.compiler.build(
...@@ -133,7 +133,7 @@ temp = util.tempdir() ...@@ -133,7 +133,7 @@ temp = util.tempdir()
lib.save(temp.relpath("graphlib.o")) lib.save(temp.relpath("graphlib.o"))
remote.upload(temp.relpath("graphlib.o")) remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o") lib = remote.load_module("graphlib.o")
ctx = remote.ext_dev(0) if "vta" in target else remote.cpu(0) ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0)
print("Build complete...") print("Build complete...")
......
...@@ -3,11 +3,12 @@ from __future__ import absolute_import as _abs ...@@ -3,11 +3,12 @@ from __future__ import absolute_import as _abs
from .environment import get_env, Environment from .environment import get_env, Environment
from . import arm_conv2d, vta_conv2d
from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga from .rpc_client import reconfig_runtime, program_fpga
try: try:
from . import top
from .build_module import build_config, lower, build
from . import graph from . import graph
except (ImportError, RuntimeError): except (ImportError, RuntimeError):
pass pass
...@@ -75,20 +75,20 @@ def server_start(): ...@@ -75,20 +75,20 @@ def server_start():
pkg = PkgConfig(cfg, proj_root) pkg = PkgConfig(cfg, proj_root)
# check if the configuration is already the same # check if the configuration is already the same
if os.path.isfile(cfg_path): if os.path.isfile(cfg_path):
old_cfg = json.load(open(cfg_path)) old_cfg = json.loads(open(cfg_path, "r").read())
if pkg.same_config(old_cfg): if pkg.same_config(old_cfg):
logging.info("Skip reconfiguration because runtime config is the same") logging.info("Skip reconfig_runtime due to same config.")
return return
cflags += ["-O2", "-std=c++11"] cflags = ["-O2", "-std=c++11"]
cflags += pkg.cflags cflags += pkg.cflags
ldflags = pkg.ldflags ldflags = pkg.ldflags
lib_name = dll_path lib_name = dll_path
source = env.pkg_config.lib_source source = pkg.lib_source
logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s", logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s",
dll_path, str(cflags), str(source), str(ldflags)) dll_path, str(cflags), str(source), str(ldflags))
cc.create_shared(lib_name, source, cflags + ldflags) cc.create_shared(lib_name, source, cflags + ldflags)
with open(cfg_path, "w") as outputfile: with open(cfg_path, "w") as outputfile:
json.dump(pkg.cfg_json, outputfile) outputfile.write(pkg.cfg_json)
def main(): def main():
......
"""TVM TOPI connector, eventually most of these should go to TVM repo"""
from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
from . import vta_conv2d
from . import arm_conv2d
...@@ -44,7 +44,7 @@ _SCHEDULES = [ ...@@ -44,7 +44,7 @@ _SCHEDULES = [
Im2ColPack(7, 4, 1, 16, False), Im2ColPack(7, 4, 1, 16, False),
] ]
@_get_schedule.register(["tcpu", "vta"]) @_get_schedule.register(["vtacpu", "vta"])
def _schedule_conv2d(wkl): def _schedule_conv2d(wkl):
if wkl not in _WORKLOADS: if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl)) raise ValueError("no schedule for such workload: {}".format(wkl))
...@@ -53,10 +53,10 @@ def _schedule_conv2d(wkl): ...@@ -53,10 +53,10 @@ def _schedule_conv2d(wkl):
return sch return sch
@conv2d.register(["tcpu", "vta"]) @conv2d.register(["vtacpu", "vta"])
def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype): def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype):
assert layout == 'NCHW', "only support NCHW convolution on tcpu" assert layout == 'NCHW', "only support NCHW convolution on vtacpu"
assert data.shape[0].value == 1, "only support batch size=1 convolution on tcpu" assert data.shape[0].value == 1, "only support batch size=1 convolution on vtacpu"
wkl = _get_workload(data, kernel, stride, padding, out_dtype) wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl) sch = _get_schedule(wkl)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype) return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype)
...@@ -284,7 +284,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, ...@@ -284,7 +284,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
return s return s
@generic.schedule_conv2d_nchw.register(["tcpu", "vta"]) @generic.schedule_conv2d_nchw.register(["vtacpu", "vta"])
def schedule_conv2d(outs): def schedule_conv2d(outs):
"""Create schedule for tensors""" """Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
......
"""Namespace for supporting packed_conv2d + ewise variant of nnvm.""" """Namespace for supporting packed_conv2d + ewise variant of nnvm."""
from __future__ import absolute_import as _abs
from collections import namedtuple from collections import namedtuple
...@@ -7,7 +8,7 @@ import tvm ...@@ -7,7 +8,7 @@ import tvm
import topi import topi
from nnvm.top import registry as reg, OpPattern from nnvm.top import registry as reg, OpPattern
from . import environment as vta from ..environment import get_env
Workload = namedtuple("Conv2DWorkload", Workload = namedtuple("Conv2DWorkload",
...@@ -219,7 +220,7 @@ def schedule_packed_conv2d(outs): ...@@ -219,7 +220,7 @@ def schedule_packed_conv2d(outs):
wrkld = _get_workload(data, pad_data, kernel, output) wrkld = _get_workload(data, pad_data, kernel, output)
plan = _WL2PLAN[wrkld] plan = _WL2PLAN[wrkld]
env = vta.get_env() env = get_env()
load_inp = load_wgt = load_out = store_out = env.dma_copy load_inp = load_wgt = load_out = store_out = env.dma_copy
alu = env.alu alu = env.alu
...@@ -251,7 +252,7 @@ def schedule_packed_conv2d(outs): ...@@ -251,7 +252,7 @@ def schedule_packed_conv2d(outs):
# tile # tile
oc_factor = (plan.oc_factor if plan.oc_factor oc_factor = (plan.oc_factor if plan.oc_factor
else wrkld.out_filter // vta.BLOCK_OUT) else plan.out_filter // env.BLOCK_OUT)
h_factor = (plan.h_factor if plan.h_factor else oshape[2]) h_factor = (plan.h_factor if plan.h_factor else oshape[2])
w_factor = (plan.w_factor if plan.w_factor else oshape[3]) w_factor = (plan.w_factor if plan.w_factor else oshape[3])
......
"""Testing if we can generate code in topi style""" """Testing if we can generate code in topi style"""
import topi
import tvm import tvm
from tvm.contrib import util, rpc from tvm.contrib import util
from tvm.contrib.pickle_memoize import memoize
import topi
import topi.testing
import vta import vta
from vta import vta_conv2d import vta.testing
import numpy as np import numpy as np
import mxnet as mx
Workload = vta_conv2d.Workload Workload = vta.top.vta_conv2d.Workload
@tvm.tag_scope(tag=topi.tag.ELEMWISE) @tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max): def my_clip(x, a_min, a_max):
...@@ -19,128 +20,136 @@ def my_clip(x, a_min, a_max): ...@@ -19,128 +20,136 @@ def my_clip(x, a_min, a_max):
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x return x
host = "pynq"
port = 9091 def test_vta_conv2d():
target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon" def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True):
print_ir = False data_shape = (batch_size, wl.in_filter // env.BLOCK_IN,
wl.height, wl.width, env.BLOCK_IN)
kernel_shape = (wl.out_filter // env.BLOCK_OUT,
def test_vta_conv2d(key, batch_size, wl, profile=True): wl.in_filter // env.BLOCK_IN,
env = vta.get_env() wl.hkernel, wl.wkernel,
data_shape = (batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_OUT, env.BLOCK_IN)
wl.height, wl.width, env.BLOCK_IN) bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
kernel_shape = (wl.out_filter // env.BLOCK_OUT,
wl.in_filter // env.BLOCK_IN,
wl.hkernel, wl.wkernel, fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
env.BLOCK_OUT, env.BLOCK_IN) fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT) data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 res_conv = vta.top.packed_conv2d(
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride))
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) res = topi.right_shift(res_conv, 8)
bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype) res = topi.broadcast_add(res, bias)
res = my_clip(res, 0, 127)
res_conv = vta_conv2d.packed_conv2d( res = topi.cast(res, "int8")
data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride))
res = topi.right_shift(res_conv, 8) num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
res = topi.broadcast_add(res, bias)
res = my_clip(res, 0, 127) a_shape = (batch_size, wl.in_filter, wl.height, wl.width)
res = topi.cast(res, "int8") w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
stride = (wl.hstride, wl.wstride)
num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter data_dtype = data.dtype
acc_dtype = env.acc_dtype
def verify(s, check_correctness): assert wl.hpad == wl.wpad
mod = tvm.build(s, [data, kernel, bias, res], "ext_dev", target, name="conv2d") padding = wl.hpad
temp = util.tempdir()
remote = rpc.connect(host, port) @memoize("vta.tests.test_benchmark_topi.conv2d,verify_nhwc")
def get_ref_data():
mod.save(temp.relpath("conv2d.o")) a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype)
remote.upload(temp.relpath("conv2d.o")) w_np = (np.random.uniform(size=w_shape) * 4).astype(data_dtype)
f = remote.load_module("conv2d.o") a_np = np.abs(a_np)
# verify w_np = np.abs(w_np)
ctx = remote.ext_dev(0) b_np = topi.testing.conv2d_nchw_python(
# Data in original format a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype)
data_orig = (np.random.uniform( return a_np, w_np, b_np
size=(batch_size, wl.in_filter, wl.height, wl.width)) * 4).astype(data.dtype)
kernel_orig = (np.random.uniform(
size=(wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)) * 4).astype(kernel.dtype) def verify(s, check_correctness):
bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32") mod = vta.build(s, [data, kernel, bias, res], "ext_dev",
env.target_host, name="conv2d")
data_orig = np.abs(data_orig) temp = util.tempdir()
kernel_orig = np.abs(kernel_orig)
bias_orig = np.abs(bias_orig) mod.save(temp.relpath("conv2d.o"))
remote.upload(temp.relpath("conv2d.o"))
data_packed = data_orig.reshape( f = remote.load_module("conv2d.o")
batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN, # verify
wl.height, wl.width).transpose((0, 1, 3, 4, 2)) ctx = remote.ext_dev(0)
kernel_packed = kernel_orig.reshape( # Data in original format
wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT, data_orig, kernel_orig, res_ref = get_ref_data()
wl.in_filter // env.BLOCK_IN, env.BLOCK_IN, bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32")
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) bias_orig = np.abs(bias_orig)
bias_packed = bias_orig.reshape(
wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT) data_packed = data_orig.reshape(
res_shape = topi.util.get_const_tuple(res.shape) batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
wl.height, wl.width).transpose((0, 1, 3, 4, 2))
res_np = np.zeros(res_shape).astype(res.dtype) kernel_packed = kernel_orig.reshape(
data_arr = tvm.nd.array(data_packed, ctx) wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT,
kernel_arr = tvm.nd.array(kernel_packed, ctx) wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
bias_arr = tvm.nd.array(bias_packed, ctx) wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
res_arr = tvm.nd.array(res_np, ctx) bias_packed = bias_orig.reshape(
time_f = f.time_evaluator("conv2d", ctx, number=10) wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) res_shape = topi.util.get_const_tuple(res.shape)
res_unpack = res_arr.asnumpy().transpose(
(0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) res_np = np.zeros(res_shape).astype(res.dtype)
if check_correctness: data_arr = tvm.nd.array(data_packed, ctx)
res_ref = mx.nd.Convolution( kernel_arr = tvm.nd.array(kernel_packed, ctx)
mx.nd.array(data_orig.astype(env.acc_dtype), mx.cpu(0)), bias_arr = tvm.nd.array(bias_packed, ctx)
mx.nd.array(kernel_orig.astype(env.acc_dtype), mx.cpu(0)), res_arr = tvm.nd.array(res_np, ctx)
stride=(wl.hstride, wl.wstride), time_f = f.time_evaluator("conv2d", ctx, number=5)
kernel=(wl.hkernel, wl.wkernel), cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
num_filter=wl.out_filter, res_unpack = res_arr.asnumpy().transpose(
no_bias=True, (0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
pad=(wl.hpad, wl.wpad)).asnumpy().astype(env.acc_dtype) if check_correctness:
res_ref = res_ref >> 8 assert wl.hpad == wl.wpad
res_ref += bias_orig.reshape(wl.out_filter, 1, 1) stride = (wl.hstride, wl.wstride)
res_ref = np.clip(res_ref, 0, 127).astype("int8") padding = wl.hpad
np.testing.assert_allclose(res_unpack, res_ref) res_ref = res_ref >> 8
print("Correctness check pass...") res_ref += bias_orig.reshape(wl.out_filter, 1, 1)
return cost res_ref = np.clip(res_ref, 0, 127).astype("int8")
np.testing.assert_allclose(res_unpack, res_ref)
def conv_normal(print_ir): return cost
print("----- CONV2D End-to-End Test-------")
with vta.build_config(): def conv_normal(print_ir):
s = vta_conv2d.schedule_packed_conv2d([res]) print("----- CONV2D End-to-End Test-------")
if print_ir: with vta.build_config():
print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) s = vta.top.schedule_packed_conv2d([res])
if print_ir:
print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
cost = verify(s, True) cost = verify(s, True)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
conv_normal(print_ir) conv_normal(False)
# ResNet18 workloads def _run(env, remote):
resnet = { # ResNet18 workloads
# Workloads of resnet18 on imagenet resnet = {
0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2), # Workloads of resnet18 on imagenet
1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), 0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1), 1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2), 2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2), 3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1), 4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2), 5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2), 6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1), 7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2), 8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2), 9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1), 10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
} 11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
}
batch_size = 1
for i in range(0, len(resnet)): batch_size = 1
wl = resnet[i] for i in range(0, len(resnet)):
key = "resnet-cfg[%d]" % i wl = resnet[i]
print "key=%s" % key key = "resnet-cfg[%d]" % i
print wl print("key=%s" % key)
test_vta_conv2d(key, batch_size, wl) print(wl)
run_vta_conv2d(env, remote, key, batch_size, wl)
vta.testing.run(_run)
if __name__ == "__main__":
test_vta_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment