Commit dea167a8 by Tianqi Chen

[COMPILER] Refactor compiler to enable configuration (#21)

parent f5960ba6
...@@ -39,8 +39,9 @@ vta.program_fpga(remote, BITSTREAM_FILE) ...@@ -39,8 +39,9 @@ vta.program_fpga(remote, BITSTREAM_FILE)
if verbose: if verbose:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# Change to -device=tcpu to run cpu only inference. # Change to -device=vta-cpu to run cpu only inference.
target = "llvm -device=vta" target = "llvm -device=vta"
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
synset = eval(open(os.path.join(CATEG_FILE)).read()) synset = eval(open(os.path.join(CATEG_FILE)).read())
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224)) image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))
...@@ -117,15 +118,16 @@ graph_attr.set_dtype_inputs(sym, dtype_dict) ...@@ -117,15 +118,16 @@ graph_attr.set_dtype_inputs(sym, dtype_dict)
sym = sym.apply("InferType") sym = sym.apply("InferType")
with nnvm.compiler.build_config(opt_level=3): with nnvm.compiler.build_config(opt_level=3):
bdict = {}
if "vta" not in target: if "vta" not in target:
bdict = {"add_lower_pass": []} graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)
else: else:
bdict = {"add_lower_pass": vta.debug_mode(0)} with vta.build_config():
with tvm.build_config(**bdict):
graph, lib, params = nnvm.compiler.build( graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict, sym, target, shape_dict, dtype_dict,
params=params) params=params, target_host=target_host)
temp = util.tempdir() temp = util.tempdir()
lib.save(temp.relpath("graphlib.o")) lib.save(temp.relpath("graphlib.o"))
......
"""TVM-based VTA Compiler Toolchain""" """TVM-based VTA Compiler Toolchain"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .hw_spec import * from .environment import get_env, Environment
try: try:
from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU # allow optional import in config mode.
from .intrin import GEVM, GEMM
from .build import debug_mode
from . import mock, ir_pass
from . import arm_conv2d, vta_conv2d from . import arm_conv2d, vta_conv2d
except AttributeError: from .build_module import build_config, lower, build
pass from .rpc_client import reconfig_runtime, program_fpga
from .rpc_client import reconfig_runtime, program_fpga
try:
from . import graph from . import graph
except ImportError: except ImportError:
pass pass
"""Runtime function related hooks""" """VTA specific buildin for runtime."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from tvm import build_module
from . runtime import CB_HANDLE
from . import ir_pass from . import ir_pass
from .environment import get_env
def lift_coproc_scope(x): def lift_coproc_scope(x):
"""Lift coprocessings cope to the """
x = ir_pass.lift_alloc_to_scope_begin(x) x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.ir_pass.LiftAttrScope(x, "coproc_scope", False) x = tvm.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x return x
def early_rewrite(stmt): def early_rewrite(stmt):
"""Try to do storage rewrite in early pass."""
try: try:
return tvm.ir_pass.StorageRewrite(stmt) return tvm.ir_pass.StorageRewrite(stmt)
except tvm.TVMError: except tvm.TVMError:
return stmt return stmt
def debug_mode(debug_flag): def build_config(debug_flag=0, **kwargs):
"""Pass to enable vta debug mode. """Build a build config for VTA.
Parameters Parameters
---------- ----------
debug_flag : int debug_flag : int
The dbeug flag to be passed. The dbeug flag to be passed.
kwargs : dict
Additional configurations.
Returns Returns
------- -------
pass_list: list of function build_config: BuildConfig
The pass to set to build_config(add_lower_pass=vta.debug_mode(mode)) The build config that can be used in TVM.
Example
--------
.. code-block:: python
# build a vta module.
with vta.build_config():
vta_module = tvm.build(s, ...)
""" """
env = get_env()
def add_debug(stmt): def add_debug(stmt):
debug = tvm.call_extern( debug = tvm.call_extern(
"int32", "VTASetDebugMode", CB_HANDLE, debug_flag) "int32", "VTASetDebugMode",
env.dev.command_handle,
debug_flag)
return tvm.make.stmt_seq(debug, stmt) return tvm.make.stmt_seq(debug, stmt)
pass_list = [(1, ir_pass.inject_dma_intrin), pass_list = [(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy), (1, ir_pass.inject_skip_copy),
...@@ -48,8 +64,38 @@ def debug_mode(debug_flag): ...@@ -48,8 +64,38 @@ def debug_mode(debug_flag):
pass_list.append((2, ir_pass.inject_alu_intrin)) pass_list.append((2, ir_pass.inject_alu_intrin))
pass_list.append((3, ir_pass.fold_uop_loop)) pass_list.append((3, ir_pass.fold_uop_loop))
pass_list.append((3, ir_pass.cpu_access_rewrite)) pass_list.append((3, ir_pass.cpu_access_rewrite))
return pass_list return tvm.build_config(add_lower_pass=pass_list, **kwargs)
def lower(*args, **kwargs):
"""Thin wrapper of tvm.lower
This wrapper automatically applies VTA's build_config
if there is no user specified build_config in context.
# Add a lower pass to sync uop See Also
build_module.current_build_config().add_lower_pass = debug_mode(0) --------
tvm.lower : The original TVM's lower function
"""
cfg = tvm.build_module.current_build_config()
if not cfg.add_lower_pass:
with build_config():
return tvm.lower(*args, **kwargs)
return tvm.lower(*args, **kwargs)
def build(*args, **kwargs):
"""Thin wrapper of tvm.build
This wrapper automatically applies VTA's build_config
if there is no user specified build_config in context.
See Also
--------
tvm.build : The original TVM's build function
"""
cfg = tvm.build_module.current_build_config()
if not cfg.add_lower_pass:
with build_config():
return tvm.build(*args, **kwargs)
return tvm.build(*args, **kwargs)
"""Configurable VTA Hareware Environment scope."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import os
import copy
try:
# Allow missing import in config mode.
import tvm
from . import intrin
except ImportError:
pass
class DevContext(object):
"""Internal development context
This contains all the non-user facing compiler
internal context that is hold by the Environment.
Parameters
----------
env : Environment
The environment hosting the DevContext
Note
----
This class is introduced so we have a clear separation
of developer related stuffs and user facing attributes.
"""
# Memory id for DMA
MEM_ID_UOP = 0
MEM_ID_WGT = 1
MEM_ID_INP = 2
MEM_ID_ACC = 3
MEM_ID_OUT = 4
# VTA ALU Opcodes
ALU_OPCODE_MIN = 0
ALU_OPCODE_MAX = 1
ALU_OPCODE_ADD = 2
ALU_OPCODE_SHR = 3
# Task queue id (pipeline stage)
QID_LOAD_INP = 1
QID_LOAD_WGT = 1
QID_LOAD_OUT = 2
QID_STORE_OUT = 3
QID_COMPUTE = 2
QID_STORE_INP = 3
def __init__(self, env):
self.vta_axis = tvm.thread_axis("vta")
self.vta_push_uop = tvm.make.StringImm("VTAPushGEMMOp")
ctx = tvm.call_extern("handle", "VTATLSCommandHandle")
self.command_handle = tvm.make.Call(
"handle", "tvm_thread_context", [ctx],
tvm.expr.Call.Intrinsic, None, 0)
self.DEBUG_NO_SYNC = False
env._dev_ctx = self
self.gemm = intrin.gemm(env, env.mock_mode)
self.gevm = intrin.gevm(env, env.mock_mode)
def get_task_qid(self, qid):
"""Get transformed queue index."""
return 1 if self.DEBUG_NO_SYNC else qid
class Environment(object):
"""Hareware configuration object.
This object contains all the information
needed for compiling to a specific VTA backend.
Parameters
----------
cfg : dict of str to value.
The configuration parameters.
Example
--------
.. code-block:: python
# the following code reconfigures the environment
# temporarily to attributes specified in new_cfg.json
new_cfg = json.load(json.load(open("new_cfg.json")))
with vta.Environment(new_cfg):
# env works on the new environment
env = vta.get_env()
"""
current = None
cfg_keys = [
"target",
"LOG_INP_WIDTH",
"LOG_WGT_WIDTH",
"LOG_ACC_WIDTH",
"LOG_BATCH",
"LOG_BLOCK_IN",
"LOG_BLOCK_OUT",
"LOG_UOP_BUFF_SIZE",
"LOG_INP_BUFF_SIZE",
"LOG_WGT_BUFF_SIZE",
"LOG_ACC_BUFF_SIZE",
]
# constants
MAX_XFER = 1 << 22
# debug flags
DEBUG_DUMP_INSN = (1 << 1)
DEBUG_DUMP_UOP = (1 << 2)
DEBUG_SKIP_READ_BARRIER = (1 << 3)
DEBUG_SKIP_WRITE_BARRIER = (1 << 4)
# memory scopes
inp_scope = "local.inp_buffer"
wgt_scope = "local.wgt_buffer"
acc_scope = "local.acc_buffer"
# initialization function
def __init__(self, cfg):
# Log of input/activation width in bits
self.__dict__.update(cfg)
for key in self.cfg_keys:
if key not in cfg:
raise ValueError("Expect key %s in cfg" % key)
self.LOG_OUT_WIDTH = self.LOG_INP_WIDTH
self.LOG_OUT_BUFF_SIZE = (
self.LOG_ACC_BUFF_SIZE +
self.LOG_OUT_WIDTH -
self.LOG_ACC_WIDTH)
# width
self.INP_WIDTH = 1 << self.LOG_INP_WIDTH
self.WGT_WIDTH = 1 << self.LOG_WGT_WIDTH
self.ACC_WIDTH = 1 << self.LOG_ACC_WIDTH
self.BATCH = 1 << self.LOG_BATCH
self.BLOCK_IN = 1 << self.LOG_BLOCK_IN
self.BLOCK_OUT = 1 << self.LOG_BLOCK_OUT
self.OUT_WIDTH = self.INP_WIDTH
# buffer size
self.UOP_BUFF_SIZE = 1 << self.LOG_UOP_BUFF_SIZE
self.INP_BUFF_SIZE = 1 << self.LOG_INP_BUFF_SIZE
self.WGT_BUFF_SIZE = 1 << self.LOG_WGT_BUFF_SIZE
self.ACC_BUFF_SIZE = 1 << self.LOG_ACC_BUFF_SIZE
self.OUT_BUFF_SIZE = 1 << self.LOG_OUT_BUFF_SIZE
# bytes per buffer
self.INP_ELEM_BITS = (self.BATCH *
self.BLOCK_IN *
self.INP_WIDTH)
self.WGT_ELEM_BITS = (self.BLOCK_OUT *
self.BLOCK_IN *
self.WGT_WIDTH)
self.ACC_ELEM_BITS = (self.BATCH *
self.BLOCK_IN *
self.ACC_WIDTH)
self.INP_ELEM_BYTES = self.INP_ELEM_BITS // 8
self.WGT_ELEM_BYTES = self.WGT_ELEM_BITS // 8
self.ACC_ELEM_BYTES = self.ACC_ELEM_BITS // 8
# dtypes
self.acc_dtype = "int%d" % self.ACC_WIDTH
self.inp_dtype = "int%d" % self.INP_WIDTH
self.wgt_dtype = "int%d" % self.WGT_WIDTH
# lazy cached members
self.mock_mode = False
self._mock_env = None
self._dev_ctx = None
@property
def dev(self):
"""Developer context"""
if self._dev_ctx is None:
self._dev_ctx = DevContext(self)
return self._dev_ctx
@property
def mock(self):
"""A mock version of the Environment
The ALU, dma_copy and intrinsics will be
mocked to be nop.
"""
if self.mock_mode:
return self
if self._mock_env is None:
self._mock_env = copy.copy(self)
self._mock_env._dev_ctx = None
self._mock_env.mock_mode = True
return self._mock_env
@property
def dma_copy(self):
"""DMA copy pragma"""
return ("dma_copy"
if not self.mock_mode
else "skip_dma_copy")
@property
def alu(self):
"""ALU pragma"""
return ("alu"
if not self.mock_mode
else "skip_alu")
@property
def gemm(self):
"""GEMM intrinsic"""
return self.dev.gemm
@property
def gevm(self):
"""GEMM intrinsic"""
return self.dev.gevm
def get_env():
"""Get the current VTA Environment.
Returns
-------
env : Environment
The current environment.
"""
return Environment.current
# The memory information for the compiler
@tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope)
def mem_info_inp_buffer():
spec = get_env()
return tvm.make.node("MemoryInfo",
unit_bits=spec.INP_ELEM_BITS,
max_simd_bits=spec.INP_ELEM_BITS,
max_num_bits=spec.INP_BUFF_SIZE * 8,
head_address=None)
@tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope)
def mem_info_wgt_buffer():
spec = get_env()
return tvm.make.node("MemoryInfo",
unit_bits=spec.WGT_ELEM_BITS,
max_simd_bits=spec.WGT_ELEM_BITS,
max_num_bits=spec.WGT_BUFF_SIZE * 8,
head_address=None)
@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
def mem_info_out_buffer():
spec = get_env()
return tvm.make.node("MemoryInfo",
unit_bits=spec.ACC_ELEM_BITS,
max_simd_bits=spec.ACC_ELEM_BITS,
max_num_bits=spec.ACC_BUFF_SIZE * 8,
head_address=None)
# TVM related registration
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync")
def coproc_sync(op):
_ = op
return tvm.call_extern(
"int32", "VTASynchronize",
get_env().dev.command_handle, 1<<31)
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push")
def coproc_dep_push(op):
return tvm.call_extern(
"int32", "VTADepPush",
get_env().dev.command_handle,
op.args[0], op.args[1])
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop")
def coproc_dep_pop(op):
return tvm.call_extern(
"int32", "VTADepPop",
get_env().dev.command_handle,
op.args[0], op.args[1])
def _init_env():
"""Iniitalize the default global env"""
python_vta_dir = os.path.dirname(__file__)
filename = os.path.join(python_vta_dir, '../../config.mk')
keys = set()
for k in Environment.cfg_keys:
keys.add("VTA_" + k)
if not os.path.isfile(filename):
raise RuntimeError(
"Error: {} not found.make sure you have config.mk in your vta root"
.format(filename))
cfg = {}
with open(filename) as f:
for line in f:
for k in keys:
if k +" =" in line:
val = line.split("=")[1].strip()
cfg[k[4:]] = int(val)
cfg["target"] = "pynq"
return Environment(cfg)
Environment.current = _init_env()
...@@ -9,11 +9,13 @@ import argparse ...@@ -9,11 +9,13 @@ import argparse
import os import os
import ctypes import ctypes
import tvm import tvm
from tvm.contrib import rpc, util, cc from tvm.contrib import rpc, cc
@tvm.register_func("tvm.contrib.rpc.server.start", override=True) @tvm.register_func("tvm.contrib.rpc.server.start", override=True)
def server_start(): def server_start():
"""callback when server starts."""
# pylint: disable=unused-variable
curr_path = os.path.dirname( curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__))) os.path.abspath(os.path.expanduser(__file__)))
dll_path = os.path.abspath( dll_path = os.path.abspath(
......
"""VTA configuration constants (should match hw_spec.h"""
from __future__ import absolute_import as _abs
import os
python_vta_dir = os.path.dirname(__file__)
print python_vta_dir
filename = os.path.join(python_vta_dir, '../../config.mk')
VTA_PYNQ_BRAM_W = 32
VTA_PYNQ_BRAM_D = 1024
VTA_PYNQ_NUM_BRAM = 124
keys = ["VTA_LOG_INP_WIDTH", "VTA_LOG_WGT_WIDTH", "VTA_LOG_ACC_WIDTH",
"VTA_LOG_BATCH", "VTA_LOG_BLOCK_IN", "VTA_LOG_BLOCK_OUT",
"VTA_LOG_UOP_BUFF_SIZE", "VTA_LOG_INP_BUFF_SIZE",
"VTA_LOG_WGT_BUFF_SIZE", "VTA_LOG_ACC_BUFF_SIZE"]
params = {}
if os.path.isfile(filename):
with open(filename) as f:
for line in f:
for k in keys:
if k+" =" in line:
val = line.split("=")[1].strip()
params[k] = int(val)
# print params
else:
print "Error: {} not found. Please make sure you have config.mk in your vta root".format(filename)
exit()
# The Constants
VTA_INP_WIDTH = 1 << params["VTA_LOG_INP_WIDTH"]
VTA_WGT_WIDTH = 1 << params["VTA_LOG_WGT_WIDTH"]
VTA_OUT_WIDTH = 1 << params["VTA_LOG_ACC_WIDTH"]
VTA_TARGET = "VTA_PYNQ_TARGET"
# Dimensions of the GEMM unit
# (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT)
VTA_BATCH = 1 << params["VTA_LOG_BATCH"]
VTA_BLOCK_IN = 1 << params["VTA_LOG_BLOCK_IN"]
VTA_BLOCK_OUT = 1 << params["VTA_LOG_BLOCK_OUT"]
# log-2 On-chip uop buffer size in Bytes
VTA_LOG_UOP_BUFF_SIZE = params["VTA_LOG_UOP_BUFF_SIZE"]
# log-2 On-chip input buffer size in Bytes
VTA_LOG_INP_BUFF_SIZE = params["VTA_LOG_INP_BUFF_SIZE"]
# log-2 On-chip wgt buffer size in Bytes
VTA_LOG_WGT_BUFF_SIZE = params["VTA_LOG_WGT_BUFF_SIZE"]
# log-2 On-chip output buffer size in Bytes
VTA_LOG_OUT_BUFF_SIZE = params["VTA_LOG_ACC_BUFF_SIZE"]
# Uop buffer size
VTA_UOP_BUFF_SIZE = 1 << VTA_LOG_UOP_BUFF_SIZE
# Input buffer size
VTA_INP_BUFF_SIZE = 1 << VTA_LOG_INP_BUFF_SIZE
# On-chip wgt buffer size in Bytes
VTA_WGT_BUFF_SIZE = 1 << VTA_LOG_WGT_BUFF_SIZE
# Output buffer size.
VTA_OUT_BUFF_SIZE = 1 << VTA_LOG_OUT_BUFF_SIZE
# Number of bytes per buffer
VTA_INP_ELEM_BYTES = (VTA_BATCH*VTA_BLOCK_IN*VTA_INP_WIDTH//8)
VTA_WGT_ELEM_BYTES = (VTA_BLOCK_OUT*VTA_BLOCK_IN*VTA_WGT_WIDTH//8)
VTA_OUT_ELEM_BYTES = (VTA_BATCH*VTA_BLOCK_OUT*VTA_OUT_WIDTH//8)
# Maximum external buffer size in bytes
VTA_MAX_XFER = 1 << 22
# Number of elements
VTA_INP_BUFF_DEPTH = VTA_INP_BUFF_SIZE//VTA_INP_ELEM_BYTES
VTA_WGT_BUFF_DEPTH = VTA_WGT_BUFF_SIZE//VTA_WGT_ELEM_BYTES
VTA_OUT_BUFF_DEPTH = VTA_OUT_BUFF_SIZE//VTA_OUT_ELEM_BYTES
# Memory id for DMA
VTA_MEM_ID_UOP = 0
VTA_MEM_ID_WGT = 1
VTA_MEM_ID_INP = 2
VTA_MEM_ID_ACC = 3
VTA_MEM_ID_OUT = 4
# VTA ALU Opcodes
VTA_ALU_OPCODE_MIN = 0
VTA_ALU_OPCODE_MAX = 1
VTA_ALU_OPCODE_ADD = 2
VTA_ALU_OPCODE_SHR = 3
# Task queue id (pipeline stage)
VTA_QID_LOAD_INP = 1
VTA_QID_LOAD_WGT = 1
VTA_QID_LOAD_OUT = 2
VTA_QID_STORE_OUT = 3
VTA_QID_COMPUTE = 2
VTA_QID_STORE_INP = 3
# Debug flags
DEBUG_DUMP_INSN = (1 << 1)
DEBUG_DUMP_UOP = (1 << 2)
DEBUG_SKIP_READ_BARRIER = (1 << 3)
DEBUG_SKIP_WRITE_BARRIER = (1 << 4)
...@@ -2,65 +2,46 @@ ...@@ -2,65 +2,46 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from . import hw_spec as spec
from .runtime import VTA_AXIS, VTA_PUSH_UOP, get_task_qid
from .runtime import SCOPE_OUT, SCOPE_INP, SCOPE_WGT
# The memory information for the compiler def gevm(env, mock=False):
@tvm.register_func("tvm.info.mem.%s" % SCOPE_INP) """Vector-matrix multiply intrinsic
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits=spec.VTA_INP_ELEM_BYTES * 8,
max_simd_bits=spec.VTA_INP_ELEM_BYTES * 8,
max_num_bits=spec.VTA_INP_BUFF_SIZE * 8,
head_address=None)
@tvm.register_func("tvm.info.mem.%s" % SCOPE_WGT) Parameters
def mem_info_wgt_buffer(): ----------
return tvm.make.node("MemoryInfo", env : Environment
unit_bits=spec.VTA_WGT_ELEM_BYTES * 8, The Environment
max_simd_bits=spec.VTA_WGT_ELEM_BYTES * 8,
max_num_bits=spec.VTA_WGT_BUFF_SIZE * 8,
head_address=None)
@tvm.register_func("tvm.info.mem.%s" % SCOPE_OUT) mock : bool
def mem_info_out_buffer(): Whether create a mock version.
return tvm.make.node("MemoryInfo", """
unit_bits=spec.VTA_OUT_ELEM_BYTES * 8, wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
max_simd_bits=spec.VTA_OUT_ELEM_BYTES * 8, assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
max_num_bits=spec.VTA_OUT_BUFF_SIZE * 8, wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
head_address=None)
def intrin_gevm(mock=False):
"""Vector-matrix multiply intrinsic"""
wgt_lanes = spec.VTA_WGT_ELEM_BYTES * 8 // spec.VTA_WGT_WIDTH
assert wgt_lanes == spec.VTA_BLOCK_OUT * spec.VTA_BLOCK_IN
wgt_shape = (spec.VTA_BLOCK_OUT, spec.VTA_BLOCK_IN)
assert wgt_shape[0] * wgt_shape[1] == wgt_lanes assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
inp_lanes = spec.VTA_INP_ELEM_BYTES * 8 // spec.VTA_INP_WIDTH inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
out_lanes = spec.VTA_OUT_ELEM_BYTES * 8 // spec.VTA_OUT_WIDTH out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]), wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]),
dtype="int%d" % spec.VTA_WGT_WIDTH, dtype="int%d" % env.WGT_WIDTH,
name=SCOPE_WGT) name=env.wgt_scope)
inp = tvm.placeholder((wgt_shape[1], ), inp = tvm.placeholder((wgt_shape[1], ),
dtype="int%d" % spec.VTA_INP_WIDTH, dtype="int%d" % env.INP_WIDTH,
name=SCOPE_INP) name=env.inp_scope)
k = tvm.reduce_axis((0, wgt_shape[1]), name="k") k = tvm.reduce_axis((0, wgt_shape[1]), name="k")
out_dtype = "int%d" % spec.VTA_OUT_WIDTH out_dtype = "int%d" % env.ACC_WIDTH
out = tvm.compute((wgt_shape[0],), out = tvm.compute((wgt_shape[0],),
lambda i: tvm.sum(inp[k].astype(out_dtype) * lambda i: tvm.sum(inp[k].astype(out_dtype) *
wgt[i, k].astype(out_dtype), wgt[i, k].astype(out_dtype),
axis=[k]), axis=[k]),
name="out") name="out")
wgt_layout = tvm.decl_buffer( wgt_layout = tvm.decl_buffer(
wgt.shape, wgt.dtype, SCOPE_WGT, wgt.shape, wgt.dtype, env.wgt_scope,
scope=SCOPE_WGT, offset_factor=wgt_lanes, data_alignment=wgt_lanes) scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
inp_layout = tvm.decl_buffer( inp_layout = tvm.decl_buffer(
inp.shape, inp.dtype, SCOPE_INP, inp.shape, inp.dtype, env.inp_scope,
scope=SCOPE_INP, offset_factor=inp_lanes, data_alignment=inp_lanes) scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
out_layout = tvm.decl_buffer( out_layout = tvm.decl_buffer(
out.shape, out.dtype, SCOPE_OUT, out.shape, out.dtype, env.acc_scope,
scope=SCOPE_OUT, offset_factor=out_lanes, data_alignment=out_lanes) scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
def intrin_func(ins, outs): def intrin_func(ins, outs):
"""Vector-matrix multiply intrinsic function""" """Vector-matrix multiply intrinsic function"""
...@@ -69,8 +50,11 @@ def intrin_gevm(mock=False): ...@@ -69,8 +50,11 @@ def intrin_gevm(mock=False):
def instr(index): def instr(index):
"""Generate vector-matrix multiply VTA instruction""" """Generate vector-matrix multiply VTA instruction"""
irb = tvm.ir_builder.create() irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE)) dev = env.dev
irb.scope_attr(VTA_AXIS, "coproc_uop_scope", VTA_PUSH_UOP) irb.scope_attr(dev.vta_axis, "coproc_scope",
dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
dev.vta_push_uop)
if index == 0 or index == 2: if index == 0 or index == 2:
irb.emit(tvm.call_extern( irb.emit(tvm.call_extern(
"int32", "VTAUopPush", "int32", "VTAUopPush",
...@@ -101,45 +85,54 @@ def intrin_gevm(mock=False): ...@@ -101,45 +85,54 @@ def intrin_gevm(mock=False):
out: out_layout}) out: out_layout})
def intrin_gemm(mock=False): def gemm(env, mock=False):
"""Matrix-matrix multiply intrinsic""" """Matrix-matrix multiply intrinsic
wgt_lanes = spec.VTA_WGT_ELEM_BYTES * 8 // spec.VTA_WGT_WIDTH
assert wgt_lanes == spec.VTA_BLOCK_OUT * spec.VTA_BLOCK_IN Parameters
wgt_shape = (spec.VTA_BLOCK_OUT, spec.VTA_BLOCK_IN) ----------
env : Environment
The Environment
mock : bool
Whether create a mock version.
"""
wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
assert wgt_shape[0] * wgt_shape[1] == wgt_lanes assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
inp_lanes = spec.VTA_INP_ELEM_BYTES * 8 // spec.VTA_INP_WIDTH inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
assert inp_lanes == spec.VTA_BATCH * spec.VTA_BLOCK_IN assert inp_lanes == env.BATCH * env.BLOCK_IN
inp_shape = (spec.VTA_BATCH, spec.VTA_BLOCK_IN) inp_shape = (env.BATCH, env.BLOCK_IN)
assert inp_shape[0] * inp_shape[1] == inp_lanes assert inp_shape[0] * inp_shape[1] == inp_lanes
out_lanes = spec.VTA_OUT_ELEM_BYTES * 8 // spec.VTA_OUT_WIDTH out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
assert out_lanes == spec.VTA_BATCH * spec.VTA_BLOCK_OUT assert out_lanes == env.BATCH * env.BLOCK_OUT
out_shape = (spec.VTA_BATCH, spec.VTA_BLOCK_OUT) out_shape = (env.BATCH, env.BLOCK_OUT)
assert out_shape[0] * out_shape[1] == out_lanes assert out_shape[0] * out_shape[1] == out_lanes
wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]), wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]),
dtype="int%d" % spec.VTA_WGT_WIDTH, dtype="int%d" % env.WGT_WIDTH,
name=SCOPE_WGT) name=env.wgt_scope)
inp = tvm.placeholder((inp_shape[0], inp_shape[1]), inp = tvm.placeholder((inp_shape[0], inp_shape[1]),
dtype="int%d" % spec.VTA_INP_WIDTH, dtype="int%d" % env.INP_WIDTH,
name=SCOPE_INP) name=env.inp_scope)
k = tvm.reduce_axis((0, wgt_shape[1]), name="k") k = tvm.reduce_axis((0, wgt_shape[1]), name="k")
out_dtype = "int%d" % spec.VTA_OUT_WIDTH out_dtype = "int%d" % env.ACC_WIDTH
out = tvm.compute((out_shape[0], out_shape[1]), out = tvm.compute((out_shape[0], out_shape[1]),
lambda i, j: tvm.sum(inp[i, k].astype(out_dtype) * lambda i, j: tvm.sum(inp[i, k].astype(out_dtype) *
wgt[j, k].astype(out_dtype), wgt[j, k].astype(out_dtype),
axis=[k]), axis=[k]),
name="out") name="out")
wgt_layout = tvm.decl_buffer( wgt_layout = tvm.decl_buffer(
wgt.shape, wgt.dtype, SCOPE_WGT, wgt.shape, wgt.dtype, env.wgt_scope,
scope=SCOPE_WGT, offset_factor=wgt_lanes, data_alignment=wgt_lanes) scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
inp_layout = tvm.decl_buffer( inp_layout = tvm.decl_buffer(
inp.shape, inp.dtype, SCOPE_INP, inp.shape, inp.dtype, env.inp_scope,
scope=SCOPE_INP, offset_factor=inp_lanes, data_alignment=inp_lanes) scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
out_layout = tvm.decl_buffer( out_layout = tvm.decl_buffer(
out.shape, out.dtype, SCOPE_OUT, out.shape, out.dtype, env.acc_scope,
scope=SCOPE_OUT, offset_factor=out_lanes, data_alignment=out_lanes) scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
def intrin_func(ins, outs): def intrin_func(ins, outs):
"""Matrix-matrix multiply intrinsic function""" """Matrix-matrix multiply intrinsic function"""
...@@ -148,8 +141,11 @@ def intrin_gemm(mock=False): ...@@ -148,8 +141,11 @@ def intrin_gemm(mock=False):
def instr(index): def instr(index):
"""Generate matrix-matrix multiply VTA instruction""" """Generate matrix-matrix multiply VTA instruction"""
irb = tvm.ir_builder.create() irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE)) dev = env.dev
irb.scope_attr(VTA_AXIS, "coproc_uop_scope", VTA_PUSH_UOP) irb.scope_attr(dev.vta_axis, "coproc_scope",
dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
dev.vta_push_uop)
if index == 0 or index == 2: if index == 0 or index == 2:
irb.emit(tvm.call_extern( irb.emit(tvm.call_extern(
"int32", "VTAUopPush", "int32", "VTAUopPush",
...@@ -178,6 +174,3 @@ def intrin_gemm(mock=False): ...@@ -178,6 +174,3 @@ def intrin_gemm(mock=False):
binds={inp: inp_layout, binds={inp: inp_layout,
wgt: wgt_layout, wgt: wgt_layout,
out: out_layout}) out: out_layout})
GEMM = intrin_gemm()
GEVM = intrin_gevm()
"""Mock interface for skip part of compute """
from .intrin import intrin_gevm, intrin_gemm
GEMM = intrin_gemm(True)
GEVM = intrin_gevm(True)
DMA_COPY = "skip_dma_copy"
ALU = "skip_alu"
"""VTA RPC client function""" """VTA RPC client function"""
import os import os
from . import hw_spec as spec
from .environment import get_env
def reconfig_runtime(remote): def reconfig_runtime(remote):
"""Reconfigure remote runtime based on current hardware spec. """Reconfigure remote runtime based on current hardware spec.
...@@ -10,6 +11,7 @@ def reconfig_runtime(remote): ...@@ -10,6 +11,7 @@ def reconfig_runtime(remote):
remote : RPCSession remote : RPCSession
The TVM RPC session The TVM RPC session
""" """
env = get_env()
keys = ["VTA_LOG_WGT_WIDTH", keys = ["VTA_LOG_WGT_WIDTH",
"VTA_LOG_INP_WIDTH", "VTA_LOG_INP_WIDTH",
"VTA_LOG_ACC_WIDTH", "VTA_LOG_ACC_WIDTH",
...@@ -22,9 +24,10 @@ def reconfig_runtime(remote): ...@@ -22,9 +24,10 @@ def reconfig_runtime(remote):
"VTA_LOG_WGT_BUFF_SIZE", "VTA_LOG_WGT_BUFF_SIZE",
"VTA_LOG_ACC_BUFF_SIZE", "VTA_LOG_ACC_BUFF_SIZE",
"VTA_LOG_OUT_BUFF_SIZE"] "VTA_LOG_OUT_BUFF_SIZE"]
cflags = ["-D%s" % spec.VTA_TARGET]
cflags = ["-DVTA_%s_TARGET" % env.target.upper()]
for k in keys: for k in keys:
cflags += ["-D%s=%s" % (k, str(getattr(spec, k)))] cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))]
freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime") freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime")
freconfig(" ".join(cflags)) freconfig(" ".join(cflags))
......
"""Runtime function related hooks"""
from __future__ import absolute_import as _abs
import tvm
def thread_local_command_buffer():
"""Get thread local command buffer"""
ctx = tvm.call_extern("handle", "VTATLSCommandHandle")
return tvm.make.Call(
"handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0)
CB_HANDLE = thread_local_command_buffer()
VTA_AXIS = tvm.thread_axis("vta")
VTA_PUSH_UOP = tvm.make.StringImm("VTAPushGEMMOp")
SCOPE_INP = "local.inp_buffer"
SCOPE_OUT = "local.out_buffer"
SCOPE_WGT = "local.wgt_buffer"
DMA_COPY = "dma_copy"
ALU = "alu"
DEBUG_NO_SYNC = False
def get_task_qid(qid):
"""Get transformed queue index."""
return 1 if DEBUG_NO_SYNC else qid
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync")
def coproc_sync(op):
return tvm.call_extern(
"int32", "VTASynchronize", CB_HANDLE, 1<<31)
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push")
def coproc_dep_push(op):
return tvm.call_extern(
"int32", "VTADepPush", CB_HANDLE, op.args[0], op.args[1])
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop")
def coproc_dep_pop(op):
return tvm.call_extern(
"int32", "VTADepPop", CB_HANDLE, op.args[0], op.args[1])
...@@ -7,17 +7,13 @@ import tvm ...@@ -7,17 +7,13 @@ import tvm
import topi import topi
from nnvm.top import registry as reg, OpPattern from nnvm.top import registry as reg, OpPattern
from . import environment as vta
from . import intrin, runtime as vta
from intrin import GEVM
TARGET_BOARD = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
Workload = namedtuple("Conv2DWorkload", Workload = namedtuple("Conv2DWorkload",
['height', 'width', 'in_filter', 'out_filter', ['height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
def packed_conv2d(data, def packed_conv2d(data,
kernel, kernel,
padding, padding,
...@@ -58,10 +54,9 @@ def packed_conv2d(data, ...@@ -58,10 +54,9 @@ def packed_conv2d(data,
def _build(funcs, target, target_host): def _build(funcs, target, target_host):
tvm_t = tvm.target.create(target) tvm_t = tvm.target.create(target)
if tvm_t.device_name == "vta": if tvm_t.device_name == "vta":
return tvm.build(funcs, target="ext_dev", return tvm.build(funcs, target="ext_dev", target_host=target_host)
target_host=TARGET_BOARD) elif tvm_t.device_name == "rasp" or tvm_t.device_name == "vta-cpu":
elif tvm_t.device_name == "rasp" or tvm_t.device_name == "tcpu": return tvm.build(funcs, target=target_host)
return tvm.build(funcs, target=TARGET_BOARD)
return tvm.build(funcs, target=target) return tvm.build(funcs, target=target)
...@@ -224,36 +219,39 @@ def schedule_packed_conv2d(outs): ...@@ -224,36 +219,39 @@ def schedule_packed_conv2d(outs):
wrkld = _get_workload(data, pad_data, kernel, output) wrkld = _get_workload(data, pad_data, kernel, output)
plan = _WL2PLAN[wrkld] plan = _WL2PLAN[wrkld]
load_inp = load_wgt = load_out = store_out = "dma_copy" env = vta.get_env()
alu = "alu"
gevm = GEVM load_inp = load_wgt = load_out = store_out = env.dma_copy
alu = env.alu
gevm = env.gevm
# schedule1 # schedule1
oshape = topi.util.get_const_tuple(output.shape) oshape = topi.util.get_const_tuple(output.shape)
s = tvm.create_schedule(output.op) s = tvm.create_schedule(output.op)
# setup pad # setup pad
if pad_data is not None: if pad_data is not None:
cdata = pad_data cdata = pad_data
s[pad_data].set_scope(vta.SCOPE_INP) s[pad_data].set_scope(env.inp_scope)
else: else:
cdata = s.cache_read(data, vta.SCOPE_INP, [conv2d_stage]) cdata = s.cache_read(data, env.inp_scope, [conv2d_stage])
ckernel = s.cache_read(kernel, vta.SCOPE_WGT, [conv2d_stage]) ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage])
s[conv2d_stage].set_scope(vta.SCOPE_OUT) s[conv2d_stage].set_scope(env.acc_scope)
# cache read input # cache read input
cache_read_ewise = [] cache_read_ewise = []
for consumer, tensor in ewise_inputs: for consumer, tensor in ewise_inputs:
cache_read_ewise.append( cache_read_ewise.append(
s.cache_read(tensor, vta.SCOPE_OUT, [consumer])) s.cache_read(tensor, env.acc_scope, [consumer]))
# set ewise scope # set ewise scope
for op in ewise_ops: for op in ewise_ops:
s[op].set_scope(vta.SCOPE_OUT) s[op].set_scope(env.acc_scope)
s[op].pragma(s[op].op.axis[0], alu) s[op].pragma(s[op].op.axis[0], alu)
# tile # tile
oc_factor = (plan.oc_factor if plan.oc_factor oc_factor = (plan.oc_factor if plan.oc_factor
else wrkld.out_filter // vta.VTA_BLOCK_OUT) else wrkld.out_filter // vta.BLOCK_OUT)
h_factor = (plan.h_factor if plan.h_factor else oshape[2]) h_factor = (plan.h_factor if plan.h_factor else oshape[2])
w_factor = (plan.w_factor if plan.w_factor else oshape[3]) w_factor = (plan.w_factor if plan.w_factor else oshape[3])
......
...@@ -21,26 +21,26 @@ def my_clip(x, a_min, a_max): ...@@ -21,26 +21,26 @@ def my_clip(x, a_min, a_max):
host = "pynq" host = "pynq"
port = 9091 port = 9091
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon" target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon"
print_ir = False print_ir = False
def test_vta_conv2d(key, batch_size, wl, profile=True): def test_vta_conv2d(key, batch_size, wl, profile=True):
data_shape = (batch_size, wl.in_filter//vta.VTA_BLOCK_IN, env = vta.get_env()
wl.height, wl.width, vta.VTA_BLOCK_IN) data_shape = (batch_size, wl.in_filter // env.BLOCK_IN,
kernel_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, wl.in_filter//vta.VTA_BLOCK_IN, wl.height, wl.width, env.BLOCK_IN)
wl.hkernel, wl.wkernel, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN) kernel_shape = (wl.out_filter // env.BLOCK_OUT,
bias_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, 1, 1, vta.VTA_BLOCK_OUT) wl.in_filter // env.BLOCK_IN,
wl.hkernel, wl.wkernel,
env.BLOCK_OUT, env.BLOCK_IN)
bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
data = tvm.placeholder(data_shape, name="data", dtype=inp_dtype) data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=wgt_dtype) kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
bias = tvm.placeholder(bias_shape, name="kernel", dtype=out_dtype) bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype)
res_conv = vta_conv2d.packed_conv2d( res_conv = vta_conv2d.packed_conv2d(
data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride)) data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride))
...@@ -73,14 +73,14 @@ def test_vta_conv2d(key, batch_size, wl, profile=True): ...@@ -73,14 +73,14 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
bias_orig = np.abs(bias_orig) bias_orig = np.abs(bias_orig)
data_packed = data_orig.reshape( data_packed = data_orig.reshape(
batch_size, wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN, batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
wl.height, wl.width).transpose((0, 1, 3, 4, 2)) wl.height, wl.width).transpose((0, 1, 3, 4, 2))
kernel_packed = kernel_orig.reshape( kernel_packed = kernel_orig.reshape(
wl.out_filter//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT, wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT,
wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
bias_packed = bias_orig.reshape( bias_packed = bias_orig.reshape(
wl.out_filter//vta.VTA_BLOCK_OUT, 1, 1, vta.VTA_BLOCK_OUT) wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
res_shape = topi.util.get_const_tuple(res.shape) res_shape = topi.util.get_const_tuple(res.shape)
res_np = np.zeros(res_shape).astype(res.dtype) res_np = np.zeros(res_shape).astype(res.dtype)
...@@ -94,13 +94,13 @@ def test_vta_conv2d(key, batch_size, wl, profile=True): ...@@ -94,13 +94,13 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
(0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) (0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
if check_correctness: if check_correctness:
res_ref = mx.nd.Convolution( res_ref = mx.nd.Convolution(
mx.nd.array(data_orig.astype(out_dtype), mx.cpu(0)), mx.nd.array(data_orig.astype(env.acc_dtype), mx.cpu(0)),
mx.nd.array(kernel_orig.astype(out_dtype), mx.cpu(0)), mx.nd.array(kernel_orig.astype(env.acc_dtype), mx.cpu(0)),
stride=(wl.hstride, wl.wstride), stride=(wl.hstride, wl.wstride),
kernel=(wl.hkernel, wl.wkernel), kernel=(wl.hkernel, wl.wkernel),
num_filter=wl.out_filter, num_filter=wl.out_filter,
no_bias=True, no_bias=True,
pad=(wl.hpad, wl.wpad)).asnumpy().astype(out_dtype) pad=(wl.hpad, wl.wpad)).asnumpy().astype(env.acc_dtype)
res_ref = res_ref >> 8 res_ref = res_ref >> 8
res_ref += bias_orig.reshape(wl.out_filter, 1, 1) res_ref += bias_orig.reshape(wl.out_filter, 1, 1)
res_ref = np.clip(res_ref, 0, 127).astype("int8") res_ref = np.clip(res_ref, 0, 127).astype("int8")
...@@ -110,10 +110,10 @@ def test_vta_conv2d(key, batch_size, wl, profile=True): ...@@ -110,10 +110,10 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
def conv_normal(print_ir): def conv_normal(print_ir):
print("----- CONV2D End-to-End Test-------") print("----- CONV2D End-to-End Test-------")
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
s = vta_conv2d.schedule_packed_conv2d([res]) s = vta_conv2d.schedule_packed_conv2d([res])
if print_ir: if print_ir:
print(tvm.lower(s, [data, kernel, bias, res], simple_mode=True)) print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
cost = verify(s, True) cost = verify(s, True)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
......
...@@ -3,14 +3,15 @@ import vta ...@@ -3,14 +3,15 @@ import vta
import os import os
from tvm.contrib import rpc, util from tvm.contrib import rpc, util
env = vta.get_env()
host = "pynq" host = "pynq"
port = 9091 port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf" target = "llvm -target=armv7-none-linux-gnueabihf"
bit = "{}x{}x{}_{}bx{}b_{}_{}_{}_{}_100MHz_10ns.bit".format( bit = "{}x{}x{}_{}bx{}b_{}_{}_{}_{}_100MHz_10ns.bit".format(
vta.VTA_BATCH, vta.VTA_BLOCK_IN, vta.VTA_BLOCK_OUT, env.BATCH, env.BLOCK_IN, env.BLOCK_OUT,
vta.VTA_INP_WIDTH, vta.VTA_WGT_WIDTH, env.INP_WIDTH, env.WGT_WIDTH,
vta.VTA_LOG_UOP_BUFF_SIZE, vta.VTA_LOG_INP_BUFF_SIZE, env.LOG_UOP_BUFF_SIZE, env.LOG_INP_BUFF_SIZE,
vta.VTA_LOG_WGT_BUFF_SIZE, vta.VTA_LOG_OUT_BUFF_SIZE) env.LOG_WGT_BUFF_SIZE, env.LOG_ACC_BUFF_SIZE)
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
bitstream = os.path.join(curr_path, "../../../../vta_bitstreams/bitstreams/", bit) bitstream = os.path.join(curr_path, "../../../../vta_bitstreams/bitstreams/", bit)
......
import vta
def test_env():
env = vta.get_env()
mock = env.mock
assert mock.alu == "skip_alu"
if __name__ == "__main__":
test_env()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment