Commit dea167a8 by Tianqi Chen

[COMPILER] Refactor compiler to enable configuration (#21)

parent f5960ba6
...@@ -39,8 +39,9 @@ vta.program_fpga(remote, BITSTREAM_FILE) ...@@ -39,8 +39,9 @@ vta.program_fpga(remote, BITSTREAM_FILE)
if verbose: if verbose:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# Change to -device=tcpu to run cpu only inference. # Change to -device=vta-cpu to run cpu only inference.
target = "llvm -device=vta" target = "llvm -device=vta"
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
synset = eval(open(os.path.join(CATEG_FILE)).read()) synset = eval(open(os.path.join(CATEG_FILE)).read())
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224)) image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))
...@@ -117,15 +118,16 @@ graph_attr.set_dtype_inputs(sym, dtype_dict) ...@@ -117,15 +118,16 @@ graph_attr.set_dtype_inputs(sym, dtype_dict)
sym = sym.apply("InferType") sym = sym.apply("InferType")
with nnvm.compiler.build_config(opt_level=3): with nnvm.compiler.build_config(opt_level=3):
bdict = {}
if "vta" not in target: if "vta" not in target:
bdict = {"add_lower_pass": []} graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)
else: else:
bdict = {"add_lower_pass": vta.debug_mode(0)} with vta.build_config():
with tvm.build_config(**bdict):
graph, lib, params = nnvm.compiler.build( graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict, sym, target, shape_dict, dtype_dict,
params=params) params=params, target_host=target_host)
temp = util.tempdir() temp = util.tempdir()
lib.save(temp.relpath("graphlib.o")) lib.save(temp.relpath("graphlib.o"))
......
"""TVM-based VTA Compiler Toolchain""" """TVM-based VTA Compiler Toolchain"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .hw_spec import * from .environment import get_env, Environment
try: try:
from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU # allow optional import in config mode.
from .intrin import GEVM, GEMM
from .build import debug_mode
from . import mock, ir_pass
from . import arm_conv2d, vta_conv2d from . import arm_conv2d, vta_conv2d
except AttributeError: from .build_module import build_config, lower, build
pass from .rpc_client import reconfig_runtime, program_fpga
from .rpc_client import reconfig_runtime, program_fpga
try:
from . import graph from . import graph
except ImportError: except ImportError:
pass pass
"""Runtime function related hooks""" """VTA specific buildin for runtime."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from tvm import build_module
from . runtime import CB_HANDLE
from . import ir_pass from . import ir_pass
from .environment import get_env
def lift_coproc_scope(x): def lift_coproc_scope(x):
"""Lift coprocessings cope to the """
x = ir_pass.lift_alloc_to_scope_begin(x) x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.ir_pass.LiftAttrScope(x, "coproc_scope", False) x = tvm.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x return x
def early_rewrite(stmt): def early_rewrite(stmt):
"""Try to do storage rewrite in early pass."""
try: try:
return tvm.ir_pass.StorageRewrite(stmt) return tvm.ir_pass.StorageRewrite(stmt)
except tvm.TVMError: except tvm.TVMError:
return stmt return stmt
def debug_mode(debug_flag): def build_config(debug_flag=0, **kwargs):
"""Pass to enable vta debug mode. """Build a build config for VTA.
Parameters Parameters
---------- ----------
debug_flag : int debug_flag : int
The dbeug flag to be passed. The dbeug flag to be passed.
kwargs : dict
Additional configurations.
Returns Returns
------- -------
pass_list: list of function build_config: BuildConfig
The pass to set to build_config(add_lower_pass=vta.debug_mode(mode)) The build config that can be used in TVM.
Example
--------
.. code-block:: python
# build a vta module.
with vta.build_config():
vta_module = tvm.build(s, ...)
""" """
env = get_env()
def add_debug(stmt): def add_debug(stmt):
debug = tvm.call_extern( debug = tvm.call_extern(
"int32", "VTASetDebugMode", CB_HANDLE, debug_flag) "int32", "VTASetDebugMode",
env.dev.command_handle,
debug_flag)
return tvm.make.stmt_seq(debug, stmt) return tvm.make.stmt_seq(debug, stmt)
pass_list = [(1, ir_pass.inject_dma_intrin), pass_list = [(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy), (1, ir_pass.inject_skip_copy),
...@@ -48,8 +64,38 @@ def debug_mode(debug_flag): ...@@ -48,8 +64,38 @@ def debug_mode(debug_flag):
pass_list.append((2, ir_pass.inject_alu_intrin)) pass_list.append((2, ir_pass.inject_alu_intrin))
pass_list.append((3, ir_pass.fold_uop_loop)) pass_list.append((3, ir_pass.fold_uop_loop))
pass_list.append((3, ir_pass.cpu_access_rewrite)) pass_list.append((3, ir_pass.cpu_access_rewrite))
return pass_list return tvm.build_config(add_lower_pass=pass_list, **kwargs)
def lower(*args, **kwargs):
"""Thin wrapper of tvm.lower
This wrapper automatically applies VTA's build_config
if there is no user specified build_config in context.
# Add a lower pass to sync uop See Also
build_module.current_build_config().add_lower_pass = debug_mode(0) --------
tvm.lower : The original TVM's lower function
"""
cfg = tvm.build_module.current_build_config()
if not cfg.add_lower_pass:
with build_config():
return tvm.lower(*args, **kwargs)
return tvm.lower(*args, **kwargs)
def build(*args, **kwargs):
"""Thin wrapper of tvm.build
This wrapper automatically applies VTA's build_config
if there is no user specified build_config in context.
See Also
--------
tvm.build : The original TVM's build function
"""
cfg = tvm.build_module.current_build_config()
if not cfg.add_lower_pass:
with build_config():
return tvm.build(*args, **kwargs)
return tvm.build(*args, **kwargs)
"""Configurable VTA Hareware Environment scope."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import os
import copy
try:
# Allow missing import in config mode.
import tvm
from . import intrin
except ImportError:
pass
class DevContext(object):
"""Internal development context
This contains all the non-user facing compiler
internal context that is hold by the Environment.
Parameters
----------
env : Environment
The environment hosting the DevContext
Note
----
This class is introduced so we have a clear separation
of developer related stuffs and user facing attributes.
"""
# Memory id for DMA
MEM_ID_UOP = 0
MEM_ID_WGT = 1
MEM_ID_INP = 2
MEM_ID_ACC = 3
MEM_ID_OUT = 4
# VTA ALU Opcodes
ALU_OPCODE_MIN = 0
ALU_OPCODE_MAX = 1
ALU_OPCODE_ADD = 2
ALU_OPCODE_SHR = 3
# Task queue id (pipeline stage)
QID_LOAD_INP = 1
QID_LOAD_WGT = 1
QID_LOAD_OUT = 2
QID_STORE_OUT = 3
QID_COMPUTE = 2
QID_STORE_INP = 3
def __init__(self, env):
self.vta_axis = tvm.thread_axis("vta")
self.vta_push_uop = tvm.make.StringImm("VTAPushGEMMOp")
ctx = tvm.call_extern("handle", "VTATLSCommandHandle")
self.command_handle = tvm.make.Call(
"handle", "tvm_thread_context", [ctx],
tvm.expr.Call.Intrinsic, None, 0)
self.DEBUG_NO_SYNC = False
env._dev_ctx = self
self.gemm = intrin.gemm(env, env.mock_mode)
self.gevm = intrin.gevm(env, env.mock_mode)
def get_task_qid(self, qid):
"""Get transformed queue index."""
return 1 if self.DEBUG_NO_SYNC else qid
class Environment(object):
"""Hareware configuration object.
This object contains all the information
needed for compiling to a specific VTA backend.
Parameters
----------
cfg : dict of str to value.
The configuration parameters.
Example
--------
.. code-block:: python
# the following code reconfigures the environment
# temporarily to attributes specified in new_cfg.json
new_cfg = json.load(json.load(open("new_cfg.json")))
with vta.Environment(new_cfg):
# env works on the new environment
env = vta.get_env()
"""
current = None
cfg_keys = [
"target",
"LOG_INP_WIDTH",
"LOG_WGT_WIDTH",
"LOG_ACC_WIDTH",
"LOG_BATCH",
"LOG_BLOCK_IN",
"LOG_BLOCK_OUT",
"LOG_UOP_BUFF_SIZE",
"LOG_INP_BUFF_SIZE",
"LOG_WGT_BUFF_SIZE",
"LOG_ACC_BUFF_SIZE",
]
# constants
MAX_XFER = 1 << 22
# debug flags
DEBUG_DUMP_INSN = (1 << 1)
DEBUG_DUMP_UOP = (1 << 2)
DEBUG_SKIP_READ_BARRIER = (1 << 3)
DEBUG_SKIP_WRITE_BARRIER = (1 << 4)
# memory scopes
inp_scope = "local.inp_buffer"
wgt_scope = "local.wgt_buffer"
acc_scope = "local.acc_buffer"
# initialization function
def __init__(self, cfg):
# Log of input/activation width in bits
self.__dict__.update(cfg)
for key in self.cfg_keys:
if key not in cfg:
raise ValueError("Expect key %s in cfg" % key)
self.LOG_OUT_WIDTH = self.LOG_INP_WIDTH
self.LOG_OUT_BUFF_SIZE = (
self.LOG_ACC_BUFF_SIZE +
self.LOG_OUT_WIDTH -
self.LOG_ACC_WIDTH)
# width
self.INP_WIDTH = 1 << self.LOG_INP_WIDTH
self.WGT_WIDTH = 1 << self.LOG_WGT_WIDTH
self.ACC_WIDTH = 1 << self.LOG_ACC_WIDTH
self.BATCH = 1 << self.LOG_BATCH
self.BLOCK_IN = 1 << self.LOG_BLOCK_IN
self.BLOCK_OUT = 1 << self.LOG_BLOCK_OUT
self.OUT_WIDTH = self.INP_WIDTH
# buffer size
self.UOP_BUFF_SIZE = 1 << self.LOG_UOP_BUFF_SIZE
self.INP_BUFF_SIZE = 1 << self.LOG_INP_BUFF_SIZE
self.WGT_BUFF_SIZE = 1 << self.LOG_WGT_BUFF_SIZE
self.ACC_BUFF_SIZE = 1 << self.LOG_ACC_BUFF_SIZE
self.OUT_BUFF_SIZE = 1 << self.LOG_OUT_BUFF_SIZE
# bytes per buffer
self.INP_ELEM_BITS = (self.BATCH *
self.BLOCK_IN *
self.INP_WIDTH)
self.WGT_ELEM_BITS = (self.BLOCK_OUT *
self.BLOCK_IN *
self.WGT_WIDTH)
self.ACC_ELEM_BITS = (self.BATCH *
self.BLOCK_IN *
self.ACC_WIDTH)
self.INP_ELEM_BYTES = self.INP_ELEM_BITS // 8
self.WGT_ELEM_BYTES = self.WGT_ELEM_BITS // 8
self.ACC_ELEM_BYTES = self.ACC_ELEM_BITS // 8
# dtypes
self.acc_dtype = "int%d" % self.ACC_WIDTH
self.inp_dtype = "int%d" % self.INP_WIDTH
self.wgt_dtype = "int%d" % self.WGT_WIDTH
# lazy cached members
self.mock_mode = False
self._mock_env = None
self._dev_ctx = None
@property
def dev(self):
"""Developer context"""
if self._dev_ctx is None:
self._dev_ctx = DevContext(self)
return self._dev_ctx
@property
def mock(self):
"""A mock version of the Environment
The ALU, dma_copy and intrinsics will be
mocked to be nop.
"""
if self.mock_mode:
return self
if self._mock_env is None:
self._mock_env = copy.copy(self)
self._mock_env._dev_ctx = None
self._mock_env.mock_mode = True
return self._mock_env
@property
def dma_copy(self):
"""DMA copy pragma"""
return ("dma_copy"
if not self.mock_mode
else "skip_dma_copy")
@property
def alu(self):
"""ALU pragma"""
return ("alu"
if not self.mock_mode
else "skip_alu")
@property
def gemm(self):
"""GEMM intrinsic"""
return self.dev.gemm
@property
def gevm(self):
"""GEMM intrinsic"""
return self.dev.gevm
def get_env():
"""Get the current VTA Environment.
Returns
-------
env : Environment
The current environment.
"""
return Environment.current
# The memory information for the compiler
@tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope)
def mem_info_inp_buffer():
spec = get_env()
return tvm.make.node("MemoryInfo",
unit_bits=spec.INP_ELEM_BITS,
max_simd_bits=spec.INP_ELEM_BITS,
max_num_bits=spec.INP_BUFF_SIZE * 8,
head_address=None)
@tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope)
def mem_info_wgt_buffer():
spec = get_env()
return tvm.make.node("MemoryInfo",
unit_bits=spec.WGT_ELEM_BITS,
max_simd_bits=spec.WGT_ELEM_BITS,
max_num_bits=spec.WGT_BUFF_SIZE * 8,
head_address=None)
@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
def mem_info_out_buffer():
spec = get_env()
return tvm.make.node("MemoryInfo",
unit_bits=spec.ACC_ELEM_BITS,
max_simd_bits=spec.ACC_ELEM_BITS,
max_num_bits=spec.ACC_BUFF_SIZE * 8,
head_address=None)
# TVM related registration
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync")
def coproc_sync(op):
_ = op
return tvm.call_extern(
"int32", "VTASynchronize",
get_env().dev.command_handle, 1<<31)
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push")
def coproc_dep_push(op):
return tvm.call_extern(
"int32", "VTADepPush",
get_env().dev.command_handle,
op.args[0], op.args[1])
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop")
def coproc_dep_pop(op):
return tvm.call_extern(
"int32", "VTADepPop",
get_env().dev.command_handle,
op.args[0], op.args[1])
def _init_env():
"""Iniitalize the default global env"""
python_vta_dir = os.path.dirname(__file__)
filename = os.path.join(python_vta_dir, '../../config.mk')
keys = set()
for k in Environment.cfg_keys:
keys.add("VTA_" + k)
if not os.path.isfile(filename):
raise RuntimeError(
"Error: {} not found.make sure you have config.mk in your vta root"
.format(filename))
cfg = {}
with open(filename) as f:
for line in f:
for k in keys:
if k +" =" in line:
val = line.split("=")[1].strip()
cfg[k[4:]] = int(val)
cfg["target"] = "pynq"
return Environment(cfg)
Environment.current = _init_env()
...@@ -9,11 +9,13 @@ import argparse ...@@ -9,11 +9,13 @@ import argparse
import os import os
import ctypes import ctypes
import tvm import tvm
from tvm.contrib import rpc, util, cc from tvm.contrib import rpc, cc
@tvm.register_func("tvm.contrib.rpc.server.start", override=True) @tvm.register_func("tvm.contrib.rpc.server.start", override=True)
def server_start(): def server_start():
"""callback when server starts."""
# pylint: disable=unused-variable
curr_path = os.path.dirname( curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__))) os.path.abspath(os.path.expanduser(__file__)))
dll_path = os.path.abspath( dll_path = os.path.abspath(
......
"""VTA configuration constants (should match hw_spec.h"""
from __future__ import absolute_import as _abs
import os
python_vta_dir = os.path.dirname(__file__)
print python_vta_dir
filename = os.path.join(python_vta_dir, '../../config.mk')
VTA_PYNQ_BRAM_W = 32
VTA_PYNQ_BRAM_D = 1024
VTA_PYNQ_NUM_BRAM = 124
keys = ["VTA_LOG_INP_WIDTH", "VTA_LOG_WGT_WIDTH", "VTA_LOG_ACC_WIDTH",
"VTA_LOG_BATCH", "VTA_LOG_BLOCK_IN", "VTA_LOG_BLOCK_OUT",
"VTA_LOG_UOP_BUFF_SIZE", "VTA_LOG_INP_BUFF_SIZE",
"VTA_LOG_WGT_BUFF_SIZE", "VTA_LOG_ACC_BUFF_SIZE"]
params = {}
if os.path.isfile(filename):
with open(filename) as f:
for line in f:
for k in keys:
if k+" =" in line:
val = line.split("=")[1].strip()
params[k] = int(val)
# print params
else:
print "Error: {} not found. Please make sure you have config.mk in your vta root".format(filename)
exit()
# The Constants
VTA_INP_WIDTH = 1 << params["VTA_LOG_INP_WIDTH"]
VTA_WGT_WIDTH = 1 << params["VTA_LOG_WGT_WIDTH"]
VTA_OUT_WIDTH = 1 << params["VTA_LOG_ACC_WIDTH"]
VTA_TARGET = "VTA_PYNQ_TARGET"
# Dimensions of the GEMM unit
# (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT)
VTA_BATCH = 1 << params["VTA_LOG_BATCH"]
VTA_BLOCK_IN = 1 << params["VTA_LOG_BLOCK_IN"]
VTA_BLOCK_OUT = 1 << params["VTA_LOG_BLOCK_OUT"]
# log-2 On-chip uop buffer size in Bytes
VTA_LOG_UOP_BUFF_SIZE = params["VTA_LOG_UOP_BUFF_SIZE"]
# log-2 On-chip input buffer size in Bytes
VTA_LOG_INP_BUFF_SIZE = params["VTA_LOG_INP_BUFF_SIZE"]
# log-2 On-chip wgt buffer size in Bytes
VTA_LOG_WGT_BUFF_SIZE = params["VTA_LOG_WGT_BUFF_SIZE"]
# log-2 On-chip output buffer size in Bytes
VTA_LOG_OUT_BUFF_SIZE = params["VTA_LOG_ACC_BUFF_SIZE"]
# Uop buffer size
VTA_UOP_BUFF_SIZE = 1 << VTA_LOG_UOP_BUFF_SIZE
# Input buffer size
VTA_INP_BUFF_SIZE = 1 << VTA_LOG_INP_BUFF_SIZE
# On-chip wgt buffer size in Bytes
VTA_WGT_BUFF_SIZE = 1 << VTA_LOG_WGT_BUFF_SIZE
# Output buffer size.
VTA_OUT_BUFF_SIZE = 1 << VTA_LOG_OUT_BUFF_SIZE
# Number of bytes per buffer
VTA_INP_ELEM_BYTES = (VTA_BATCH*VTA_BLOCK_IN*VTA_INP_WIDTH//8)
VTA_WGT_ELEM_BYTES = (VTA_BLOCK_OUT*VTA_BLOCK_IN*VTA_WGT_WIDTH//8)
VTA_OUT_ELEM_BYTES = (VTA_BATCH*VTA_BLOCK_OUT*VTA_OUT_WIDTH//8)
# Maximum external buffer size in bytes
VTA_MAX_XFER = 1 << 22
# Number of elements
VTA_INP_BUFF_DEPTH = VTA_INP_BUFF_SIZE//VTA_INP_ELEM_BYTES
VTA_WGT_BUFF_DEPTH = VTA_WGT_BUFF_SIZE//VTA_WGT_ELEM_BYTES
VTA_OUT_BUFF_DEPTH = VTA_OUT_BUFF_SIZE//VTA_OUT_ELEM_BYTES
# Memory id for DMA
VTA_MEM_ID_UOP = 0
VTA_MEM_ID_WGT = 1
VTA_MEM_ID_INP = 2
VTA_MEM_ID_ACC = 3
VTA_MEM_ID_OUT = 4
# VTA ALU Opcodes
VTA_ALU_OPCODE_MIN = 0
VTA_ALU_OPCODE_MAX = 1
VTA_ALU_OPCODE_ADD = 2
VTA_ALU_OPCODE_SHR = 3
# Task queue id (pipeline stage)
VTA_QID_LOAD_INP = 1
VTA_QID_LOAD_WGT = 1
VTA_QID_LOAD_OUT = 2
VTA_QID_STORE_OUT = 3
VTA_QID_COMPUTE = 2
VTA_QID_STORE_INP = 3
# Debug flags
DEBUG_DUMP_INSN = (1 << 1)
DEBUG_DUMP_UOP = (1 << 2)
DEBUG_SKIP_READ_BARRIER = (1 << 3)
DEBUG_SKIP_WRITE_BARRIER = (1 << 4)
...@@ -2,65 +2,46 @@ ...@@ -2,65 +2,46 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from . import hw_spec as spec
from .runtime import VTA_AXIS, VTA_PUSH_UOP, get_task_qid
from .runtime import SCOPE_OUT, SCOPE_INP, SCOPE_WGT
# The memory information for the compiler def gevm(env, mock=False):
@tvm.register_func("tvm.info.mem.%s" % SCOPE_INP) """Vector-matrix multiply intrinsic
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits=spec.VTA_INP_ELEM_BYTES * 8,
max_simd_bits=spec.VTA_INP_ELEM_BYTES * 8,
max_num_bits=spec.VTA_INP_BUFF_SIZE * 8,
head_address=None)
@tvm.register_func("tvm.info.mem.%s" % SCOPE_WGT) Parameters
def mem_info_wgt_buffer(): ----------
return tvm.make.node("MemoryInfo", env : Environment
unit_bits=spec.VTA_WGT_ELEM_BYTES * 8, The Environment
max_simd_bits=spec.VTA_WGT_ELEM_BYTES * 8,
max_num_bits=spec.VTA_WGT_BUFF_SIZE * 8,
head_address=None)
@tvm.register_func("tvm.info.mem.%s" % SCOPE_OUT) mock : bool
def mem_info_out_buffer(): Whether create a mock version.
return tvm.make.node("MemoryInfo", """
unit_bits=spec.VTA_OUT_ELEM_BYTES * 8, wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
max_simd_bits=spec.VTA_OUT_ELEM_BYTES * 8, assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
max_num_bits=spec.VTA_OUT_BUFF_SIZE * 8, wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
head_address=None)
def intrin_gevm(mock=False):
"""Vector-matrix multiply intrinsic"""
wgt_lanes = spec.VTA_WGT_ELEM_BYTES * 8 // spec.VTA_WGT_WIDTH
assert wgt_lanes == spec.VTA_BLOCK_OUT * spec.VTA_BLOCK_IN
wgt_shape = (spec.VTA_BLOCK_OUT, spec.VTA_BLOCK_IN)
assert wgt_shape[0] * wgt_shape[1] == wgt_lanes assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
inp_lanes = spec.VTA_INP_ELEM_BYTES * 8 // spec.VTA_INP_WIDTH inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
out_lanes = spec.VTA_OUT_ELEM_BYTES * 8 // spec.VTA_OUT_WIDTH out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]), wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]),
dtype="int%d" % spec.VTA_WGT_WIDTH, dtype="int%d" % env.WGT_WIDTH,
name=SCOPE_WGT) name=env.wgt_scope)
inp = tvm.placeholder((wgt_shape[1], ), inp = tvm.placeholder((wgt_shape[1], ),
dtype="int%d" % spec.VTA_INP_WIDTH, dtype="int%d" % env.INP_WIDTH,
name=SCOPE_INP) name=env.inp_scope)
k = tvm.reduce_axis((0, wgt_shape[1]), name="k") k = tvm.reduce_axis((0, wgt_shape[1]), name="k")
out_dtype = "int%d" % spec.VTA_OUT_WIDTH out_dtype = "int%d" % env.ACC_WIDTH
out = tvm.compute((wgt_shape[0],), out = tvm.compute((wgt_shape[0],),
lambda i: tvm.sum(inp[k].astype(out_dtype) * lambda i: tvm.sum(inp[k].astype(out_dtype) *
wgt[i, k].astype(out_dtype), wgt[i, k].astype(out_dtype),
axis=[k]), axis=[k]),
name="out") name="out")
wgt_layout = tvm.decl_buffer( wgt_layout = tvm.decl_buffer(
wgt.shape, wgt.dtype, SCOPE_WGT, wgt.shape, wgt.dtype, env.wgt_scope,
scope=SCOPE_WGT, offset_factor=wgt_lanes, data_alignment=wgt_lanes) scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
inp_layout = tvm.decl_buffer( inp_layout = tvm.decl_buffer(
inp.shape, inp.dtype, SCOPE_INP, inp.shape, inp.dtype, env.inp_scope,
scope=SCOPE_INP, offset_factor=inp_lanes, data_alignment=inp_lanes) scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
out_layout = tvm.decl_buffer( out_layout = tvm.decl_buffer(
out.shape, out.dtype, SCOPE_OUT, out.shape, out.dtype, env.acc_scope,
scope=SCOPE_OUT, offset_factor=out_lanes, data_alignment=out_lanes) scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
def intrin_func(ins, outs): def intrin_func(ins, outs):
"""Vector-matrix multiply intrinsic function""" """Vector-matrix multiply intrinsic function"""
...@@ -69,8 +50,11 @@ def intrin_gevm(mock=False): ...@@ -69,8 +50,11 @@ def intrin_gevm(mock=False):
def instr(index): def instr(index):
"""Generate vector-matrix multiply VTA instruction""" """Generate vector-matrix multiply VTA instruction"""
irb = tvm.ir_builder.create() irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE)) dev = env.dev
irb.scope_attr(VTA_AXIS, "coproc_uop_scope", VTA_PUSH_UOP) irb.scope_attr(dev.vta_axis, "coproc_scope",
dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
dev.vta_push_uop)
if index == 0 or index == 2: if index == 0 or index == 2:
irb.emit(tvm.call_extern( irb.emit(tvm.call_extern(
"int32", "VTAUopPush", "int32", "VTAUopPush",
...@@ -101,45 +85,54 @@ def intrin_gevm(mock=False): ...@@ -101,45 +85,54 @@ def intrin_gevm(mock=False):
out: out_layout}) out: out_layout})
def intrin_gemm(mock=False): def gemm(env, mock=False):
"""Matrix-matrix multiply intrinsic""" """Matrix-matrix multiply intrinsic
wgt_lanes = spec.VTA_WGT_ELEM_BYTES * 8 // spec.VTA_WGT_WIDTH
assert wgt_lanes == spec.VTA_BLOCK_OUT * spec.VTA_BLOCK_IN Parameters
wgt_shape = (spec.VTA_BLOCK_OUT, spec.VTA_BLOCK_IN) ----------
env : Environment
The Environment
mock : bool
Whether create a mock version.
"""
wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
assert wgt_shape[0] * wgt_shape[1] == wgt_lanes assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
inp_lanes = spec.VTA_INP_ELEM_BYTES * 8 // spec.VTA_INP_WIDTH inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
assert inp_lanes == spec.VTA_BATCH * spec.VTA_BLOCK_IN assert inp_lanes == env.BATCH * env.BLOCK_IN
inp_shape = (spec.VTA_BATCH, spec.VTA_BLOCK_IN) inp_shape = (env.BATCH, env.BLOCK_IN)
assert inp_shape[0] * inp_shape[1] == inp_lanes assert inp_shape[0] * inp_shape[1] == inp_lanes
out_lanes = spec.VTA_OUT_ELEM_BYTES * 8 // spec.VTA_OUT_WIDTH out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
assert out_lanes == spec.VTA_BATCH * spec.VTA_BLOCK_OUT assert out_lanes == env.BATCH * env.BLOCK_OUT
out_shape = (spec.VTA_BATCH, spec.VTA_BLOCK_OUT) out_shape = (env.BATCH, env.BLOCK_OUT)
assert out_shape[0] * out_shape[1] == out_lanes assert out_shape[0] * out_shape[1] == out_lanes
wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]), wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]),
dtype="int%d" % spec.VTA_WGT_WIDTH, dtype="int%d" % env.WGT_WIDTH,
name=SCOPE_WGT) name=env.wgt_scope)
inp = tvm.placeholder((inp_shape[0], inp_shape[1]), inp = tvm.placeholder((inp_shape[0], inp_shape[1]),
dtype="int%d" % spec.VTA_INP_WIDTH, dtype="int%d" % env.INP_WIDTH,
name=SCOPE_INP) name=env.inp_scope)
k = tvm.reduce_axis((0, wgt_shape[1]), name="k") k = tvm.reduce_axis((0, wgt_shape[1]), name="k")
out_dtype = "int%d" % spec.VTA_OUT_WIDTH out_dtype = "int%d" % env.ACC_WIDTH
out = tvm.compute((out_shape[0], out_shape[1]), out = tvm.compute((out_shape[0], out_shape[1]),
lambda i, j: tvm.sum(inp[i, k].astype(out_dtype) * lambda i, j: tvm.sum(inp[i, k].astype(out_dtype) *
wgt[j, k].astype(out_dtype), wgt[j, k].astype(out_dtype),
axis=[k]), axis=[k]),
name="out") name="out")
wgt_layout = tvm.decl_buffer( wgt_layout = tvm.decl_buffer(
wgt.shape, wgt.dtype, SCOPE_WGT, wgt.shape, wgt.dtype, env.wgt_scope,
scope=SCOPE_WGT, offset_factor=wgt_lanes, data_alignment=wgt_lanes) scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
inp_layout = tvm.decl_buffer( inp_layout = tvm.decl_buffer(
inp.shape, inp.dtype, SCOPE_INP, inp.shape, inp.dtype, env.inp_scope,
scope=SCOPE_INP, offset_factor=inp_lanes, data_alignment=inp_lanes) scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
out_layout = tvm.decl_buffer( out_layout = tvm.decl_buffer(
out.shape, out.dtype, SCOPE_OUT, out.shape, out.dtype, env.acc_scope,
scope=SCOPE_OUT, offset_factor=out_lanes, data_alignment=out_lanes) scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
def intrin_func(ins, outs): def intrin_func(ins, outs):
"""Matrix-matrix multiply intrinsic function""" """Matrix-matrix multiply intrinsic function"""
...@@ -148,8 +141,11 @@ def intrin_gemm(mock=False): ...@@ -148,8 +141,11 @@ def intrin_gemm(mock=False):
def instr(index): def instr(index):
"""Generate matrix-matrix multiply VTA instruction""" """Generate matrix-matrix multiply VTA instruction"""
irb = tvm.ir_builder.create() irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE)) dev = env.dev
irb.scope_attr(VTA_AXIS, "coproc_uop_scope", VTA_PUSH_UOP) irb.scope_attr(dev.vta_axis, "coproc_scope",
dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
dev.vta_push_uop)
if index == 0 or index == 2: if index == 0 or index == 2:
irb.emit(tvm.call_extern( irb.emit(tvm.call_extern(
"int32", "VTAUopPush", "int32", "VTAUopPush",
...@@ -178,6 +174,3 @@ def intrin_gemm(mock=False): ...@@ -178,6 +174,3 @@ def intrin_gemm(mock=False):
binds={inp: inp_layout, binds={inp: inp_layout,
wgt: wgt_layout, wgt: wgt_layout,
out: out_layout}) out: out_layout})
GEMM = intrin_gemm()
GEVM = intrin_gevm()
"""Additional IR Pass for VTA""" """Additional IR Pass for VTA"""
# pylint: disable=len-as-condition
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from topi import util as util from topi import util as util
from . import hw_spec as spec from .environment import get_env
from . runtime import CB_HANDLE, VTA_AXIS, VTA_PUSH_UOP
from . runtime import SCOPE_OUT, SCOPE_INP, SCOPE_WGT, DMA_COPY, get_task_qid
def fold_uop_loop(stmt_in): def fold_uop_loop(stmt_in):
"""Pass to fold uop loops""" """Detect and fold uop loop.
VTA support uop programming model
that recognizes loop structure.
This pass detect the loop structure
and extract that into uop loop AST.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Output statement.
"""
env = get_env()
def _fold_outermost_loop(body): def _fold_outermost_loop(body):
stmt = body stmt = body
while not isinstance(stmt, tvm.stmt.For): while not isinstance(stmt, tvm.stmt.For):
if isinstance(stmt, (tvm.stmt.ProducerConsumer, )): if isinstance(stmt, (tvm.stmt.ProducerConsumer,)):
stmt = stmt.body stmt = stmt.body
else: else:
return None, body, None return None, body, None
...@@ -30,7 +47,8 @@ def fold_uop_loop(stmt_in): ...@@ -30,7 +47,8 @@ def fold_uop_loop(stmt_in):
args = [] args = []
args += op.args[:base_args] args += op.args[:base_args]
for i in range(3): for i in range(3):
m = tvm.arith.DetectLinearEquation(op.args[i + base_args], [loop_var]) m = tvm.arith.DetectLinearEquation(
op.args[i + base_args], [loop_var])
if not m: if not m:
fail[0] = True fail[0] = True
return op return op
...@@ -68,7 +86,7 @@ def fold_uop_loop(stmt_in): ...@@ -68,7 +86,7 @@ def fold_uop_loop(stmt_in):
def _do_fold(stmt): def _do_fold(stmt):
if (stmt.attr_key == "coproc_uop_scope" and if (stmt.attr_key == "coproc_uop_scope" and
isinstance(stmt.value, tvm.expr.StringImm) and isinstance(stmt.value, tvm.expr.StringImm) and
stmt.value.value == VTA_PUSH_UOP.value): stmt.value.value == env.dev.vta_push_uop.value):
body = stmt.body body = stmt.body
begins = [] begins = []
ends = [] ends = []
...@@ -98,7 +116,12 @@ def fold_uop_loop(stmt_in): ...@@ -98,7 +116,12 @@ def fold_uop_loop(stmt_in):
def cpu_access_rewrite(stmt_in): def cpu_access_rewrite(stmt_in):
"""Rewrite the code when there is CPU access happening. """Detect CPU access to VTA buffer and get address correctly.
VTA's buffer is an opaque handle that do not
correspond to address in CPU.
This pass detect CPU access and rewrite to use pointer
returned VTABufferCPUPtr for CPU access.
Parameters Parameters
---------- ----------
...@@ -110,6 +133,7 @@ def cpu_access_rewrite(stmt_in): ...@@ -110,6 +133,7 @@ def cpu_access_rewrite(stmt_in):
stmt_out : Stmt stmt_out : Stmt
Transformed statement Transformed statement
""" """
env = get_env()
rw_info = {} rw_info = {}
def _post_order(op): def _post_order(op):
if isinstance(op, tvm.stmt.Allocate): if isinstance(op, tvm.stmt.Allocate):
...@@ -119,7 +143,8 @@ def cpu_access_rewrite(stmt_in): ...@@ -119,7 +143,8 @@ def cpu_access_rewrite(stmt_in):
new_var = rw_info[buffer_var] new_var = rw_info[buffer_var]
let_stmt = tvm.make.LetStmt( let_stmt = tvm.make.LetStmt(
new_var, tvm.call_extern( new_var, tvm.call_extern(
"handle", "VTABufferCPUPtr", CB_HANDLE, "handle", "VTABufferCPUPtr",
env.dev.command_handle,
buffer_var), op.body) buffer_var), op.body)
alloc = tvm.make.Allocate( alloc = tvm.make.Allocate(
buffer_var, op.dtype, op.extents, buffer_var, op.dtype, op.extents,
...@@ -147,7 +172,8 @@ def cpu_access_rewrite(stmt_in): ...@@ -147,7 +172,8 @@ def cpu_access_rewrite(stmt_in):
for buffer_var, new_var in rw_info.items(): for buffer_var, new_var in rw_info.items():
stmt = tvm.make.LetStmt( stmt = tvm.make.LetStmt(
new_var, tvm.call_extern( new_var, tvm.call_extern(
"handle", "VTABufferCPUPtr", CB_HANDLE, "handle", "VTABufferCPUPtr",
env.dev.command_handle,
buffer_var), stmt) buffer_var), stmt)
return stmt return stmt
...@@ -184,9 +210,6 @@ def lift_alloc_to_scope_begin(stmt_in): ...@@ -184,9 +210,6 @@ def lift_alloc_to_scope_begin(stmt_in):
else: else:
raise RuntimeError("unexpected op") raise RuntimeError("unexpected op")
del slist[:] del slist[:]
# n = len(slist)
# for i in range(n):
# slist.pop()
return body return body
def _pre_order(op): def _pre_order(op):
...@@ -220,7 +243,7 @@ def lift_alloc_to_scope_begin(stmt_in): ...@@ -220,7 +243,7 @@ def lift_alloc_to_scope_begin(stmt_in):
def inject_skip_copy(stmt_in): def inject_skip_copy(stmt_in):
"""Pass to inject skip copy stmt, used in debug. """Pass to inject skip copy stmt, used for debug purpose.
Parameters Parameters
---------- ----------
...@@ -233,7 +256,7 @@ def inject_skip_copy(stmt_in): ...@@ -233,7 +256,7 @@ def inject_skip_copy(stmt_in):
Transformed statement Transformed statement
""" """
def _do_fold(stmt): def _do_fold(stmt):
if stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_dma_copy": if (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_dma_copy"):
return tvm.make.Evaluate(0) return tvm.make.Evaluate(0)
return None return None
return tvm.ir_pass.IRTransform( return tvm.ir_pass.IRTransform(
...@@ -286,6 +309,7 @@ def inject_dma_intrin(stmt_in): ...@@ -286,6 +309,7 @@ def inject_dma_intrin(stmt_in):
stmt_out : Stmt stmt_out : Stmt
Transformed statement Transformed statement
""" """
env = get_env()
def _check_compact(buf): def _check_compact(buf):
ndim = len(buf.shape) ndim = len(buf.shape)
size = tvm.const(1, buf.shape[0].dtype) size = tvm.const(1, buf.shape[0].dtype)
...@@ -321,7 +345,8 @@ def inject_dma_intrin(stmt_in): ...@@ -321,7 +345,8 @@ def inject_dma_intrin(stmt_in):
x_stride = buf.strides[ndim - base] x_stride = buf.strides[ndim - base]
next_base = base next_base = base
if not util.equal_const_int(x_stride % elem_block, 0): if not util.equal_const_int(x_stride % elem_block, 0):
raise RuntimeError("scope %s need to have block=%d, shape=%s, strides=%s" % ( raise RuntimeError(
"scope %s need to have block=%d, shape=%s, strides=%s" % (
scope, elem_block, buf.shape, buf.strides)) scope, elem_block, buf.shape, buf.strides))
for i in range(base, ndim + 1): for i in range(base, ndim + 1):
k = ndim - i k = ndim - i
...@@ -338,7 +363,6 @@ def inject_dma_intrin(stmt_in): ...@@ -338,7 +363,6 @@ def inject_dma_intrin(stmt_in):
shape = list(reversed(shape)) shape = list(reversed(shape))
return shape, strides return shape, strides
def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold): def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
elem_block = elem_bytes * 8 // elem_width elem_block = elem_bytes * 8 // elem_width
if buf.dtype != dtype: if buf.dtype != dtype:
...@@ -425,47 +449,50 @@ def inject_dma_intrin(stmt_in): ...@@ -425,47 +449,50 @@ def inject_dma_intrin(stmt_in):
def _inject_copy(src, dst, pad_before, pad_after, pad_value): def _inject_copy(src, dst, pad_before, pad_after, pad_value):
# FIXME: pad_value is ignored... # FIXME: pad_value is ignored...
_ = pad_value
if dst.scope == "global": if dst.scope == "global":
# Store # Store
if pad_before or pad_after: if pad_before or pad_after:
raise RuntimeError("Do not support copy into DRAM with pad") raise RuntimeError("Do not support copy into DRAM with pad")
if src.scope == SCOPE_OUT: if src.scope == env.acc_scope:
elem_width = spec.VTA_INP_WIDTH # output compression to inp type elem_width = env.INP_WIDTH # output compression to inp type
elem_bytes = spec.VTA_INP_ELEM_BYTES # output compression to inp type elem_bytes = env.INP_ELEM_BYTES # output compression to inp type
mem_type = spec.VTA_MEM_ID_OUT mem_type = env.dev.MEM_ID_OUT
data_type = "int%d" % spec.VTA_INP_WIDTH data_type = "int%d" % env.INP_WIDTH
task_qid = spec.VTA_QID_STORE_OUT task_qid = env.dev.QID_STORE_OUT
else: else:
raise RuntimeError("Do not support copy %s->dram" % (src.scope)) raise RuntimeError("Do not support copy %s->dram" % (src.scope))
_check_compact(src) _check_compact(src)
x_size, y_size, x_stride, offset = _get_2d_pattern( x_size, y_size, x_stride, offset = _get_2d_pattern(
dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True) dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True)
irb = tvm.ir_builder.create() irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(task_qid)) irb.scope_attr(env.dev.vta_axis, "coproc_scope",
env.dev.get_task_qid(task_qid))
irb.emit(tvm.call_extern( irb.emit(tvm.call_extern(
"int32", "VTAStoreBuffer2D", CB_HANDLE, "int32", "VTAStoreBuffer2D",
env.dev.command_handle,
src.access_ptr("r", "int32"), src.access_ptr("r", "int32"),
mem_type, dst.data, offset, x_size, y_size, x_stride)) mem_type, dst.data, offset, x_size, y_size, x_stride))
return irb.get() return irb.get()
elif src.scope == "global": elif src.scope == "global":
if dst.scope == SCOPE_OUT: if dst.scope == env.acc_scope:
elem_width = spec.VTA_OUT_WIDTH elem_width = env.ACC_WIDTH
elem_bytes = spec.VTA_OUT_ELEM_BYTES elem_bytes = env.ACC_ELEM_BYTES
mem_type = spec.VTA_MEM_ID_ACC mem_type = env.dev.MEM_ID_ACC
data_type = "int%d" % spec.VTA_OUT_WIDTH data_type = "int%d" % env.ACC_WIDTH
task_qid = spec.VTA_QID_LOAD_OUT task_qid = env.dev.QID_LOAD_OUT
elif dst.scope == SCOPE_INP: elif dst.scope == env.inp_scope:
elem_width = spec.VTA_INP_WIDTH elem_width = env.INP_WIDTH
elem_bytes = spec.VTA_INP_ELEM_BYTES elem_bytes = env.INP_ELEM_BYTES
mem_type = spec.VTA_MEM_ID_INP mem_type = env.dev.MEM_ID_INP
data_type = "int%d" % spec.VTA_INP_WIDTH data_type = "int%d" % env.INP_WIDTH
task_qid = spec.VTA_QID_LOAD_INP task_qid = env.dev.QID_LOAD_INP
elif dst.scope == SCOPE_WGT: elif dst.scope == env.wgt_scope:
elem_width = spec.VTA_WGT_WIDTH elem_width = env.WGT_WIDTH
elem_bytes = spec.VTA_WGT_ELEM_BYTES elem_bytes = env.WGT_ELEM_BYTES
mem_type = spec.VTA_MEM_ID_WGT mem_type = env.dev.MEM_ID_WGT
data_type = "int%d" % spec.VTA_WGT_WIDTH data_type = "int%d" % env.WGT_WIDTH
task_qid = spec.VTA_QID_LOAD_WGT task_qid = env.dev.QID_LOAD_WGT
else: else:
raise RuntimeError("Do not support copy dram->%s" % (dst.scope)) raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
# collect pad statistics # collect pad statistics
...@@ -498,13 +525,16 @@ def inject_dma_intrin(stmt_in): ...@@ -498,13 +525,16 @@ def inject_dma_intrin(stmt_in):
_check_compact(dst) _check_compact(dst)
x_size, y_size, x_stride, offset = _get_2d_pattern( x_size, y_size, x_stride, offset = _get_2d_pattern(
src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold) src, elem_width, elem_bytes, data_type,
dst.scope, allow_fold=allow_fold)
irb = tvm.ir_builder.create() irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(task_qid)) irb.scope_attr(env.dev.vta_axis, "coproc_scope",
env.dev.get_task_qid(task_qid))
irb.emit(tvm.call_extern( irb.emit(tvm.call_extern(
"int32", "VTALoadBuffer2D", CB_HANDLE, "int32", "VTALoadBuffer2D",
env.dev.command_handle,
src.data, offset, x_size, y_size, x_stride, src.data, offset, x_size, y_size, x_stride,
x_pad_before, y_pad_before, x_pad_before, y_pad_before,
x_pad_after, y_pad_after, x_pad_after, y_pad_after,
...@@ -514,7 +544,7 @@ def inject_dma_intrin(stmt_in): ...@@ -514,7 +544,7 @@ def inject_dma_intrin(stmt_in):
else: else:
raise RuntimeError("Donot support copy %s->%s" % (src.scope, dst.scope)) raise RuntimeError("Donot support copy %s->%s" % (src.scope, dst.scope))
return tvm.ir_pass.InjectCopyIntrin(stmt_in, DMA_COPY, _inject_copy) return tvm.ir_pass.InjectCopyIntrin(stmt_in, "dma_copy", _inject_copy)
def annotate_alu_coproc_scope(stmt_in): def annotate_alu_coproc_scope(stmt_in):
...@@ -530,11 +560,14 @@ def annotate_alu_coproc_scope(stmt_in): ...@@ -530,11 +560,14 @@ def annotate_alu_coproc_scope(stmt_in):
stmt_out : Stmt stmt_out : Stmt
Transformed statement Transformed statement
""" """
env = get_env()
def _do_fold(stmt): def _do_fold(stmt):
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"): if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"):
irb = tvm.ir_builder.create() irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE)) irb.scope_attr(env.dev.vta_axis, "coproc_scope",
irb.scope_attr(VTA_AXIS, "coproc_uop_scope", tvm.make.StringImm("VTAPushALUOp")) env.dev.get_task_qid(env.dev.QID_COMPUTE))
irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
tvm.make.StringImm("VTAPushALUOp"))
irb.emit(stmt) irb.emit(stmt)
return irb.get() return irb.get()
elif (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_alu"): elif (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_alu"):
...@@ -560,6 +593,7 @@ def inject_alu_intrin(stmt_in): ...@@ -560,6 +593,7 @@ def inject_alu_intrin(stmt_in):
stmt_out : Stmt stmt_out : Stmt
Transformed statement Transformed statement
""" """
env = get_env()
def _do_fold(stmt): def _do_fold(stmt):
def _equal(x, y): def _equal(x, y):
return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0) return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0)
...@@ -618,32 +652,32 @@ def inject_alu_intrin(stmt_in): ...@@ -618,32 +652,32 @@ def inject_alu_intrin(stmt_in):
tmp_body = tmp_body.body tmp_body = tmp_body.body
# Derive opcode # Derive opcode
if isinstance(loop_body.value, tvm.expr.Add): if isinstance(loop_body.value, tvm.expr.Add):
alu_opcode = spec.VTA_ALU_OPCODE_ADD alu_opcode = env.dev.ALU_OPCODE_ADD
lhs = loop_body.value.a lhs = loop_body.value.a
rhs = loop_body.value.b rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Sub): elif isinstance(loop_body.value, tvm.expr.Sub):
alu_opcode = spec.VTA_ALU_OPCODE_SUB alu_opcode = env.dev.ALU_OPCODE_SUB
lhs = loop_body.value.a lhs = loop_body.value.a
rhs = loop_body.value.b rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Mul): elif isinstance(loop_body.value, tvm.expr.Mul):
alu_opcode = spec.VTA_ALU_OPCODE_MUL alu_opcode = env.dev.ALU_OPCODE_MUL
lhs = loop_body.value.a lhs = loop_body.value.a
rhs = loop_body.value.b rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Min): elif isinstance(loop_body.value, tvm.expr.Min):
alu_opcode = spec.VTA_ALU_OPCODE_MIN alu_opcode = env.dev.ALU_OPCODE_MIN
lhs = loop_body.value.a lhs = loop_body.value.a
rhs = loop_body.value.b rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Max): elif isinstance(loop_body.value, tvm.expr.Max):
alu_opcode = spec.VTA_ALU_OPCODE_MAX alu_opcode = env.dev.ALU_OPCODE_MAX
lhs = loop_body.value.a lhs = loop_body.value.a
rhs = loop_body.value.b rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Call): elif isinstance(loop_body.value, tvm.expr.Call):
if loop_body.value.name == 'shift_left': if loop_body.value.name == 'shift_left':
alu_opcode = spec.VTA_ALU_OPCODE_SHR alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0] lhs = loop_body.value.args[0]
rhs = tvm.ir_pass.Simplify(-loop_body.value.args[1]) rhs = tvm.ir_pass.Simplify(-loop_body.value.args[1])
elif loop_body.value.name == 'shift_right': elif loop_body.value.name == 'shift_right':
alu_opcode = spec.VTA_ALU_OPCODE_SHR alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0] lhs = loop_body.value.args[0]
rhs = loop_body.value.args[1] rhs = loop_body.value.args[1]
else: else:
...@@ -696,26 +730,26 @@ def inject_alu_intrin(stmt_in): ...@@ -696,26 +730,26 @@ def inject_alu_intrin(stmt_in):
extents = list(extents) extents = list(extents)
assert len(src_coeff) > 1 assert len(src_coeff) > 1
assert len(dst_coeff) > 1 assert len(dst_coeff) > 1
assert len(extents) > 0 assert len(extents) != 0
assert tvm.ir_pass.Equal( assert tvm.ir_pass.Equal(
tvm.ir_pass.Simplify( tvm.ir_pass.Simplify(
src_coeff[-1]%(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)), 0) src_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir_pass.Equal( assert tvm.ir_pass.Equal(
tvm.ir_pass.Simplify( tvm.ir_pass.Simplify(
dst_coeff[-1]%(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)), 0) dst_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir_pass.Equal(src_coeff[-2], 1) assert tvm.ir_pass.Equal(src_coeff[-2], 1)
assert tvm.ir_pass.Equal(dst_coeff[-2], 1) assert tvm.ir_pass.Equal(dst_coeff[-2], 1)
if spec.VTA_BATCH > 1: if env.BATCH > 1:
assert len(src_coeff) > 2 assert len(src_coeff) > 2
assert len(dst_coeff) > 2 assert len(dst_coeff) > 2
assert len(extents) > 1 assert len(extents) > 1
assert tvm.ir_pass.Equal(src_coeff[-3], spec.VTA_BLOCK_OUT) assert tvm.ir_pass.Equal(src_coeff[-3], env.BLOCK_OUT)
assert tvm.ir_pass.Equal(dst_coeff[-3], spec.VTA_BLOCK_OUT) assert tvm.ir_pass.Equal(dst_coeff[-3], env.BLOCK_OUT)
# Apply tensorization of the loop coefficients # Apply tensorization of the loop coefficients
src_offset = src_coeff[-1] src_offset = src_coeff[-1]
dst_offset = dst_coeff[-1] dst_offset = dst_coeff[-1]
if spec.VTA_BATCH == 1: if env.BATCH == 1:
src_coeff = src_coeff[:-2] src_coeff = src_coeff[:-2]
dst_coeff = dst_coeff[:-2] dst_coeff = dst_coeff[:-2]
extents = extents[:-1] extents = extents[:-1]
...@@ -726,9 +760,9 @@ def inject_alu_intrin(stmt_in): ...@@ -726,9 +760,9 @@ def inject_alu_intrin(stmt_in):
src_coeff.append(src_offset) src_coeff.append(src_offset)
dst_coeff.append(dst_offset) dst_coeff.append(dst_offset)
src_coeff = [ src_coeff = [
tvm.ir_pass.Simplify(c/(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)) for c in src_coeff] tvm.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
dst_coeff = [ dst_coeff = [
tvm.ir_pass.Simplify(c/(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)) for c in dst_coeff] tvm.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
# Flatten the outer loops # Flatten the outer loops
if extents: if extents:
...@@ -750,13 +784,25 @@ def inject_alu_intrin(stmt_in): ...@@ -750,13 +784,25 @@ def inject_alu_intrin(stmt_in):
irb.emit(tvm.call_extern( irb.emit(tvm.call_extern(
"int32", "VTAUopLoopEnd")) "int32", "VTAUopLoopEnd"))
return irb.get() return irb.get()
return stmt return stmt
stmt_out = tvm.ir_pass.IRTransform( stmt_out = tvm.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"]) stmt_in, None, _do_fold, ["AttrStmt"])
return stmt_out return stmt_out
def debug_print(stmt): def debug_print(stmt):
print stmt """A debug pass that print the stmt
Parameters
----------
stmt : Stmt
The input statement
Returns
-------
stmt : Stmt
The
"""
print(stmt)
return stmt return stmt
"""Mock interface for skip part of compute """
from .intrin import intrin_gevm, intrin_gemm
GEMM = intrin_gemm(True)
GEVM = intrin_gevm(True)
DMA_COPY = "skip_dma_copy"
ALU = "skip_alu"
"""VTA RPC client function""" """VTA RPC client function"""
import os import os
from . import hw_spec as spec
from .environment import get_env
def reconfig_runtime(remote): def reconfig_runtime(remote):
"""Reconfigure remote runtime based on current hardware spec. """Reconfigure remote runtime based on current hardware spec.
...@@ -10,6 +11,7 @@ def reconfig_runtime(remote): ...@@ -10,6 +11,7 @@ def reconfig_runtime(remote):
remote : RPCSession remote : RPCSession
The TVM RPC session The TVM RPC session
""" """
env = get_env()
keys = ["VTA_LOG_WGT_WIDTH", keys = ["VTA_LOG_WGT_WIDTH",
"VTA_LOG_INP_WIDTH", "VTA_LOG_INP_WIDTH",
"VTA_LOG_ACC_WIDTH", "VTA_LOG_ACC_WIDTH",
...@@ -22,9 +24,10 @@ def reconfig_runtime(remote): ...@@ -22,9 +24,10 @@ def reconfig_runtime(remote):
"VTA_LOG_WGT_BUFF_SIZE", "VTA_LOG_WGT_BUFF_SIZE",
"VTA_LOG_ACC_BUFF_SIZE", "VTA_LOG_ACC_BUFF_SIZE",
"VTA_LOG_OUT_BUFF_SIZE"] "VTA_LOG_OUT_BUFF_SIZE"]
cflags = ["-D%s" % spec.VTA_TARGET]
cflags = ["-DVTA_%s_TARGET" % env.target.upper()]
for k in keys: for k in keys:
cflags += ["-D%s=%s" % (k, str(getattr(spec, k)))] cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))]
freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime") freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime")
freconfig(" ".join(cflags)) freconfig(" ".join(cflags))
......
"""Runtime function related hooks"""
from __future__ import absolute_import as _abs
import tvm
def thread_local_command_buffer():
"""Get thread local command buffer"""
ctx = tvm.call_extern("handle", "VTATLSCommandHandle")
return tvm.make.Call(
"handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0)
CB_HANDLE = thread_local_command_buffer()
VTA_AXIS = tvm.thread_axis("vta")
VTA_PUSH_UOP = tvm.make.StringImm("VTAPushGEMMOp")
SCOPE_INP = "local.inp_buffer"
SCOPE_OUT = "local.out_buffer"
SCOPE_WGT = "local.wgt_buffer"
DMA_COPY = "dma_copy"
ALU = "alu"
DEBUG_NO_SYNC = False
def get_task_qid(qid):
"""Get transformed queue index."""
return 1 if DEBUG_NO_SYNC else qid
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync")
def coproc_sync(op):
return tvm.call_extern(
"int32", "VTASynchronize", CB_HANDLE, 1<<31)
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push")
def coproc_dep_push(op):
return tvm.call_extern(
"int32", "VTADepPush", CB_HANDLE, op.args[0], op.args[1])
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop")
def coproc_dep_pop(op):
return tvm.call_extern(
"int32", "VTADepPop", CB_HANDLE, op.args[0], op.args[1])
...@@ -7,17 +7,13 @@ import tvm ...@@ -7,17 +7,13 @@ import tvm
import topi import topi
from nnvm.top import registry as reg, OpPattern from nnvm.top import registry as reg, OpPattern
from . import environment as vta
from . import intrin, runtime as vta
from intrin import GEVM
TARGET_BOARD = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
Workload = namedtuple("Conv2DWorkload", Workload = namedtuple("Conv2DWorkload",
['height', 'width', 'in_filter', 'out_filter', ['height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
def packed_conv2d(data, def packed_conv2d(data,
kernel, kernel,
padding, padding,
...@@ -58,10 +54,9 @@ def packed_conv2d(data, ...@@ -58,10 +54,9 @@ def packed_conv2d(data,
def _build(funcs, target, target_host): def _build(funcs, target, target_host):
tvm_t = tvm.target.create(target) tvm_t = tvm.target.create(target)
if tvm_t.device_name == "vta": if tvm_t.device_name == "vta":
return tvm.build(funcs, target="ext_dev", return tvm.build(funcs, target="ext_dev", target_host=target_host)
target_host=TARGET_BOARD) elif tvm_t.device_name == "rasp" or tvm_t.device_name == "vta-cpu":
elif tvm_t.device_name == "rasp" or tvm_t.device_name == "tcpu": return tvm.build(funcs, target=target_host)
return tvm.build(funcs, target=TARGET_BOARD)
return tvm.build(funcs, target=target) return tvm.build(funcs, target=target)
...@@ -224,36 +219,39 @@ def schedule_packed_conv2d(outs): ...@@ -224,36 +219,39 @@ def schedule_packed_conv2d(outs):
wrkld = _get_workload(data, pad_data, kernel, output) wrkld = _get_workload(data, pad_data, kernel, output)
plan = _WL2PLAN[wrkld] plan = _WL2PLAN[wrkld]
load_inp = load_wgt = load_out = store_out = "dma_copy" env = vta.get_env()
alu = "alu"
gevm = GEVM load_inp = load_wgt = load_out = store_out = env.dma_copy
alu = env.alu
gevm = env.gevm
# schedule1 # schedule1
oshape = topi.util.get_const_tuple(output.shape) oshape = topi.util.get_const_tuple(output.shape)
s = tvm.create_schedule(output.op) s = tvm.create_schedule(output.op)
# setup pad # setup pad
if pad_data is not None: if pad_data is not None:
cdata = pad_data cdata = pad_data
s[pad_data].set_scope(vta.SCOPE_INP) s[pad_data].set_scope(env.inp_scope)
else: else:
cdata = s.cache_read(data, vta.SCOPE_INP, [conv2d_stage]) cdata = s.cache_read(data, env.inp_scope, [conv2d_stage])
ckernel = s.cache_read(kernel, vta.SCOPE_WGT, [conv2d_stage]) ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage])
s[conv2d_stage].set_scope(vta.SCOPE_OUT) s[conv2d_stage].set_scope(env.acc_scope)
# cache read input # cache read input
cache_read_ewise = [] cache_read_ewise = []
for consumer, tensor in ewise_inputs: for consumer, tensor in ewise_inputs:
cache_read_ewise.append( cache_read_ewise.append(
s.cache_read(tensor, vta.SCOPE_OUT, [consumer])) s.cache_read(tensor, env.acc_scope, [consumer]))
# set ewise scope # set ewise scope
for op in ewise_ops: for op in ewise_ops:
s[op].set_scope(vta.SCOPE_OUT) s[op].set_scope(env.acc_scope)
s[op].pragma(s[op].op.axis[0], alu) s[op].pragma(s[op].op.axis[0], alu)
# tile # tile
oc_factor = (plan.oc_factor if plan.oc_factor oc_factor = (plan.oc_factor if plan.oc_factor
else wrkld.out_filter // vta.VTA_BLOCK_OUT) else wrkld.out_filter // vta.BLOCK_OUT)
h_factor = (plan.h_factor if plan.h_factor else oshape[2]) h_factor = (plan.h_factor if plan.h_factor else oshape[2])
w_factor = (plan.w_factor if plan.w_factor else oshape[3]) w_factor = (plan.w_factor if plan.w_factor else oshape[3])
......
...@@ -11,9 +11,6 @@ import pandas as pd ...@@ -11,9 +11,6 @@ import pandas as pd
host = "pynq" host = "pynq"
port = 9091 port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon" target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon"
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
Workload = namedtuple("Conv2DWorkload", Workload = namedtuple("Conv2DWorkload",
['batch', 'height', 'width', 'in_filter', 'out_filter', ['batch', 'height', 'width', 'in_filter', 'out_filter',
...@@ -46,12 +43,13 @@ class Conv2DSchedule(object): ...@@ -46,12 +43,13 @@ class Conv2DSchedule(object):
Schedule = Conv2DSchedule Schedule = Conv2DSchedule
def get_insn_count(layer, sched): def get_insn_count(layer, sched):
env = vta.get_env()
b, h, w, ci, co = sched b, h, w, ci, co = sched
b_factor = b b_factor = b
h_factor = layer.height / h h_factor = layer.height / h
w_factor = layer.width / w w_factor = layer.width / w
ci_factor = int(np.ceil(float(layer.in_filter) / (ci * vta.VTA_BLOCK_IN))) ci_factor = int(np.ceil(float(layer.in_filter) / (ci * env.BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) / (co * vta.VTA_BLOCK_OUT))) co_factor = int(np.ceil(float(layer.out_filter) / (co * env.BLOCK_OUT)))
input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
output_xfers = b_factor * h_factor * w_factor * co_factor output_xfers = b_factor * h_factor * w_factor * co_factor
...@@ -62,6 +60,7 @@ def get_insn_count(layer, sched): ...@@ -62,6 +60,7 @@ def get_insn_count(layer, sched):
return insn_count return insn_count
def find_schedules(layer, mtOnly=False, bestOnly=False): def find_schedules(layer, mtOnly=False, bestOnly=False):
env = vta.get_env()
# Helper function to get factors # Helper function to get factors
def find_factors(n): def find_factors(n):
factors = [] factors = []
...@@ -70,11 +69,11 @@ def find_schedules(layer, mtOnly=False, bestOnly=False): ...@@ -70,11 +69,11 @@ def find_schedules(layer, mtOnly=False, bestOnly=False):
factors.append(i) factors.append(i)
return factors return factors
# Scheduling exploration # Scheduling exploration
batch_factors = find_factors(int(np.ceil(float(layer.batch) / vta.VTA_BATCH))) batch_factors = find_factors(int(np.ceil(float(layer.batch) / env.BATCH)))
height_factors = find_factors(layer.height / layer.hstride) height_factors = find_factors(layer.height / layer.hstride)
width_factors = find_factors(layer.width / layer.wstride) width_factors = find_factors(layer.width / layer.wstride)
cin_factors = find_factors(int(np.ceil(float(layer.in_filter) / vta.VTA_BLOCK_IN))) cin_factors = find_factors(int(np.ceil(float(layer.in_filter) / env.BLOCK_IN)))
cout_factors = find_factors(int(np.ceil(float(layer.out_filter) / vta.VTA_BLOCK_OUT))) cout_factors = find_factors(int(np.ceil(float(layer.out_filter) / env.BLOCK_OUT)))
ht_factors = [1, 2] ht_factors = [1, 2]
cot_factors = [1, 2] cot_factors = [1, 2]
# Explore schedules # Explore schedules
...@@ -90,12 +89,12 @@ def find_schedules(layer, mtOnly=False, bestOnly=False): ...@@ -90,12 +89,12 @@ def find_schedules(layer, mtOnly=False, bestOnly=False):
if ci == 1: if ci == 1:
schedules.append([b, h, w, ci, co]) schedules.append([b, h, w, ci, co])
# Filter the schedules that wouldn't work in the available BRAM sizes # Filter the schedules that wouldn't work in the available BRAM sizes
input_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_IN * vta.VTA_INP_WIDTH input_elem_size_b = env.BATCH * env.BLOCK_IN * env.INP_WIDTH
weight_elem_size_b = vta.VTA_BLOCK_IN * vta.VTA_BLOCK_OUT * vta.VTA_WGT_WIDTH weight_elem_size_b = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH
output_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_OUT * vta.VTA_OUT_WIDTH output_elem_size_b = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH
input_brams_capacity_b = vta.VTA_INP_BUFF_SIZE * 8 input_brams_capacity_b = env.INP_BUFF_SIZE * 8
weight_brams_capacity_b = vta.VTA_WGT_BUFF_SIZE * 8 weight_brams_capacity_b = env.WGT_BUFF_SIZE * 8
output_brams_capacity_b = vta.VTA_OUT_BUFF_SIZE * 8 output_brams_capacity_b = env.OUT_BUFF_SIZE * 8
fil_sched = [] fil_sched = []
xfer_size = [] xfer_size = []
for sched in schedules: for sched in schedules:
...@@ -125,7 +124,7 @@ def find_schedules(layer, mtOnly=False, bestOnly=False): ...@@ -125,7 +124,7 @@ def find_schedules(layer, mtOnly=False, bestOnly=False):
if input_tile_elems*input_elem_size_b <= input_brams_capacity_b/(cot*ht) and \ if input_tile_elems*input_elem_size_b <= input_brams_capacity_b/(cot*ht) and \
weight_tile_elems*weight_elem_size_b <= weight_brams_capacity_b and \ weight_tile_elems*weight_elem_size_b <= weight_brams_capacity_b and \
output_tile_elems*output_elem_size_b <= output_brams_capacity_b/(cot*ht) and \ output_tile_elems*output_elem_size_b <= output_brams_capacity_b/(cot*ht) and \
insn_count <= vta.VTA_MAX_XFER / 16 and \ insn_count <= env.MAX_XFER / 16 and \
h > 2 and w > 2: h > 2 and w > 2:
schedule = Schedule(oc_factor=co, ko_factor=ci, h_factor=h, schedule = Schedule(oc_factor=co, ko_factor=ci, h_factor=h,
w_factor=w, oc_nthread=cot, h_nthread=ht) w_factor=w, oc_nthread=cot, h_nthread=ht)
...@@ -137,10 +136,11 @@ def find_schedules(layer, mtOnly=False, bestOnly=False): ...@@ -137,10 +136,11 @@ def find_schedules(layer, mtOnly=False, bestOnly=False):
return fil_sched return fil_sched
def get_data_movementB(sched, layer): def get_data_movementB(sched, layer):
env = vta.get_env()
# Derive data movement # Derive data movement
input_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_IN * vta.VTA_INP_WIDTH input_elem_size_b = env.BATCH * env.BLOCK_IN * env.INP_WIDTH
weight_elem_size_b = vta.VTA_BLOCK_IN * vta.VTA_BLOCK_OUT * vta.VTA_WGT_WIDTH weight_elem_size_b = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH
output_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_OUT * vta.VTA_OUT_WIDTH output_elem_size_b = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH
b = sched.b_factor b = sched.b_factor
h = sched.h_factor h = sched.h_factor
w = sched.w_factor w = sched.w_factor
...@@ -154,11 +154,11 @@ def get_data_movementB(sched, layer): ...@@ -154,11 +154,11 @@ def get_data_movementB(sched, layer):
weight_tile_elems = layer.hkernel * layer.wkernel * ci weight_tile_elems = layer.hkernel * layer.wkernel * ci
output_tile_elems = b * h * w * co output_tile_elems = b * h * w * co
# Derive factors # Derive factors
b_factor = int(np.ceil(float(layer.batch) / (b * vta.VTA_BATCH))) b_factor = int(np.ceil(float(layer.batch) / (b * env.BATCH)))
h_factor = (layer.height / layer.hstride) / h h_factor = (layer.height / layer.hstride) / h
w_factor = (layer.width / layer.wstride) / w w_factor = (layer.width / layer.wstride) / w
ci_factor = int(np.ceil(float(layer.in_filter) / (ci * vta.VTA_BLOCK_IN))) ci_factor = int(np.ceil(float(layer.in_filter) / (ci * env.BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) / (co * vta.VTA_BLOCK_OUT))) co_factor = int(np.ceil(float(layer.out_filter) / (co * env.BLOCK_OUT)))
# Derive transfers # Derive transfers
input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
...@@ -171,19 +171,20 @@ def get_data_movementB(sched, layer): ...@@ -171,19 +171,20 @@ def get_data_movementB(sched, layer):
return total_xfer_B return total_xfer_B
def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True): def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True):
assert batch_size % vta.VTA_BATCH == 0 env = vta.get_env()
assert wl.in_filter % vta.VTA_BLOCK_IN == 0 assert batch_size % env.BATCH == 0
assert wl.out_filter % vta.VTA_BLOCK_OUT == 0 assert wl.in_filter % env.BLOCK_IN == 0
data_shape = (batch_size//vta.VTA_BATCH, wl.in_filter//vta.VTA_BLOCK_IN, assert wl.out_filter % env.BLOCK_OUT == 0
wl.height, wl.width, vta.VTA_BATCH, vta.VTA_BLOCK_IN) data_shape = (batch_size//env.BATCH, wl.in_filter//env.BLOCK_IN,
kernel_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, wl.in_filter//vta.VTA_BLOCK_IN, wl.height, wl.width, env.BATCH, env.BLOCK_IN)
wl.hkernel, wl.wkernel, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN) kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN,
wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
res_shape = (batch_size//vta.VTA_BATCH, wl.out_filter//vta.VTA_BLOCK_OUT, res_shape = (batch_size//env.BATCH, wl.out_filter//env.BLOCK_OUT,
fout_height, fout_width, vta.VTA_BATCH, vta.VTA_BLOCK_OUT) fout_height, fout_width, env.BATCH, env.BLOCK_OUT)
data = tvm.placeholder(data_shape, name="data", dtype=inp_dtype) data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=wgt_dtype) kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
if wl.hpad or wl.wpad: if wl.hpad or wl.wpad:
data_buf = topi.nn.pad(data, [0, 0, wl.hpad, wl.wpad, 0, 0], name="data_buf") data_buf = topi.nn.pad(data, [0, 0, wl.hpad, wl.wpad, 0, 0], name="data_buf")
else: else:
...@@ -191,22 +192,22 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True) ...@@ -191,22 +192,22 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
kernel_buf = tvm.compute(kernel_shape, lambda *i: kernel(*i), "kernel_buf") kernel_buf = tvm.compute(kernel_shape, lambda *i: kernel(*i), "kernel_buf")
di = tvm.reduce_axis((0, wl.hkernel), name='di') di = tvm.reduce_axis((0, wl.hkernel), name='di')
dj = tvm.reduce_axis((0, wl.wkernel), name='dj') dj = tvm.reduce_axis((0, wl.wkernel), name='dj')
ko = tvm.reduce_axis((0, wl.in_filter//vta.VTA_BLOCK_IN), name='ko') ko = tvm.reduce_axis((0, wl.in_filter//env.BLOCK_IN), name='ko')
ki = tvm.reduce_axis((0, vta.VTA_BLOCK_IN), name='ki') ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki')
res_cnv = tvm.compute( res_cnv = tvm.compute(
res_shape, res_shape,
lambda bo, co, i, j, bi, ci: tvm.sum( lambda bo, co, i, j, bi, ci: tvm.sum(
data_buf[bo, ko, i*wl.hstride+di, j*wl.wstride+dj, bi, ki].astype(out_dtype) * data_buf[bo, ko, i*wl.hstride+di, j*wl.wstride+dj, bi, ki].astype(env.acc_dtype) *
kernel_buf[co, ko, di, dj, ci, ki].astype(out_dtype), kernel_buf[co, ko, di, dj, ci, ki].astype(env.acc_dtype),
axis=[ko, di, dj, ki]), axis=[ko, di, dj, ki]),
name="res_cnv") name="res_cnv")
res_shf = tvm.compute(res_shape, lambda *i: res_cnv(*i) >> 8, name="res_shf") res_shf = tvm.compute(res_shape, lambda *i: res_cnv(*i) >> 8, name="res_shf")
res = tvm.compute(res_shape, lambda *i: res_shf(*i).astype(inp_dtype), name="res") res = tvm.compute(res_shape, lambda *i: res_shf(*i).astype(env.inp_dtype), name="res")
num_ops = batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter num_ops = batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
total_xfer_B = get_data_movementB(sched, wl) total_xfer_B = get_data_movementB(sched, wl)
def verify(s, check_correctness): def verify(s, check_correctness):
mod = tvm.build(s, [data, kernel, res], "ext_dev", target, name="conv2d") mod = vta.build(s, [data, kernel, res], "ext_dev", target, name="conv2d")
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port) remote = rpc.connect(host, port)
mod.save(temp.relpath("conv2d.o")) mod.save(temp.relpath("conv2d.o"))
...@@ -220,12 +221,12 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True) ...@@ -220,12 +221,12 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
kernel_orig = np.random.randint( kernel_orig = np.random.randint(
-128, 128, size=(wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)).astype(kernel.dtype) -128, 128, size=(wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)).astype(kernel.dtype)
data_packed = data_orig.reshape( data_packed = data_orig.reshape(
batch_size//vta.VTA_BATCH, vta.VTA_BATCH, batch_size//env.BATCH, env.BATCH,
wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN, wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3)) wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3))
kernel_packed = kernel_orig.reshape( kernel_packed = kernel_orig.reshape(
wl.out_filter//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT, wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT,
wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN, wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
res_np = np.zeros(res_shape).astype(res.dtype) res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_packed, ctx) data_arr = tvm.nd.array(data_packed, ctx)
...@@ -237,13 +238,13 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True) ...@@ -237,13 +238,13 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
(0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) (0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
if check_correctness: if check_correctness:
res_ref = mx.nd.Convolution( res_ref = mx.nd.Convolution(
mx.nd.array(data_orig.astype(out_dtype), mx.cpu(0)), mx.nd.array(data_orig.astype(env.acc_dtype), mx.cpu(0)),
mx.nd.array(kernel_orig.astype(out_dtype), mx.cpu(0)), mx.nd.array(kernel_orig.astype(env.acc_dtype), mx.cpu(0)),
stride=(wl.hstride, wl.wstride), stride=(wl.hstride, wl.wstride),
kernel=(wl.hkernel, wl.wkernel), kernel=(wl.hkernel, wl.wkernel),
num_filter=wl.out_filter, num_filter=wl.out_filter,
no_bias=True, no_bias=True,
pad=(wl.hpad, wl.wpad)).asnumpy().astype(out_dtype) pad=(wl.hpad, wl.wpad)).asnumpy().astype(env.acc_dtype)
res_ref = np.right_shift(res_ref, 8).astype(res.dtype) res_ref = np.right_shift(res_ref, 8).astype(res.dtype)
np.testing.assert_allclose(res_unpack, res_ref) np.testing.assert_allclose(res_unpack, res_ref)
print("Correctness check pass...") print("Correctness check pass...")
...@@ -253,13 +254,13 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True) ...@@ -253,13 +254,13 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
print_ir, check_correctness): print_ir, check_correctness):
# schedule1 # schedule1
s = tvm.create_schedule(res.op) s = tvm.create_schedule(res.op)
s[data_buf].set_scope(vta.SCOPE_INP) s[data_buf].set_scope(env.inp_scope)
s[kernel_buf].set_scope(vta.SCOPE_WGT) s[kernel_buf].set_scope(env.wgt_scope)
s[res_cnv].set_scope(vta.SCOPE_OUT) s[res_cnv].set_scope(env.acc_scope)
s[res_shf].set_scope(vta.SCOPE_OUT) s[res_shf].set_scope(env.acc_scope)
# tile # tile
oc_factor = (sched.oc_factor if sched.oc_factor oc_factor = (sched.oc_factor if sched.oc_factor
else wl.out_filter // vta.VTA_BLOCK_OUT) else wl.out_filter // env.BLOCK_OUT)
h_factor = (sched.h_factor if sched.h_factor else fout_height) h_factor = (sched.h_factor if sched.h_factor else fout_height)
w_factor = (sched.w_factor if sched.w_factor else fout_width) w_factor = (sched.w_factor if sched.w_factor else fout_width)
xbo, xco, xi, xj, xbi, xci = s[res].op.axis xbo, xco, xi, xj, xbi, xci = s[res].op.axis
...@@ -304,8 +305,8 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True) ...@@ -304,8 +305,8 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
def run_test(header, print_ir, check_correctness): def run_test(header, print_ir, check_correctness):
s = [1, sched.oc_factor, sched.ko_factor, sched.h_factor, sched.w_factor] s = [1, sched.oc_factor, sched.ko_factor, sched.h_factor, sched.w_factor]
cost = run_schedule( cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY, env.dma_copy, env.dma_copy,
vta.GEMM, vta.ALU, vta.DMA_COPY, env.gemm, env.alu, env.dma_copy,
print_ir, check_correctness) print_ir, check_correctness)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
...@@ -316,15 +317,15 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True) ...@@ -316,15 +317,15 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
log_frame["total-gops"].append(gops) log_frame["total-gops"].append(gops)
log_frame["total-cost"].append(cost.mean) log_frame["total-cost"].append(cost.mean)
log_frame["total-insn"].append(get_insn_count(wl, s)) log_frame["total-insn"].append(get_insn_count(wl, s))
log_frame["block-batch"].append(vta.VTA_BATCH) log_frame["block-batch"].append(env.BATCH)
log_frame["block-in"].append(vta.VTA_BLOCK_IN) log_frame["block-in"].append(env.BLOCK_IN)
log_frame["block-out"].append(vta.VTA_BLOCK_OUT) log_frame["block-out"].append(env.BLOCK_OUT)
log_frame["inp-width"].append(vta.VTA_INP_WIDTH) log_frame["inp-width"].append(env.INP_WIDTH)
log_frame["wgt-width"].append(vta.VTA_WGT_WIDTH) log_frame["wgt-width"].append(env.WGT_WIDTH)
log_frame["uop-size"].append(vta.VTA_UOP_BUFF_SIZE) log_frame["uop-size"].append(env.UOP_BUFF_SIZE)
log_frame["inp-size"].append(vta.VTA_INP_BUFF_SIZE) log_frame["inp-size"].append(env.INP_BUFF_SIZE)
log_frame["wgt-size"].append(vta.VTA_WGT_BUFF_SIZE) log_frame["wgt-size"].append(env.WGT_BUFF_SIZE)
log_frame["out-size"].append(vta.VTA_OUT_BUFF_SIZE) log_frame["out-size"].append(env.OUT_BUFF_SIZE)
log_frame["oc-factor"].append(sched.oc_factor) log_frame["oc-factor"].append(sched.oc_factor)
log_frame["ic-factor"].append(sched.ko_factor) log_frame["ic-factor"].append(sched.ko_factor)
log_frame["h-factor"].append(sched.h_factor) log_frame["h-factor"].append(sched.h_factor)
...@@ -333,16 +334,16 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True) ...@@ -333,16 +334,16 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
log_frame["h-threads"].append(sched.h_nthread) log_frame["h-threads"].append(sched.h_nthread)
log_frame["threaded"].append(True if sched.oc_nthread > 1 or sched.h_nthread > 1 else False) log_frame["threaded"].append(True if sched.oc_nthread > 1 or sched.h_nthread > 1 else False)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir, True) run_test("NORMAL", print_ir, True)
def skip_alu_unittest(print_ir): def skip_alu_unittest(print_ir):
mock = vta.mock mock = env.mock
print("----- Skip ALU Unit Test-------") print("----- Skip ALU Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY, env.dma_copy, env.dma_copy,
vta.GEMM, mock.ALU, vta.DMA_COPY, env.gemm, mock.alu, env.dma_copy,
print_ir, False) print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
...@@ -350,118 +351,118 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True) ...@@ -350,118 +351,118 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
log_frame["skip-alu-gops"].append(gops) log_frame["skip-alu-gops"].append(gops)
log_frame["skip-alu-cost"].append(cost.mean) log_frame["skip-alu-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
def gemm_unittest(print_ir): def gemm_unittest(print_ir):
mock = vta.mock mock = env.mock
print("----- GEMM Unit Test-------") print("----- GEMM Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY, mock.dma_copy, mock.dma_copy,
vta.GEMM, mock.ALU, mock.DMA_COPY, env.gemm, mock.alu, mock.dma_copy,
print_ir, False) print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
log_frame["gemm-gops"].append(gops) log_frame["gemm-gops"].append(gops)
log_frame["gemm-cost"].append(cost.mean) log_frame["gemm-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
def alu_unittest(print_ir): def alu_unittest(print_ir):
mock = vta.mock mock = env.mock
print("----- ALU Unit Test-------") print("----- ALU Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY, mock.dma_copy, mock.dma_copy,
mock.GEMM, vta.ALU, mock.DMA_COPY, env.gemm, env.alu, mock.dma_copy,
print_ir, False) print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
log_frame["alu-gops"].append(gops) log_frame["alu-gops"].append(gops)
log_frame["alu-cost"].append(cost.mean) log_frame["alu-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
def load_inp_unittest(print_ir): def load_inp_unittest(print_ir):
mock = vta.mock mock = env.mock
print("----- LoadInp Unit Test-------") print("----- LoadInp Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
vta.DMA_COPY, mock.DMA_COPY, env.dma_copy, mock.dma_copy,
mock.GEMM, mock.ALU, mock.DMA_COPY, env.gemm, mock.alu, mock.dma_copy,
print_ir, False) print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * wl.in_filter * wl.height * bandwith = (batch_size * wl.in_filter * wl.height *
wl.width * vta.INP_WIDTH / cost.mean) / float(10 ** 9) wl.width * env.INP_WIDTH / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % ( print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % (
cost.mean, gops, bandwith)) cost.mean, gops, bandwith))
log_frame["ld-inp-gbits"].append(bandwith) log_frame["ld-inp-gbits"].append(bandwith)
log_frame["ld-inp-cost"].append(cost.mean) log_frame["ld-inp-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
def load_wgt_unittest(print_ir): def load_wgt_unittest(print_ir):
mock = vta.mock mock = env.mock
print("----- LoadWgt Unit Test-------") print("----- LoadWgt Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
mock.DMA_COPY, vta.DMA_COPY, mock.dma_copy, env.dma_copy,
mock.GEMM, mock.ALU, mock.DMA_COPY, print_ir, env.gemm, mock.alu, mock.dma_copy, print_ir,
False) False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (wl.out_filter * wl.in_filter * wl.hkernel * bandwith = (wl.out_filter * wl.in_filter * wl.hkernel *
wl.wkernel * vta.WGT_WIDTH / cost.mean) / float(10 ** 9) wl.wkernel * env.WGT_WIDTH / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % ( print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % (
cost.mean, gops, bandwith)) cost.mean, gops, bandwith))
log_frame["ld-wgt-gbits"].append(bandwith) log_frame["ld-wgt-gbits"].append(bandwith)
log_frame["ld-wgt-cost"].append(cost.mean) log_frame["ld-wgt-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
def store_out_unittest(print_ir): def store_out_unittest(print_ir):
mock = vta.mock mock = env.mock
print("----- StoreOut Unit Test-------") print("----- StoreOut Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY, mock.dma_copy, mock.dma_copy,
mock.GEMM, mock.ALU, vta.DMA_COPY, print_ir, env.gemm, mock.alu, env.dma_copy, print_ir,
False) False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * wl.out_filter * fout_height * bandwith = (batch_size * wl.out_filter * fout_height *
fout_width * vta.OUT_WIDTH / cost.mean) / float(10 ** 9) fout_width * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % ( print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % (
cost.mean, gops, bandwith)) cost.mean, gops, bandwith))
log_frame["st-out-gbits"].append(bandwith) log_frame["st-out-gbits"].append(bandwith)
log_frame["st-out-cost"].append(cost.mean) log_frame["st-out-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
def manual_unittest(print_ir): def manual_unittest(print_ir):
# Manual section used to teak the components # Manual section used to teak the components
mock = vta.mock mock = env.mock
print("----- Manual Unit Test-------") print("----- Manual Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY, env.dma_copy, env.dma_copy,
vta.GEMM, vta.ALU, mock.DMA_COPY, print_ir, env.gemm, env.alu, mock.dma_copy, print_ir,
False) False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % ( print("\tTime cost = %g sec/op, %g GFLOPS" % (
cost.mean, gops)) cost.mean, gops))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
......
...@@ -8,34 +8,33 @@ from tvm.contrib import rpc, util ...@@ -8,34 +8,33 @@ from tvm.contrib import rpc, util
host = "pynq" host = "pynq"
port = 9091 port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf" target = "llvm -target=armv7-none-linux-gnueabihf"
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
def test_gemm_packed(batch_size, channel, block): def test_gemm_packed(batch_size, channel, block):
data_shape = (batch_size//vta.VTA_BATCH, env = vta.get_env()
channel//vta.VTA_BLOCK_IN, data_shape = (batch_size // env.BATCH,
vta.VTA_BATCH, channel // env.BLOCK_IN,
vta.VTA_BLOCK_IN) env.BATCH,
weight_shape = (channel//vta.VTA_BLOCK_OUT, env.BLOCK_IN)
channel//vta.VTA_BLOCK_IN, weight_shape = (channel // env.BLOCK_OUT,
vta.VTA_BLOCK_OUT, channel // env.BLOCK_IN,
vta.VTA_BLOCK_IN) env.BLOCK_OUT,
res_shape = (batch_size//vta.VTA_BATCH, env.BLOCK_IN)
channel//vta.VTA_BLOCK_OUT, res_shape = (batch_size // env.BATCH,
vta.VTA_BATCH, channel // env.BLOCK_OUT,
vta.VTA_BLOCK_OUT) env.BATCH,
env.BLOCK_OUT)
num_ops = channel * channel * batch_size num_ops = channel * channel * batch_size
ko = tvm.reduce_axis((0, channel//vta.VTA_BLOCK_IN), name='ko') ko = tvm.reduce_axis((0, channel // env.BLOCK_IN), name='ko')
ki = tvm.reduce_axis((0, vta.VTA_BLOCK_IN), name='ki') ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki')
data = tvm.placeholder(data_shape, data = tvm.placeholder(data_shape,
name="data", name="data",
dtype=inp_dtype) dtype=env.inp_dtype)
weight = tvm.placeholder(weight_shape, weight = tvm.placeholder(weight_shape,
name="weight", name="weight",
dtype=wgt_dtype) dtype=env.wgt_dtype)
data_buf = tvm.compute(data_shape, data_buf = tvm.compute(data_shape,
lambda *i: data(*i), lambda *i: data(*i),
"data_buf") "data_buf")
...@@ -44,8 +43,8 @@ def test_gemm_packed(batch_size, channel, block): ...@@ -44,8 +43,8 @@ def test_gemm_packed(batch_size, channel, block):
"weight_buf") "weight_buf")
res_gem = tvm.compute(res_shape, res_gem = tvm.compute(res_shape,
lambda bo, co, bi, ci: tvm.sum( lambda bo, co, bi, ci: tvm.sum(
data_buf[bo, ko, bi, ki].astype(out_dtype) * data_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
weight_buf[co, ko, ci, ki].astype(out_dtype), weight_buf[co, ko, ci, ki].astype(env.acc_dtype),
axis=[ko, ki]), axis=[ko, ki]),
name="res_gem") name="res_gem")
res_shf = tvm.compute(res_shape, res_shf = tvm.compute(res_shape,
...@@ -55,14 +54,14 @@ def test_gemm_packed(batch_size, channel, block): ...@@ -55,14 +54,14 @@ def test_gemm_packed(batch_size, channel, block):
lambda *i: tvm.max(res_shf(*i), 0), lambda *i: tvm.max(res_shf(*i), 0),
"res_max") #relu "res_max") #relu
res_min = tvm.compute(res_shape, res_min = tvm.compute(res_shape,
lambda *i: tvm.min(res_max(*i), (1<<(vta.VTA_INP_WIDTH-1))-1), lambda *i: tvm.min(res_max(*i), (1<<(env.INP_WIDTH-1))-1),
"res_min") #relu "res_min") #relu
res = tvm.compute(res_shape, res = tvm.compute(res_shape,
lambda *i: res_min(*i).astype(inp_dtype), lambda *i: res_min(*i).astype(env.inp_dtype),
name="res") name="res")
def verify(s, check_correctness=True): def verify(s, check_correctness=True):
mod = tvm.build(s, [data, weight, res], "ext_dev", target, name="gemm") mod = vta.build(s, [data, weight, res], "ext_dev", target, name="gemm")
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port) remote = rpc.connect(host, port)
mod.save(temp.relpath("gemm.o")) mod.save(temp.relpath("gemm.o"))
...@@ -76,29 +75,29 @@ def test_gemm_packed(batch_size, channel, block): ...@@ -76,29 +75,29 @@ def test_gemm_packed(batch_size, channel, block):
weight_orig = np.random.randint( weight_orig = np.random.randint(
-128, 128, size=(channel, channel)).astype(weight.dtype) -128, 128, size=(channel, channel)).astype(weight.dtype)
data_packed = data_orig.reshape( data_packed = data_orig.reshape(
batch_size//vta.VTA_BATCH, vta.VTA_BATCH, batch_size // env.BATCH, env.BATCH,
channel//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN).transpose((0, 2, 1, 3)) channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
weight_packed = weight_orig.reshape( weight_packed = weight_orig.reshape(
channel//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT, channel // env.BLOCK_OUT, env.BLOCK_OUT,
channel//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN).transpose((0, 2, 1, 3)) channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
res_np = np.zeros(res_shape).astype(res.dtype) res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_packed, ctx) data_arr = tvm.nd.array(data_packed, ctx)
weight_arr = tvm.nd.array(weight_packed, ctx) weight_arr = tvm.nd.array(weight_packed, ctx)
res_arr = tvm.nd.array(res_np, ctx) res_arr = tvm.nd.array(res_np, ctx)
res_ref = np.zeros(res_shape).astype(out_dtype) res_ref = np.zeros(res_shape).astype(env.acc_dtype)
for b in range(batch_size//vta.VTA_BATCH): for b in range(batch_size // env.BATCH):
for i in range(channel//vta.VTA_BLOCK_OUT): for i in range(channel // env.BLOCK_OUT):
for j in range(channel//vta.VTA_BLOCK_IN): for j in range(channel // env.BLOCK_IN):
res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(out_dtype), res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(env.acc_dtype),
weight_packed[i,j].T.astype(out_dtype)) weight_packed[i,j].T.astype(env.acc_dtype))
res_ref = np.right_shift(res_ref, 8) res_ref = np.right_shift(res_ref, 8)
res_ref = np.clip(res_ref, 0, (1<<(vta.VTA_INP_WIDTH-1))-1).astype(res.dtype) res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
time_f = f.time_evaluator("gemm", ctx, number=20) time_f = f.time_evaluator("gemm", ctx, number=20)
cost = time_f(data_arr, weight_arr, res_arr) cost = time_f(data_arr, weight_arr, res_arr)
res_unpack = res_arr.asnumpy().reshape(batch_size//vta.VTA_BATCH, res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH,
channel//vta.VTA_BLOCK_OUT, channel // env.BLOCK_OUT,
vta.VTA_BATCH, env.BATCH,
vta.VTA_BLOCK_OUT) env.BLOCK_OUT)
if check_correctness: if check_correctness:
np.testing.assert_allclose(res_unpack, res_ref) np.testing.assert_allclose(res_unpack, res_ref)
return cost return cost
...@@ -111,17 +110,17 @@ def test_gemm_packed(batch_size, channel, block): ...@@ -111,17 +110,17 @@ def test_gemm_packed(batch_size, channel, block):
print_ir, print_ir,
check_correctness): check_correctness):
s = tvm.create_schedule(res.op) s = tvm.create_schedule(res.op)
s[data_buf].set_scope(vta.SCOPE_INP) s[data_buf].set_scope(env.inp_scope)
s[weight_buf].set_scope(vta.SCOPE_WGT) s[weight_buf].set_scope(env.wgt_scope)
s[res_gem].set_scope(vta.SCOPE_OUT) s[res_gem].set_scope(env.acc_scope)
s[res_shf].set_scope(vta.SCOPE_OUT) s[res_shf].set_scope(env.acc_scope)
s[res_min].set_scope(vta.SCOPE_OUT) s[res_min].set_scope(env.acc_scope)
s[res_max].set_scope(vta.SCOPE_OUT) s[res_max].set_scope(env.acc_scope)
if block: if block:
bblock = block // vta.VTA_BATCH bblock = block // env.BATCH
iblock = block // vta.VTA_BLOCK_IN iblock = block // env.BLOCK_IN
oblock = block // vta.VTA_BLOCK_OUT oblock = block // env.BLOCK_OUT
xbo, xco, xbi, xci = s[res].op.axis xbo, xco, xbi, xci = s[res].op.axis
xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock) xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock)
store_pt = xb2 store_pt = xb2
...@@ -162,91 +161,91 @@ def test_gemm_packed(batch_size, channel, block): ...@@ -162,91 +161,91 @@ def test_gemm_packed(batch_size, channel, block):
return verify(s, check_correctness) return verify(s, check_correctness)
def gemm_normal(print_ir): def gemm_normal(print_ir):
mock = vta.mock mock = env.mock
print("----- GEMM GFLOPS End-to-End Test-------") print("----- GEMM GFLOPS End-to-End Test-------")
def run_test(header, print_ir, check_correctness): def run_test(header, print_ir, check_correctness):
cost = run_schedule( cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY, vta.GEMM, vta.ALU, vta.DMA_COPY, env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy,
print_ir, check_correctness) print_ir, check_correctness)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with tvm.build_config(add_lower_pass=vta.debug_mode(vta.DEBUG_DUMP_INSN)): with vta.build_config():
run_test("NORMAL", print_ir, True) run_test("NORMAL", print_ir, True)
print("") print("")
def gevm_unittest(print_ir): def gevm_unittest(print_ir):
mock = vta.mock mock = env.mock
print("----- GEMM Unit Test-------") print("----- GEMM Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY, vta.GEMM, mock.ALU, mock.DMA_COPY, mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy,
print_ir, False) print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
def alu_unittest(print_ir): def alu_unittest(print_ir):
mock = vta.mock mock = env.mock
print("----- ALU Unit Test-------") print("----- ALU Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY, mock.GEMM, vta.ALU, mock.DMA_COPY, mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy,
print_ir, False) print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
def load_inp_unittest(print_ir): def load_inp_unittest(print_ir):
mock = vta.mock mock = env.mock
print("----- LoadInp Unit Test-------") print("----- LoadInp Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
vta.DMA_COPY, mock.DMA_COPY, mock.GEMM, mock.ALU, mock.DMA_COPY, print_ir, False) env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * vta.VTA_INP_WIDTH / cost.mean) / float(10 ** 9) bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % ( print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith)) cost.mean, gops, bandwith))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
def load_wgt_unittest(print_ir): def load_wgt_unittest(print_ir):
mock = vta.mock mock = env.mock
print("----- LoadWgt Unit Test-------") print("----- LoadWgt Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
mock.DMA_COPY, vta.DMA_COPY, mock.GEMM, mock.ALU, mock.DMA_COPY, print_ir, False) mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (channel * channel * vta.VTA_WGT_WIDTH / cost.mean) / float(10 ** 9) bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % ( print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith)) cost.mean, gops, bandwith))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
def store_out_unittest(print_ir): def store_out_unittest(print_ir):
mock = vta.mock mock = env.mock
print("----- StoreOut Unit Test-------") print("----- StoreOut Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY, mock.GEMM, mock.ALU, vta.DMA_COPY, mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy,
print_ir, False) print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * vta.VTA_OUT_WIDTH / cost.mean) / float(10 ** 9) bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % ( print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith)) cost.mean, gops, bandwith))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
......
...@@ -21,26 +21,26 @@ def my_clip(x, a_min, a_max): ...@@ -21,26 +21,26 @@ def my_clip(x, a_min, a_max):
host = "pynq" host = "pynq"
port = 9091 port = 9091
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon" target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon"
print_ir = False print_ir = False
def test_vta_conv2d(key, batch_size, wl, profile=True): def test_vta_conv2d(key, batch_size, wl, profile=True):
data_shape = (batch_size, wl.in_filter//vta.VTA_BLOCK_IN, env = vta.get_env()
wl.height, wl.width, vta.VTA_BLOCK_IN) data_shape = (batch_size, wl.in_filter // env.BLOCK_IN,
kernel_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, wl.in_filter//vta.VTA_BLOCK_IN, wl.height, wl.width, env.BLOCK_IN)
wl.hkernel, wl.wkernel, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN) kernel_shape = (wl.out_filter // env.BLOCK_OUT,
bias_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, 1, 1, vta.VTA_BLOCK_OUT) wl.in_filter // env.BLOCK_IN,
wl.hkernel, wl.wkernel,
env.BLOCK_OUT, env.BLOCK_IN)
bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
data = tvm.placeholder(data_shape, name="data", dtype=inp_dtype) data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=wgt_dtype) kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
bias = tvm.placeholder(bias_shape, name="kernel", dtype=out_dtype) bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype)
res_conv = vta_conv2d.packed_conv2d( res_conv = vta_conv2d.packed_conv2d(
data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride)) data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride))
...@@ -73,14 +73,14 @@ def test_vta_conv2d(key, batch_size, wl, profile=True): ...@@ -73,14 +73,14 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
bias_orig = np.abs(bias_orig) bias_orig = np.abs(bias_orig)
data_packed = data_orig.reshape( data_packed = data_orig.reshape(
batch_size, wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN, batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
wl.height, wl.width).transpose((0, 1, 3, 4, 2)) wl.height, wl.width).transpose((0, 1, 3, 4, 2))
kernel_packed = kernel_orig.reshape( kernel_packed = kernel_orig.reshape(
wl.out_filter//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT, wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT,
wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
bias_packed = bias_orig.reshape( bias_packed = bias_orig.reshape(
wl.out_filter//vta.VTA_BLOCK_OUT, 1, 1, vta.VTA_BLOCK_OUT) wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
res_shape = topi.util.get_const_tuple(res.shape) res_shape = topi.util.get_const_tuple(res.shape)
res_np = np.zeros(res_shape).astype(res.dtype) res_np = np.zeros(res_shape).astype(res.dtype)
...@@ -94,13 +94,13 @@ def test_vta_conv2d(key, batch_size, wl, profile=True): ...@@ -94,13 +94,13 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
(0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) (0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
if check_correctness: if check_correctness:
res_ref = mx.nd.Convolution( res_ref = mx.nd.Convolution(
mx.nd.array(data_orig.astype(out_dtype), mx.cpu(0)), mx.nd.array(data_orig.astype(env.acc_dtype), mx.cpu(0)),
mx.nd.array(kernel_orig.astype(out_dtype), mx.cpu(0)), mx.nd.array(kernel_orig.astype(env.acc_dtype), mx.cpu(0)),
stride=(wl.hstride, wl.wstride), stride=(wl.hstride, wl.wstride),
kernel=(wl.hkernel, wl.wkernel), kernel=(wl.hkernel, wl.wkernel),
num_filter=wl.out_filter, num_filter=wl.out_filter,
no_bias=True, no_bias=True,
pad=(wl.hpad, wl.wpad)).asnumpy().astype(out_dtype) pad=(wl.hpad, wl.wpad)).asnumpy().astype(env.acc_dtype)
res_ref = res_ref >> 8 res_ref = res_ref >> 8
res_ref += bias_orig.reshape(wl.out_filter, 1, 1) res_ref += bias_orig.reshape(wl.out_filter, 1, 1)
res_ref = np.clip(res_ref, 0, 127).astype("int8") res_ref = np.clip(res_ref, 0, 127).astype("int8")
...@@ -110,10 +110,10 @@ def test_vta_conv2d(key, batch_size, wl, profile=True): ...@@ -110,10 +110,10 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
def conv_normal(print_ir): def conv_normal(print_ir):
print("----- CONV2D End-to-End Test-------") print("----- CONV2D End-to-End Test-------")
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with vta.build_config():
s = vta_conv2d.schedule_packed_conv2d([res]) s = vta_conv2d.schedule_packed_conv2d([res])
if print_ir: if print_ir:
print(tvm.lower(s, [data, kernel, bias, res], simple_mode=True)) print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
cost = verify(s, True) cost = verify(s, True)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
......
...@@ -3,14 +3,15 @@ import vta ...@@ -3,14 +3,15 @@ import vta
import os import os
from tvm.contrib import rpc, util from tvm.contrib import rpc, util
env = vta.get_env()
host = "pynq" host = "pynq"
port = 9091 port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf" target = "llvm -target=armv7-none-linux-gnueabihf"
bit = "{}x{}x{}_{}bx{}b_{}_{}_{}_{}_100MHz_10ns.bit".format( bit = "{}x{}x{}_{}bx{}b_{}_{}_{}_{}_100MHz_10ns.bit".format(
vta.VTA_BATCH, vta.VTA_BLOCK_IN, vta.VTA_BLOCK_OUT, env.BATCH, env.BLOCK_IN, env.BLOCK_OUT,
vta.VTA_INP_WIDTH, vta.VTA_WGT_WIDTH, env.INP_WIDTH, env.WGT_WIDTH,
vta.VTA_LOG_UOP_BUFF_SIZE, vta.VTA_LOG_INP_BUFF_SIZE, env.LOG_UOP_BUFF_SIZE, env.LOG_INP_BUFF_SIZE,
vta.VTA_LOG_WGT_BUFF_SIZE, vta.VTA_LOG_OUT_BUFF_SIZE) env.LOG_WGT_BUFF_SIZE, env.LOG_ACC_BUFF_SIZE)
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
bitstream = os.path.join(curr_path, "../../../../vta_bitstreams/bitstreams/", bit) bitstream = os.path.join(curr_path, "../../../../vta_bitstreams/bitstreams/", bit)
......
"""Unit test TPU's instructions """ """Unit test VTA's instructions """
import tvm import tvm
import vta import vta
import mxnet as mx import mxnet as mx
...@@ -9,44 +9,42 @@ from tvm.contrib import rpc, util ...@@ -9,44 +9,42 @@ from tvm.contrib import rpc, util
host = "pynq" host = "pynq"
port = 9091 port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf" target = "llvm -target=armv7-none-linux-gnueabihf"
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
do_verify = True do_verify = True
print_ir = False print_ir = False
def test_save_load_out(): def test_save_load_out():
env = vta.get_env()
"""Test save/store output command""" """Test save/store output command"""
n = 4 n = 4
x = tvm.placeholder( x = tvm.placeholder(
(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (n, n, env.BATCH, env.BLOCK_OUT),
name="x", name="x",
dtype=out_dtype) dtype=env.acc_dtype)
x_buf = tvm.compute( x_buf = tvm.compute(
(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: x(*i), lambda *i: x(*i),
"x_buf") "x_buf")
# insert no-op that won't be optimized away # insert no-op that won't be optimized away
y_buf = tvm.compute( y_buf = tvm.compute(
(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: x_buf(*i)>>0, lambda *i: x_buf(*i)>>0,
"y_buf") "y_buf")
y = tvm.compute( y = tvm.compute(
(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: y_buf(*i).astype(inp_dtype), lambda *i: y_buf(*i).astype(env.inp_dtype),
"y") "y")
# schedule # schedule
s = tvm.create_schedule(y.op) s = tvm.create_schedule(y.op)
s[x_buf].set_scope(vta.SCOPE_OUT) s[x_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], vta.DMA_COPY) s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(vta.SCOPE_OUT) s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], vta.ALU) s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], vta.DMA_COPY) s[y].pragma(y.op.axis[0], env.dma_copy)
def verify(): def verify():
# build # build
m = tvm.build(s, [x, y], "ext_dev", target) with vta.build_config(env.DEBUG_DUMP_INSN):
m = vta.build(s, [x, y], "ext_dev", target)
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port) remote = rpc.connect(host, port)
m.save(temp.relpath("load_act.o")) m.save(temp.relpath("load_act.o"))
...@@ -55,7 +53,7 @@ def test_save_load_out(): ...@@ -55,7 +53,7 @@ def test_save_load_out():
# verify # verify
ctx = remote.ext_dev(0) ctx = remote.ext_dev(0)
x_np = np.random.randint( x_np = np.random.randint(
1, 10, size=(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(x.dtype) 1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
y_np = x_np.astype(y.dtype) y_np = x_np.astype(y.dtype)
x_nd = tvm.nd.array(x_np, ctx) x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
...@@ -65,38 +63,41 @@ def test_save_load_out(): ...@@ -65,38 +63,41 @@ def test_save_load_out():
if do_verify: if do_verify:
verify() verify()
def test_padded_load(): def test_padded_load():
"""Test padded load.""" """Test padded load."""
env = vta.get_env()
# declare # declare
n = 21 n = 21
m = 20 m = 20
pad_before = [0, 1, 0, 0] pad_before = [0, 1, 0, 0]
pad_after = [1, 3, 0, 0] pad_after = [1, 3, 0, 0]
x = tvm.placeholder( x = tvm.placeholder(
(n, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (n, m, env.BATCH, env.BLOCK_OUT),
name="x", name="x",
dtype=out_dtype) dtype=env.acc_dtype)
x_buf = topi.nn.pad(x, pad_before, pad_after, name="y") x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
# insert no-op that won't be optimized away # insert no-op that won't be optimized away
y_buf = tvm.compute((n + pad_before[0] + pad_after[0], y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1], m + pad_before[1] + pad_after[1],
vta.VTA_BATCH, env.BATCH,
vta.VTA_BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf") env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
y = tvm.compute((n + pad_before[0] + pad_after[0], y = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1], m + pad_before[1] + pad_after[1],
vta.VTA_BATCH, env.BATCH,
vta.VTA_BLOCK_OUT), lambda *i: y_buf(*i).astype(inp_dtype), "y") env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
# schedule # schedule
s = tvm.create_schedule(y.op) s = tvm.create_schedule(y.op)
s[x_buf].set_scope(vta.SCOPE_OUT) s[x_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], vta.DMA_COPY) s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(vta.SCOPE_OUT) s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], vta.ALU) s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], vta.DMA_COPY) s[y].pragma(y.op.axis[0], env.dma_copy)
def verify(): def verify():
# build # build
mod = tvm.build(s, [x, y], "ext_dev", target) with vta.build_config(env.DEBUG_DUMP_INSN):
mod = vta.build(s, [x, y], "ext_dev", target)
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port) remote = rpc.connect(host, port)
mod.save(temp.relpath("padded_load.o")) mod.save(temp.relpath("padded_load.o"))
...@@ -105,11 +106,11 @@ def test_padded_load(): ...@@ -105,11 +106,11 @@ def test_padded_load():
# verify # verify
ctx = remote.ext_dev(0) ctx = remote.ext_dev(0)
x_np = np.random.randint(1, 2, size=( x_np = np.random.randint(1, 2, size=(
n, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(x.dtype) n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
y_np = np.zeros((n + pad_before[0] + pad_after[0], y_np = np.zeros((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1], m + pad_before[1] + pad_after[1],
vta.VTA_BATCH, env.BATCH,
vta.VTA_BLOCK_OUT)).astype(y.dtype) env.BLOCK_OUT)).astype(y.dtype)
y_np[pad_before[0]:pad_before[0] + n, y_np[pad_before[0]:pad_before[0] + n,
pad_before[1]:pad_before[1] + m, pad_before[1]:pad_before[1] + m,
:] = x_np :] = x_np
...@@ -118,51 +119,50 @@ def test_padded_load(): ...@@ -118,51 +119,50 @@ def test_padded_load():
f(x_nd, y_nd) f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy()) np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...") print("\tFinished verification...")
if print_ir: if print_ir:
print(tvm.lower(s, [y, x], simple_mode=True)) print(vta.lower(s, [y, x], simple_mode=True))
if do_verify:
with tvm.build_config(add_lower_pass=vta.debug_mode(
vta.DEBUG_DUMP_INSN)):
verify()
def test_gemm(): def test_gemm():
"""Test GEMM.""" """Test GEMM."""
env = vta.get_env()
# declare # declare
o = 4 o = 4
n = 4 n = 4
m = 4 m = 4
x = tvm.placeholder((o, n, vta.VTA_BATCH, vta.VTA_BLOCK_IN), name="x", dtype=inp_dtype) x = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="x", dtype=env.inp_dtype)
w = tvm.placeholder((m, n, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN), name="w", dtype=wgt_dtype) w = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="w", dtype=env.wgt_dtype)
x_buf = tvm.compute((o, n, vta.VTA_BATCH, vta.VTA_BLOCK_IN), lambda *i: x(*i), "x_buf") x_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: x(*i), "x_buf")
w_buf = tvm.compute((m, n, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN), lambda *i: w(*i), "w_buf") w_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: w(*i), "w_buf")
ko = tvm.reduce_axis((0, n), name="ko") ko = tvm.reduce_axis((0, n), name="ko")
ki = tvm.reduce_axis((0, vta.VTA_BLOCK_IN), name="ki") ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki")
y_gem = tvm.compute( y_gem = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (o, m, env.BATCH, env.BLOCK_OUT),
lambda bo, co, bi, ci: lambda bo, co, bi, ci:
tvm.sum(x_buf[bo, ko, bi, ki].astype(out_dtype) * tvm.sum(x_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
w_buf[co, ko, ci, ki].astype(out_dtype), w_buf[co, ko, ci, ki].astype(env.acc_dtype),
axis=[ko, ki]), axis=[ko, ki]),
name="y_gem") name="y_gem")
y_shf = tvm.compute( y_shf = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: y_gem(*i)>>8, lambda *i: y_gem(*i)>>8,
name="y_shf") name="y_shf")
y_max = tvm.compute( y_max = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.max(y_shf(*i), 0), lambda *i: tvm.max(y_shf(*i), 0),
"y_max") #relu "y_max") #relu
y_min = tvm.compute( y_min = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.min(y_max(*i), (1<<(vta.VTA_INP_WIDTH-1))-1), lambda *i: tvm.min(y_max(*i), (1<<(env.INP_WIDTH-1))-1),
"y_min") #relu "y_min") #relu
y = tvm.compute( y = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: y_min(*i).astype(inp_dtype), lambda *i: y_min(*i).astype(env.inp_dtype),
name="y") name="y")
def verify(s): def verify(s):
mod = tvm.build(s, [x, w, y], "ext_dev", target) mod = vta.build(s, [x, w, y], "ext_dev", target)
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port) remote = rpc.connect(host, port)
mod.save(temp.relpath("gemm.o")) mod.save(temp.relpath("gemm.o"))
...@@ -171,21 +171,21 @@ def test_gemm(): ...@@ -171,21 +171,21 @@ def test_gemm():
# verify # verify
ctx = remote.ext_dev(0) ctx = remote.ext_dev(0)
x_np = np.random.randint( x_np = np.random.randint(
-128, 128, size=(o, n, vta.VTA_BATCH, vta.VTA_BLOCK_IN)).astype(x.dtype) -128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(x.dtype)
w_np = np.random.randint( w_np = np.random.randint(
-128, 128, size=(m, n, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN)).astype(w.dtype) -128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(w.dtype)
y_np = np.zeros((o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(y.dtype) y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype)
x_nd = tvm.nd.array(x_np, ctx) x_nd = tvm.nd.array(x_np, ctx)
w_nd = tvm.nd.array(w_np, ctx) w_nd = tvm.nd.array(w_np, ctx)
y_nd = tvm.nd.array(y_np, ctx) y_nd = tvm.nd.array(y_np, ctx)
y_np = y_np.astype(out_dtype) y_np = y_np.astype(env.acc_dtype)
for b in range(o): for b in range(o):
for i in range(m): for i in range(m):
for j in range(n): for j in range(n):
y_np[b,i,:] += np.dot(x_np[b,j,:].astype(out_dtype), y_np[b,i,:] += np.dot(x_np[b,j,:].astype(env.acc_dtype),
w_np[i,j].T.astype(out_dtype)) w_np[i,j].T.astype(env.acc_dtype))
y_np = np.right_shift(y_np, 8) y_np = np.right_shift(y_np, 8)
y_np = np.clip(y_np, 0, (1<<(vta.VTA_INP_WIDTH-1))-1).astype(y.dtype) y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
f(x_nd, w_nd, y_nd) f(x_nd, w_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy()) np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...") print("\tFinished verification...")
...@@ -194,19 +194,19 @@ def test_gemm(): ...@@ -194,19 +194,19 @@ def test_gemm():
# default schedule with no smt # default schedule with no smt
s = tvm.create_schedule(y.op) s = tvm.create_schedule(y.op)
# set the scope of the SRAM buffers # set the scope of the SRAM buffers
s[x_buf].set_scope(vta.SCOPE_INP) s[x_buf].set_scope(env.SCOPE_INP)
s[w_buf].set_scope(vta.SCOPE_WGT) s[w_buf].set_scope(env.SCOPE_WGT)
s[y_gem].set_scope(vta.SCOPE_OUT) s[y_gem].set_scope(env.acc_scope)
s[y_shf].set_scope(vta.SCOPE_OUT) s[y_shf].set_scope(env.acc_scope)
s[y_max].set_scope(vta.SCOPE_OUT) s[y_max].set_scope(env.acc_scope)
s[y_min].set_scope(vta.SCOPE_OUT) s[y_min].set_scope(env.acc_scope)
# set pragmas for DMA transfer and ALU ops # set pragmas for DMA transfer and ALU ops
s[x_buf].pragma(s[x_buf].op.axis[0], vta.DMA_COPY) s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
s[w_buf].pragma(s[w_buf].op.axis[0], vta.DMA_COPY) s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y_shf].pragma(s[y_shf].op.axis[0], vta.ALU) s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
s[y_max].pragma(s[y_max].op.axis[0], vta.ALU) s[y_max].pragma(s[y_max].op.axis[0], env.alu)
s[y_min].pragma(s[y_min].op.axis[0], vta.ALU) s[y_min].pragma(s[y_min].op.axis[0], env.alu)
s[y].pragma(s[y].op.axis[0], vta.DMA_COPY) s[y].pragma(s[y].op.axis[0], env.dma_copy)
# tensorization # tensorization
s[y_gem].reorder( s[y_gem].reorder(
ko, ko,
...@@ -215,23 +215,22 @@ def test_gemm(): ...@@ -215,23 +215,22 @@ def test_gemm():
s[y_gem].op.axis[2], s[y_gem].op.axis[2],
s[y_gem].op.axis[3], s[y_gem].op.axis[3],
ki) ki)
s[y_gem].tensorize(s[y_gem].op.axis[2], vta.GEMM) s[y_gem].tensorize(s[y_gem].op.axis[2], env.GEMM)
if print_ir: if print_ir:
print(tvm.lower(s, [x, w, y], simple_mode=True)) print(vta.lower(s, [x, w, y], simple_mode=True))
if do_verify: if do_verify:
with tvm.build_config( with vta.build_config(env.DEBUG_DUMP_INSN):
add_lower_pass=vta.debug_mode(vta.DEBUG_DUMP_INSN)):
verify(s) verify(s)
def test_smt(): def test_smt():
# test smt schedule # test smt schedule
s = tvm.create_schedule(y.op) s = tvm.create_schedule(y.op)
s[x_buf].set_scope(vta.SCOPE_INP) s[x_buf].set_scope(env.SCOPE_INP)
s[w_buf].set_scope(vta.SCOPE_WGT) s[w_buf].set_scope(env.SCOPE_WGT)
s[y_gem].set_scope(vta.SCOPE_OUT) s[y_gem].set_scope(env.acc_scope)
s[y_shf].set_scope(vta.SCOPE_OUT) s[y_shf].set_scope(env.acc_scope)
s[y_max].set_scope(vta.SCOPE_OUT) s[y_max].set_scope(env.acc_scope)
s[y_min].set_scope(vta.SCOPE_OUT) s[y_min].set_scope(env.acc_scope)
abo, aco, abi, aci = s[y].op.axis abo, aco, abi, aci = s[y].op.axis
abo1, abo2 = s[y].split(abo, nparts=2) abo1, abo2 = s[y].split(abo, nparts=2)
s[y].bind(abo1, tvm.thread_axis("cthread")) s[y].bind(abo1, tvm.thread_axis("cthread"))
...@@ -246,20 +245,19 @@ def test_gemm(): ...@@ -246,20 +245,19 @@ def test_gemm():
s[y_gem].op.axis[2], s[y_gem].op.axis[2],
s[y_gem].op.axis[3], s[y_gem].op.axis[3],
ki) ki)
s[y_gem].tensorize(s[y_gem].op.axis[2], vta.GEMM) s[y_gem].tensorize(s[y_gem].op.axis[2], env.GEMM)
s[y_shf].pragma(s[y_shf].op.axis[0], vta.ALU) s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
s[y_max].pragma(s[y_max].op.axis[0], vta.ALU) s[y_max].pragma(s[y_max].op.axis[0], env.alu)
s[y_min].pragma(s[y_min].op.axis[0], vta.ALU) s[y_min].pragma(s[y_min].op.axis[0], env.alu)
s[x_buf].compute_at(s[y_gem], ko) s[x_buf].compute_at(s[y_gem], ko)
s[x_buf].pragma(s[x_buf].op.axis[0], vta.DMA_COPY) s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
s[w_buf].compute_at(s[y_gem], ko) s[w_buf].compute_at(s[y_gem], ko)
s[w_buf].pragma(s[w_buf].op.axis[0], vta.DMA_COPY) s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y].pragma(abo2, vta.DMA_COPY) s[y].pragma(abo2, env.dma_copy)
if print_ir: if print_ir:
print(tvm.lower(s, [x, y, w], simple_mode=True)) print(vta.lower(s, [x, y, w], simple_mode=True))
if do_verify: if do_verify:
with tvm.build_config( with vta.build_config(env.DEBUG_DUMP_INSN):
add_lower_pass=vta.debug_mode(vta.DEBUG_DUMP_INSN)):
verify(s) verify(s)
test_schedule1() test_schedule1()
...@@ -267,62 +265,64 @@ def test_gemm(): ...@@ -267,62 +265,64 @@ def test_gemm():
def test_alu(tvm_op, np_op=None, use_imm=False): def test_alu(tvm_op, np_op=None, use_imm=False):
"""Test ALU""" """Test ALU"""
env = vta.get_env()
m = 8 m = 8
n = 8 n = 8
imm = np.random.randint(1,5) imm = np.random.randint(1,5)
# compute # compute
a = tvm.placeholder( a = tvm.placeholder(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
name="a", name="a",
dtype=out_dtype) dtype=env.acc_dtype)
a_buf = tvm.compute( a_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i), lambda *i: a(*i),
"a_buf") #DRAM->SRAM "a_buf") #DRAM->SRAM
if use_imm: if use_imm:
res_buf = tvm.compute( res_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm_op(a_buf(*i), imm), lambda *i: tvm_op(a_buf(*i), imm),
"res_buf") #compute "res_buf") #compute
else: else:
b = tvm.placeholder( b = tvm.placeholder(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
name="b", name="b",
dtype=out_dtype) dtype=env.acc_dtype)
b_buf = tvm.compute( b_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: b(*i), lambda *i: b(*i),
"b_buf") #DRAM->SRAM "b_buf") #DRAM->SRAM
res_buf = tvm.compute( res_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm_op(a_buf(*i), b_buf(*i)), lambda *i: tvm_op(a_buf(*i), b_buf(*i)),
"res_buf") #compute "res_buf") #compute
res = tvm.compute( res = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_buf(*i).astype(inp_dtype), lambda *i: res_buf(*i).astype(env.inp_dtype),
"res") #SRAM->DRAM "res") #SRAM->DRAM
# schedule # schedule
s = tvm.create_schedule(res.op) s = tvm.create_schedule(res.op)
s[a_buf].set_scope(vta.SCOPE_OUT) # SRAM s[a_buf].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[res_buf].set_scope(vta.SCOPE_OUT) # SRAM s[res_buf].set_scope(env.acc_scope) # SRAM
s[res_buf].pragma(res_buf.op.axis[0], vta.ALU) # compute s[res_buf].pragma(res_buf.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], vta.DMA_COPY) # SRAM->DRAM s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
if use_imm: if use_imm:
if print_ir: if print_ir:
print(tvm.lower(s, [a, res], simple_mode=True)) print(vta.lower(s, [a, res], simple_mode=True))
else: else:
s[b_buf].set_scope(vta.SCOPE_OUT) # SRAM s[b_buf].set_scope(env.acc_scope) # SRAM
s[b_buf].pragma(b_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
if print_ir: if print_ir:
print(tvm.lower(s, [a, b, res], simple_mode=True)) print(vta.lower(s, [a, b, res], simple_mode=True))
def verify(): def verify():
# build # build
with vta.build_config():
if use_imm: if use_imm:
mod = tvm.build(s, [a, res], "ext_dev", target) mod = vta.build(s, [a, res], "ext_dev", target)
else: else:
mod = tvm.build(s, [a, b, res], "ext_dev", target) mod = vta.build(s, [a, b, res], "ext_dev", target)
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port) remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o")) mod.save(temp.relpath("load_act.o"))
...@@ -331,17 +331,17 @@ def test_alu(tvm_op, np_op=None, use_imm=False): ...@@ -331,17 +331,17 @@ def test_alu(tvm_op, np_op=None, use_imm=False):
# verify # verify
ctx = remote.ext_dev(0) ctx = remote.ext_dev(0)
a_np = np.random.randint( a_np = np.random.randint(
-16, 16, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(a.dtype) -16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
if use_imm: if use_imm:
res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm) res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm)
else: else:
b_np = np.random.randint( b_np = np.random.randint(
-16, 16, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(b.dtype) -16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(b.dtype)
res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np) res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np)
res_np = res_np.astype(res.dtype) res_np = res_np.astype(res.dtype)
a_nd = tvm.nd.array(a_np, ctx) a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array( res_nd = tvm.nd.array(
np.zeros((m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(res.dtype), ctx) np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if use_imm: if use_imm:
f(a_nd, res_nd) f(a_nd, res_nd)
else: else:
...@@ -355,44 +355,46 @@ def test_alu(tvm_op, np_op=None, use_imm=False): ...@@ -355,44 +355,46 @@ def test_alu(tvm_op, np_op=None, use_imm=False):
def test_relu(): def test_relu():
"""Test RELU on ALU""" """Test RELU on ALU"""
env = vta.get_env()
m = 8 m = 8
n = 8 n = 8
# compute # compute
a = tvm.placeholder( a = tvm.placeholder(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
name="a", name="a",
dtype=out_dtype) dtype=env.acc_dtype)
a_buf = tvm.compute( a_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i), lambda *i: a(*i),
"a_buf") # DRAM->SRAM "a_buf") # DRAM->SRAM
max_buf = tvm.compute( max_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.max(a_buf(*i), 0), lambda *i: tvm.max(a_buf(*i), 0),
"res_buf") # relu "res_buf") # relu
min_buf = tvm.compute( min_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.min(max_buf(*i), (1<<(vta.VTA_INP_WIDTH-1))-1), lambda *i: tvm.min(max_buf(*i), (1<<(env.INP_WIDTH-1))-1),
"max_buf") # relu "max_buf") # relu
res = tvm.compute( res = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: min_buf(*i).astype(inp_dtype), lambda *i: min_buf(*i).astype(env.inp_dtype),
"min_buf") # SRAM->DRAM "min_buf") # SRAM->DRAM
# schedule # schedule
s = tvm.create_schedule(res.op) s = tvm.create_schedule(res.op)
s[a_buf].set_scope(vta.SCOPE_OUT) # SRAM s[a_buf].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[max_buf].set_scope(vta.SCOPE_OUT) # SRAM s[max_buf].set_scope(env.acc_scope) # SRAM
s[min_buf].set_scope(vta.SCOPE_OUT) # SRAM s[min_buf].set_scope(env.acc_scope) # SRAM
s[max_buf].pragma(max_buf.op.axis[0], vta.ALU) # compute s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute
s[min_buf].pragma(min_buf.op.axis[0], vta.ALU) # compute s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], vta.DMA_COPY) # SRAM->DRAM s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
if print_ir: if print_ir:
print(tvm.lower(s, [a, res], simple_mode=True)) print(vta.lower(s, [a, res], simple_mode=True))
def verify(): def verify():
# build # build
mod = tvm.build(s, [a, res], "ext_dev", target) with vta.build_config(env.DEBUG_DUMP_INSN):
mod = vta.build(s, [a, res], "ext_dev", target)
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port) remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o")) mod.save(temp.relpath("load_act.o"))
...@@ -401,11 +403,11 @@ def test_relu(): ...@@ -401,11 +403,11 @@ def test_relu():
# verify # verify
ctx = remote.ext_dev(0) ctx = remote.ext_dev(0)
a_np = np.random.randint( a_np = np.random.randint(
-256, 256, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(a.dtype) -256, 256, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
res_np = np.clip(a_np, 0, (1<<(vta.VTA_INP_WIDTH-1))-1).astype(res.dtype) res_np = np.clip(a_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
a_nd = tvm.nd.array(a_np, ctx) a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array( res_nd = tvm.nd.array(
np.zeros((m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(res.dtype), ctx) np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
f(a_nd, res_nd) f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy()) np.testing.assert_equal(res_np, res_nd.asnumpy())
print("\tFinished verification...") print("\tFinished verification...")
...@@ -415,45 +417,47 @@ def test_relu(): ...@@ -415,45 +417,47 @@ def test_relu():
def test_shift_and_scale(): def test_shift_and_scale():
"""Test shift and scale on ALU""" """Test shift and scale on ALU"""
env = vta.get_env()
m = 8 m = 8
n = 8 n = 8
imm_shift = np.random.randint(-10,10) imm_shift = np.random.randint(-10,10)
imm_scale = np.random.randint(1,5) imm_scale = np.random.randint(1,5)
# compute # compute
a = tvm.placeholder( a = tvm.placeholder(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
name="a", dtype=out_dtype) name="a", dtype=env.acc_dtype)
a_buf = tvm.compute( a_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i), lambda *i: a(*i),
"a_buf") # DRAM->SRAM "a_buf") # DRAM->SRAM
res_shift = tvm.compute( res_shift = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a_buf(*i)+imm_shift, lambda *i: a_buf(*i)+imm_shift,
"res_shift") # compute "res_shift") # compute
res_scale = tvm.compute( res_scale = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_shift(*i)>>imm_scale, lambda *i: res_shift(*i)>>imm_scale,
"res_scale") # compute "res_scale") # compute
res = tvm.compute( res = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_scale(*i).astype(inp_dtype), lambda *i: res_scale(*i).astype(env.inp_dtype),
"res") # SRAM->DRAM "res") # SRAM->DRAM
# schedule # schedule
s = tvm.create_schedule(res.op) s = tvm.create_schedule(res.op)
s[a_buf].set_scope(vta.SCOPE_OUT) # SRAM s[a_buf].set_scope(env.acc_scope) # SRAM
s[res_shift].set_scope(vta.SCOPE_OUT) # SRAM s[res_shift].set_scope(env.acc_scope) # SRAM
s[res_scale].set_scope(vta.SCOPE_OUT) # SRAM s[res_scale].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[res_shift].pragma(res_shift.op.axis[0], vta.ALU) # compute s[res_shift].pragma(res_shift.op.axis[0], env.alu) # compute
s[res_scale].pragma(res_scale.op.axis[0], vta.ALU) # compute s[res_scale].pragma(res_scale.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], vta.DMA_COPY) # SRAM->DRAM s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
if print_ir: if print_ir:
print(tvm.lower(s, [a, res], simple_mode=True)) print(vta.lower(s, [a, res], simple_mode=True))
def verify(): def verify():
# build # build
mod = tvm.build(s, [a, res], "ext_dev", target) mod = vta.build(s, [a, res], "ext_dev", target)
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port) remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o")) mod.save(temp.relpath("load_act.o"))
...@@ -462,12 +466,12 @@ def test_shift_and_scale(): ...@@ -462,12 +466,12 @@ def test_shift_and_scale():
# verify # verify
ctx = remote.ext_dev(0) ctx = remote.ext_dev(0)
a_np = np.random.randint( a_np = np.random.randint(
-10, 10, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(a.dtype) -10, 10, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
res_np = np.right_shift((a_np + imm_shift), imm_scale) res_np = np.right_shift((a_np + imm_shift), imm_scale)
res_np = res_np.astype(res.dtype) res_np = res_np.astype(res.dtype)
a_nd = tvm.nd.array(a_np, ctx) a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array( res_nd = tvm.nd.array(
np.zeros((m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(res.dtype), ctx) np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
f(a_nd, res_nd) f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy()) np.testing.assert_equal(res_np, res_nd.asnumpy())
print("\tFinished verification...") print("\tFinished verification...")
...@@ -476,10 +480,10 @@ def test_shift_and_scale(): ...@@ -476,10 +480,10 @@ def test_shift_and_scale():
verify() verify()
if __name__ == "__main__": if __name__ == "__main__":
print("Padded load test")
test_padded_load()
print("Load/store test") print("Load/store test")
test_save_load_out() test_save_load_out()
print("Padded load test")
test_padded_load()
# print("GEMM test") # print("GEMM test")
# test_gemm() # test_gemm()
print("Max immediate") print("Max immediate")
......
import vta
def test_env():
env = vta.get_env()
mock = env.mock
assert mock.alu == "skip_alu"
if __name__ == "__main__":
test_env()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment