Commit edac6a8d by Tianqi Chen

Refactor, refactor code structure, fix pynq rpc (#29)

parent 5c5806ba
...@@ -37,10 +37,10 @@ remote = rpc.connect(host, port) ...@@ -37,10 +37,10 @@ remote = rpc.connect(host, port)
vta.program_fpga(remote, BITSTREAM_FILE) vta.program_fpga(remote, BITSTREAM_FILE)
if verbose: if verbose:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.DEBUG)
# Change to -device=vta-cpu to run cpu only inference. # Change to -device=vtacpu to run cpu only inference.
target = "llvm -device=vta" target = tvm.target.create("llvm -device=vta")
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon" target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
synset = eval(open(os.path.join(CATEG_FILE)).read()) synset = eval(open(os.path.join(CATEG_FILE)).read())
...@@ -109,7 +109,7 @@ dtype = "float32" ...@@ -109,7 +109,7 @@ dtype = "float32"
sym = vta.graph.remove_stochastic(sym) sym = vta.graph.remove_stochastic(sym)
sym = vta.graph.clean_cast(sym) sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym) sym = vta.graph.clean_conv_fuse(sym)
if "vta" in target: if target.device_name == "vta":
sym = vta.graph.pack(sym, shape_dict, factor) sym = vta.graph.pack(sym, shape_dict, factor)
graph_attr.set_shape_inputs(sym, shape_dict) graph_attr.set_shape_inputs(sym, shape_dict)
...@@ -118,10 +118,10 @@ graph_attr.set_dtype_inputs(sym, dtype_dict) ...@@ -118,10 +118,10 @@ graph_attr.set_dtype_inputs(sym, dtype_dict)
sym = sym.apply("InferType") sym = sym.apply("InferType")
with nnvm.compiler.build_config(opt_level=3): with nnvm.compiler.build_config(opt_level=3):
if "vta" not in target: if target.device_name != "vta":
graph, lib, params = nnvm.compiler.build( graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict, sym, target_host, shape_dict, dtype_dict,
params=params, target_host=target_host) params=params)
else: else:
with vta.build_config(): with vta.build_config():
graph, lib, params = nnvm.compiler.build( graph, lib, params = nnvm.compiler.build(
...@@ -133,7 +133,7 @@ temp = util.tempdir() ...@@ -133,7 +133,7 @@ temp = util.tempdir()
lib.save(temp.relpath("graphlib.o")) lib.save(temp.relpath("graphlib.o"))
remote.upload(temp.relpath("graphlib.o")) remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o") lib = remote.load_module("graphlib.o")
ctx = remote.ext_dev(0) if "vta" in target else remote.cpu(0) ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0)
print("Build complete...") print("Build complete...")
......
...@@ -3,11 +3,12 @@ from __future__ import absolute_import as _abs ...@@ -3,11 +3,12 @@ from __future__ import absolute_import as _abs
from .environment import get_env, Environment from .environment import get_env, Environment
from . import arm_conv2d, vta_conv2d
from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga from .rpc_client import reconfig_runtime, program_fpga
try: try:
from . import top
from .build_module import build_config, lower, build
from . import graph from . import graph
except (ImportError, RuntimeError): except (ImportError, RuntimeError):
pass pass
...@@ -75,20 +75,20 @@ def server_start(): ...@@ -75,20 +75,20 @@ def server_start():
pkg = PkgConfig(cfg, proj_root) pkg = PkgConfig(cfg, proj_root)
# check if the configuration is already the same # check if the configuration is already the same
if os.path.isfile(cfg_path): if os.path.isfile(cfg_path):
old_cfg = json.load(open(cfg_path)) old_cfg = json.loads(open(cfg_path, "r").read())
if pkg.same_config(old_cfg): if pkg.same_config(old_cfg):
logging.info("Skip reconfiguration because runtime config is the same") logging.info("Skip reconfig_runtime due to same config.")
return return
cflags += ["-O2", "-std=c++11"] cflags = ["-O2", "-std=c++11"]
cflags += pkg.cflags cflags += pkg.cflags
ldflags = pkg.ldflags ldflags = pkg.ldflags
lib_name = dll_path lib_name = dll_path
source = env.pkg_config.lib_source source = pkg.lib_source
logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s", logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s",
dll_path, str(cflags), str(source), str(ldflags)) dll_path, str(cflags), str(source), str(ldflags))
cc.create_shared(lib_name, source, cflags + ldflags) cc.create_shared(lib_name, source, cflags + ldflags)
with open(cfg_path, "w") as outputfile: with open(cfg_path, "w") as outputfile:
json.dump(pkg.cfg_json, outputfile) outputfile.write(pkg.cfg_json)
def main(): def main():
......
"""TVM TOPI connector, eventually most of these should go to TVM repo"""
from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
from . import vta_conv2d
from . import arm_conv2d
...@@ -44,7 +44,7 @@ _SCHEDULES = [ ...@@ -44,7 +44,7 @@ _SCHEDULES = [
Im2ColPack(7, 4, 1, 16, False), Im2ColPack(7, 4, 1, 16, False),
] ]
@_get_schedule.register(["tcpu", "vta"]) @_get_schedule.register(["vtacpu", "vta"])
def _schedule_conv2d(wkl): def _schedule_conv2d(wkl):
if wkl not in _WORKLOADS: if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl)) raise ValueError("no schedule for such workload: {}".format(wkl))
...@@ -53,10 +53,10 @@ def _schedule_conv2d(wkl): ...@@ -53,10 +53,10 @@ def _schedule_conv2d(wkl):
return sch return sch
@conv2d.register(["tcpu", "vta"]) @conv2d.register(["vtacpu", "vta"])
def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype): def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype):
assert layout == 'NCHW', "only support NCHW convolution on tcpu" assert layout == 'NCHW', "only support NCHW convolution on vtacpu"
assert data.shape[0].value == 1, "only support batch size=1 convolution on tcpu" assert data.shape[0].value == 1, "only support batch size=1 convolution on vtacpu"
wkl = _get_workload(data, kernel, stride, padding, out_dtype) wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl) sch = _get_schedule(wkl)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype) return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype)
...@@ -284,7 +284,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, ...@@ -284,7 +284,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
return s return s
@generic.schedule_conv2d_nchw.register(["tcpu", "vta"]) @generic.schedule_conv2d_nchw.register(["vtacpu", "vta"])
def schedule_conv2d(outs): def schedule_conv2d(outs):
"""Create schedule for tensors""" """Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
......
"""Namespace for supporting packed_conv2d + ewise variant of nnvm.""" """Namespace for supporting packed_conv2d + ewise variant of nnvm."""
from __future__ import absolute_import as _abs
from collections import namedtuple from collections import namedtuple
...@@ -7,7 +8,7 @@ import tvm ...@@ -7,7 +8,7 @@ import tvm
import topi import topi
from nnvm.top import registry as reg, OpPattern from nnvm.top import registry as reg, OpPattern
from . import environment as vta from ..environment import get_env
Workload = namedtuple("Conv2DWorkload", Workload = namedtuple("Conv2DWorkload",
...@@ -219,7 +220,7 @@ def schedule_packed_conv2d(outs): ...@@ -219,7 +220,7 @@ def schedule_packed_conv2d(outs):
wrkld = _get_workload(data, pad_data, kernel, output) wrkld = _get_workload(data, pad_data, kernel, output)
plan = _WL2PLAN[wrkld] plan = _WL2PLAN[wrkld]
env = vta.get_env() env = get_env()
load_inp = load_wgt = load_out = store_out = env.dma_copy load_inp = load_wgt = load_out = store_out = env.dma_copy
alu = env.alu alu = env.alu
...@@ -251,7 +252,7 @@ def schedule_packed_conv2d(outs): ...@@ -251,7 +252,7 @@ def schedule_packed_conv2d(outs):
# tile # tile
oc_factor = (plan.oc_factor if plan.oc_factor oc_factor = (plan.oc_factor if plan.oc_factor
else wrkld.out_filter // vta.BLOCK_OUT) else plan.out_filter // env.BLOCK_OUT)
h_factor = (plan.h_factor if plan.h_factor else oshape[2]) h_factor = (plan.h_factor if plan.h_factor else oshape[2])
w_factor = (plan.w_factor if plan.w_factor else oshape[3]) w_factor = (plan.w_factor if plan.w_factor else oshape[3])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment