Commit edac6a8d by Tianqi Chen

Refactor, refactor code structure, fix pynq rpc (#29)

parent 5c5806ba
......@@ -37,10 +37,10 @@ remote = rpc.connect(host, port)
vta.program_fpga(remote, BITSTREAM_FILE)
if verbose:
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.DEBUG)
# Change to -device=vta-cpu to run cpu only inference.
target = "llvm -device=vta"
# Change to -device=vtacpu to run cpu only inference.
target = tvm.target.create("llvm -device=vta")
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
synset = eval(open(os.path.join(CATEG_FILE)).read())
......@@ -109,7 +109,7 @@ dtype = "float32"
sym = vta.graph.remove_stochastic(sym)
sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym)
if "vta" in target:
if target.device_name == "vta":
sym = vta.graph.pack(sym, shape_dict, factor)
graph_attr.set_shape_inputs(sym, shape_dict)
......@@ -118,10 +118,10 @@ graph_attr.set_dtype_inputs(sym, dtype_dict)
sym = sym.apply("InferType")
with nnvm.compiler.build_config(opt_level=3):
if "vta" not in target:
if target.device_name != "vta":
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)
sym, target_host, shape_dict, dtype_dict,
params=params)
else:
with vta.build_config():
graph, lib, params = nnvm.compiler.build(
......@@ -133,7 +133,7 @@ temp = util.tempdir()
lib.save(temp.relpath("graphlib.o"))
remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o")
ctx = remote.ext_dev(0) if "vta" in target else remote.cpu(0)
ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0)
print("Build complete...")
......
......@@ -3,11 +3,12 @@ from __future__ import absolute_import as _abs
from .environment import get_env, Environment
from . import arm_conv2d, vta_conv2d
from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga
try:
from . import top
from .build_module import build_config, lower, build
from . import graph
except (ImportError, RuntimeError):
pass
......@@ -75,20 +75,20 @@ def server_start():
pkg = PkgConfig(cfg, proj_root)
# check if the configuration is already the same
if os.path.isfile(cfg_path):
old_cfg = json.load(open(cfg_path))
old_cfg = json.loads(open(cfg_path, "r").read())
if pkg.same_config(old_cfg):
logging.info("Skip reconfiguration because runtime config is the same")
logging.info("Skip reconfig_runtime due to same config.")
return
cflags += ["-O2", "-std=c++11"]
cflags = ["-O2", "-std=c++11"]
cflags += pkg.cflags
ldflags = pkg.ldflags
lib_name = dll_path
source = env.pkg_config.lib_source
source = pkg.lib_source
logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s",
dll_path, str(cflags), str(source), str(ldflags))
cc.create_shared(lib_name, source, cflags + ldflags)
with open(cfg_path, "w") as outputfile:
json.dump(pkg.cfg_json, outputfile)
outputfile.write(pkg.cfg_json)
def main():
......
"""TVM TOPI connector, eventually most of these should go to TVM repo"""
from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
from . import vta_conv2d
from . import arm_conv2d
......@@ -44,7 +44,7 @@ _SCHEDULES = [
Im2ColPack(7, 4, 1, 16, False),
]
@_get_schedule.register(["tcpu", "vta"])
@_get_schedule.register(["vtacpu", "vta"])
def _schedule_conv2d(wkl):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
......@@ -53,10 +53,10 @@ def _schedule_conv2d(wkl):
return sch
@conv2d.register(["tcpu", "vta"])
@conv2d.register(["vtacpu", "vta"])
def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype):
assert layout == 'NCHW', "only support NCHW convolution on tcpu"
assert data.shape[0].value == 1, "only support batch size=1 convolution on tcpu"
assert layout == 'NCHW', "only support NCHW convolution on vtacpu"
assert data.shape[0].value == 1, "only support batch size=1 convolution on vtacpu"
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype)
......@@ -284,7 +284,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
return s
@generic.schedule_conv2d_nchw.register(["tcpu", "vta"])
@generic.schedule_conv2d_nchw.register(["vtacpu", "vta"])
def schedule_conv2d(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
......
"""Namespace for supporting packed_conv2d + ewise variant of nnvm."""
from __future__ import absolute_import as _abs
from collections import namedtuple
......@@ -7,7 +8,7 @@ import tvm
import topi
from nnvm.top import registry as reg, OpPattern
from . import environment as vta
from ..environment import get_env
Workload = namedtuple("Conv2DWorkload",
......@@ -219,7 +220,7 @@ def schedule_packed_conv2d(outs):
wrkld = _get_workload(data, pad_data, kernel, output)
plan = _WL2PLAN[wrkld]
env = vta.get_env()
env = get_env()
load_inp = load_wgt = load_out = store_out = env.dma_copy
alu = env.alu
......@@ -251,7 +252,7 @@ def schedule_packed_conv2d(outs):
# tile
oc_factor = (plan.oc_factor if plan.oc_factor
else wrkld.out_filter // vta.BLOCK_OUT)
else plan.out_filter // env.BLOCK_OUT)
h_factor = (plan.h_factor if plan.h_factor else oshape[2])
w_factor = (plan.w_factor if plan.w_factor else oshape[3])
......
"""Testing if we can generate code in topi style"""
import topi
import tvm
from tvm.contrib import util, rpc
from tvm.contrib import util
from tvm.contrib.pickle_memoize import memoize
import topi
import topi.testing
import vta
from vta import vta_conv2d
import vta.testing
import numpy as np
import mxnet as mx
Workload = vta_conv2d.Workload
Workload = vta.top.vta_conv2d.Workload
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
......@@ -19,14 +20,9 @@ def my_clip(x, a_min, a_max):
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon"
print_ir = False
def test_vta_conv2d(key, batch_size, wl, profile=True):
env = vta.get_env()
def test_vta_conv2d():
def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True):
data_shape = (batch_size, wl.in_filter // env.BLOCK_IN,
wl.height, wl.width, env.BLOCK_IN)
kernel_shape = (wl.out_filter // env.BLOCK_OUT,
......@@ -42,7 +38,7 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype)
res_conv = vta_conv2d.packed_conv2d(
res_conv = vta.top.packed_conv2d(
data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride))
res = topi.right_shift(res_conv, 8)
res = topi.broadcast_add(res, bias)
......@@ -51,10 +47,29 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
a_shape = (batch_size, wl.in_filter, wl.height, wl.width)
w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
stride = (wl.hstride, wl.wstride)
data_dtype = data.dtype
acc_dtype = env.acc_dtype
assert wl.hpad == wl.wpad
padding = wl.hpad
@memoize("vta.tests.test_benchmark_topi.conv2d,verify_nhwc")
def get_ref_data():
a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype)
w_np = (np.random.uniform(size=w_shape) * 4).astype(data_dtype)
a_np = np.abs(a_np)
w_np = np.abs(w_np)
b_np = topi.testing.conv2d_nchw_python(
a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype)
return a_np, w_np, b_np
def verify(s, check_correctness):
mod = tvm.build(s, [data, kernel, bias, res], "ext_dev", target, name="conv2d")
mod = vta.build(s, [data, kernel, bias, res], "ext_dev",
env.target_host, name="conv2d")
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("conv2d.o"))
remote.upload(temp.relpath("conv2d.o"))
......@@ -62,14 +77,8 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
# verify
ctx = remote.ext_dev(0)
# Data in original format
data_orig = (np.random.uniform(
size=(batch_size, wl.in_filter, wl.height, wl.width)) * 4).astype(data.dtype)
kernel_orig = (np.random.uniform(
size=(wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)) * 4).astype(kernel.dtype)
data_orig, kernel_orig, res_ref = get_ref_data()
bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32")
data_orig = np.abs(data_orig)
kernel_orig = np.abs(kernel_orig)
bias_orig = np.abs(bias_orig)
data_packed = data_orig.reshape(
......@@ -88,40 +97,35 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
kernel_arr = tvm.nd.array(kernel_packed, ctx)
bias_arr = tvm.nd.array(bias_packed, ctx)
res_arr = tvm.nd.array(res_np, ctx)
time_f = f.time_evaluator("conv2d", ctx, number=10)
time_f = f.time_evaluator("conv2d", ctx, number=5)
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
res_unpack = res_arr.asnumpy().transpose(
(0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
if check_correctness:
res_ref = mx.nd.Convolution(
mx.nd.array(data_orig.astype(env.acc_dtype), mx.cpu(0)),
mx.nd.array(kernel_orig.astype(env.acc_dtype), mx.cpu(0)),
stride=(wl.hstride, wl.wstride),
kernel=(wl.hkernel, wl.wkernel),
num_filter=wl.out_filter,
no_bias=True,
pad=(wl.hpad, wl.wpad)).asnumpy().astype(env.acc_dtype)
assert wl.hpad == wl.wpad
stride = (wl.hstride, wl.wstride)
padding = wl.hpad
res_ref = res_ref >> 8
res_ref += bias_orig.reshape(wl.out_filter, 1, 1)
res_ref = np.clip(res_ref, 0, 127).astype("int8")
np.testing.assert_allclose(res_unpack, res_ref)
print("Correctness check pass...")
return cost
def conv_normal(print_ir):
print("----- CONV2D End-to-End Test-------")
with vta.build_config():
s = vta_conv2d.schedule_packed_conv2d([res])
s = vta.top.schedule_packed_conv2d([res])
if print_ir:
print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
cost = verify(s, True)
gops = (num_ops / cost.mean) / float(10 ** 9)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
conv_normal(print_ir)
conv_normal(False)
# ResNet18 workloads
resnet = {
def _run(env, remote):
# ResNet18 workloads
resnet = {
# Workloads of resnet18 on imagenet
0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
......@@ -135,12 +139,17 @@ resnet = {
9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
}
}
batch_size = 1
for i in range(0, len(resnet)):
batch_size = 1
for i in range(0, len(resnet)):
wl = resnet[i]
key = "resnet-cfg[%d]" % i
print "key=%s" % key
print wl
test_vta_conv2d(key, batch_size, wl)
print("key=%s" % key)
print(wl)
run_vta_conv2d(env, remote, key, batch_size, wl)
vta.testing.run(_run)
if __name__ == "__main__":
test_vta_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment