Commit dea167a8 by Tianqi Chen

[COMPILER] Refactor compiler to enable configuration (#21)

parent f5960ba6
......@@ -39,8 +39,9 @@ vta.program_fpga(remote, BITSTREAM_FILE)
if verbose:
logging.basicConfig(level=logging.INFO)
# Change to -device=tcpu to run cpu only inference.
# Change to -device=vta-cpu to run cpu only inference.
target = "llvm -device=vta"
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
synset = eval(open(os.path.join(CATEG_FILE)).read())
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))
......@@ -117,15 +118,16 @@ graph_attr.set_dtype_inputs(sym, dtype_dict)
sym = sym.apply("InferType")
with nnvm.compiler.build_config(opt_level=3):
bdict = {}
if "vta" not in target:
bdict = {"add_lower_pass": []}
else:
bdict = {"add_lower_pass": vta.debug_mode(0)}
with tvm.build_config(**bdict):
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params)
params=params, target_host=target_host)
else:
with vta.build_config():
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)
temp = util.tempdir()
lib.save(temp.relpath("graphlib.o"))
......
"""TVM-based VTA Compiler Toolchain"""
from __future__ import absolute_import as _abs
from .hw_spec import *
from .environment import get_env, Environment
try:
from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU
from .intrin import GEVM, GEMM
from .build import debug_mode
from . import mock, ir_pass
# allow optional import in config mode.
from . import arm_conv2d, vta_conv2d
except AttributeError:
pass
from .rpc_client import reconfig_runtime, program_fpga
try:
from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga
from . import graph
except ImportError:
pass
"""Runtime function related hooks"""
"""VTA specific buildin for runtime."""
from __future__ import absolute_import as _abs
import tvm
from tvm import build_module
from . runtime import CB_HANDLE
from . import ir_pass
from .environment import get_env
def lift_coproc_scope(x):
"""Lift coprocessings cope to the """
x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x
def early_rewrite(stmt):
"""Try to do storage rewrite in early pass."""
try:
return tvm.ir_pass.StorageRewrite(stmt)
except tvm.TVMError:
return stmt
def debug_mode(debug_flag):
"""Pass to enable vta debug mode.
def build_config(debug_flag=0, **kwargs):
"""Build a build config for VTA.
Parameters
----------
debug_flag : int
The dbeug flag to be passed.
kwargs : dict
Additional configurations.
Returns
-------
pass_list: list of function
The pass to set to build_config(add_lower_pass=vta.debug_mode(mode))
build_config: BuildConfig
The build config that can be used in TVM.
Example
--------
.. code-block:: python
# build a vta module.
with vta.build_config():
vta_module = tvm.build(s, ...)
"""
env = get_env()
def add_debug(stmt):
debug = tvm.call_extern(
"int32", "VTASetDebugMode", CB_HANDLE, debug_flag)
"int32", "VTASetDebugMode",
env.dev.command_handle,
debug_flag)
return tvm.make.stmt_seq(debug, stmt)
pass_list = [(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy),
......@@ -48,8 +64,38 @@ def debug_mode(debug_flag):
pass_list.append((2, ir_pass.inject_alu_intrin))
pass_list.append((3, ir_pass.fold_uop_loop))
pass_list.append((3, ir_pass.cpu_access_rewrite))
return pass_list
return tvm.build_config(add_lower_pass=pass_list, **kwargs)
def lower(*args, **kwargs):
"""Thin wrapper of tvm.lower
This wrapper automatically applies VTA's build_config
if there is no user specified build_config in context.
# Add a lower pass to sync uop
build_module.current_build_config().add_lower_pass = debug_mode(0)
See Also
--------
tvm.lower : The original TVM's lower function
"""
cfg = tvm.build_module.current_build_config()
if not cfg.add_lower_pass:
with build_config():
return tvm.lower(*args, **kwargs)
return tvm.lower(*args, **kwargs)
def build(*args, **kwargs):
"""Thin wrapper of tvm.build
This wrapper automatically applies VTA's build_config
if there is no user specified build_config in context.
See Also
--------
tvm.build : The original TVM's build function
"""
cfg = tvm.build_module.current_build_config()
if not cfg.add_lower_pass:
with build_config():
return tvm.build(*args, **kwargs)
return tvm.build(*args, **kwargs)
"""Configurable VTA Hareware Environment scope."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import os
import copy
try:
# Allow missing import in config mode.
import tvm
from . import intrin
except ImportError:
pass
class DevContext(object):
"""Internal development context
This contains all the non-user facing compiler
internal context that is hold by the Environment.
Parameters
----------
env : Environment
The environment hosting the DevContext
Note
----
This class is introduced so we have a clear separation
of developer related stuffs and user facing attributes.
"""
# Memory id for DMA
MEM_ID_UOP = 0
MEM_ID_WGT = 1
MEM_ID_INP = 2
MEM_ID_ACC = 3
MEM_ID_OUT = 4
# VTA ALU Opcodes
ALU_OPCODE_MIN = 0
ALU_OPCODE_MAX = 1
ALU_OPCODE_ADD = 2
ALU_OPCODE_SHR = 3
# Task queue id (pipeline stage)
QID_LOAD_INP = 1
QID_LOAD_WGT = 1
QID_LOAD_OUT = 2
QID_STORE_OUT = 3
QID_COMPUTE = 2
QID_STORE_INP = 3
def __init__(self, env):
self.vta_axis = tvm.thread_axis("vta")
self.vta_push_uop = tvm.make.StringImm("VTAPushGEMMOp")
ctx = tvm.call_extern("handle", "VTATLSCommandHandle")
self.command_handle = tvm.make.Call(
"handle", "tvm_thread_context", [ctx],
tvm.expr.Call.Intrinsic, None, 0)
self.DEBUG_NO_SYNC = False
env._dev_ctx = self
self.gemm = intrin.gemm(env, env.mock_mode)
self.gevm = intrin.gevm(env, env.mock_mode)
def get_task_qid(self, qid):
"""Get transformed queue index."""
return 1 if self.DEBUG_NO_SYNC else qid
class Environment(object):
"""Hareware configuration object.
This object contains all the information
needed for compiling to a specific VTA backend.
Parameters
----------
cfg : dict of str to value.
The configuration parameters.
Example
--------
.. code-block:: python
# the following code reconfigures the environment
# temporarily to attributes specified in new_cfg.json
new_cfg = json.load(json.load(open("new_cfg.json")))
with vta.Environment(new_cfg):
# env works on the new environment
env = vta.get_env()
"""
current = None
cfg_keys = [
"target",
"LOG_INP_WIDTH",
"LOG_WGT_WIDTH",
"LOG_ACC_WIDTH",
"LOG_BATCH",
"LOG_BLOCK_IN",
"LOG_BLOCK_OUT",
"LOG_UOP_BUFF_SIZE",
"LOG_INP_BUFF_SIZE",
"LOG_WGT_BUFF_SIZE",
"LOG_ACC_BUFF_SIZE",
]
# constants
MAX_XFER = 1 << 22
# debug flags
DEBUG_DUMP_INSN = (1 << 1)
DEBUG_DUMP_UOP = (1 << 2)
DEBUG_SKIP_READ_BARRIER = (1 << 3)
DEBUG_SKIP_WRITE_BARRIER = (1 << 4)
# memory scopes
inp_scope = "local.inp_buffer"
wgt_scope = "local.wgt_buffer"
acc_scope = "local.acc_buffer"
# initialization function
def __init__(self, cfg):
# Log of input/activation width in bits
self.__dict__.update(cfg)
for key in self.cfg_keys:
if key not in cfg:
raise ValueError("Expect key %s in cfg" % key)
self.LOG_OUT_WIDTH = self.LOG_INP_WIDTH
self.LOG_OUT_BUFF_SIZE = (
self.LOG_ACC_BUFF_SIZE +
self.LOG_OUT_WIDTH -
self.LOG_ACC_WIDTH)
# width
self.INP_WIDTH = 1 << self.LOG_INP_WIDTH
self.WGT_WIDTH = 1 << self.LOG_WGT_WIDTH
self.ACC_WIDTH = 1 << self.LOG_ACC_WIDTH
self.BATCH = 1 << self.LOG_BATCH
self.BLOCK_IN = 1 << self.LOG_BLOCK_IN
self.BLOCK_OUT = 1 << self.LOG_BLOCK_OUT
self.OUT_WIDTH = self.INP_WIDTH
# buffer size
self.UOP_BUFF_SIZE = 1 << self.LOG_UOP_BUFF_SIZE
self.INP_BUFF_SIZE = 1 << self.LOG_INP_BUFF_SIZE
self.WGT_BUFF_SIZE = 1 << self.LOG_WGT_BUFF_SIZE
self.ACC_BUFF_SIZE = 1 << self.LOG_ACC_BUFF_SIZE
self.OUT_BUFF_SIZE = 1 << self.LOG_OUT_BUFF_SIZE
# bytes per buffer
self.INP_ELEM_BITS = (self.BATCH *
self.BLOCK_IN *
self.INP_WIDTH)
self.WGT_ELEM_BITS = (self.BLOCK_OUT *
self.BLOCK_IN *
self.WGT_WIDTH)
self.ACC_ELEM_BITS = (self.BATCH *
self.BLOCK_IN *
self.ACC_WIDTH)
self.INP_ELEM_BYTES = self.INP_ELEM_BITS // 8
self.WGT_ELEM_BYTES = self.WGT_ELEM_BITS // 8
self.ACC_ELEM_BYTES = self.ACC_ELEM_BITS // 8
# dtypes
self.acc_dtype = "int%d" % self.ACC_WIDTH
self.inp_dtype = "int%d" % self.INP_WIDTH
self.wgt_dtype = "int%d" % self.WGT_WIDTH
# lazy cached members
self.mock_mode = False
self._mock_env = None
self._dev_ctx = None
@property
def dev(self):
"""Developer context"""
if self._dev_ctx is None:
self._dev_ctx = DevContext(self)
return self._dev_ctx
@property
def mock(self):
"""A mock version of the Environment
The ALU, dma_copy and intrinsics will be
mocked to be nop.
"""
if self.mock_mode:
return self
if self._mock_env is None:
self._mock_env = copy.copy(self)
self._mock_env._dev_ctx = None
self._mock_env.mock_mode = True
return self._mock_env
@property
def dma_copy(self):
"""DMA copy pragma"""
return ("dma_copy"
if not self.mock_mode
else "skip_dma_copy")
@property
def alu(self):
"""ALU pragma"""
return ("alu"
if not self.mock_mode
else "skip_alu")
@property
def gemm(self):
"""GEMM intrinsic"""
return self.dev.gemm
@property
def gevm(self):
"""GEMM intrinsic"""
return self.dev.gevm
def get_env():
"""Get the current VTA Environment.
Returns
-------
env : Environment
The current environment.
"""
return Environment.current
# The memory information for the compiler
@tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope)
def mem_info_inp_buffer():
spec = get_env()
return tvm.make.node("MemoryInfo",
unit_bits=spec.INP_ELEM_BITS,
max_simd_bits=spec.INP_ELEM_BITS,
max_num_bits=spec.INP_BUFF_SIZE * 8,
head_address=None)
@tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope)
def mem_info_wgt_buffer():
spec = get_env()
return tvm.make.node("MemoryInfo",
unit_bits=spec.WGT_ELEM_BITS,
max_simd_bits=spec.WGT_ELEM_BITS,
max_num_bits=spec.WGT_BUFF_SIZE * 8,
head_address=None)
@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
def mem_info_out_buffer():
spec = get_env()
return tvm.make.node("MemoryInfo",
unit_bits=spec.ACC_ELEM_BITS,
max_simd_bits=spec.ACC_ELEM_BITS,
max_num_bits=spec.ACC_BUFF_SIZE * 8,
head_address=None)
# TVM related registration
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync")
def coproc_sync(op):
_ = op
return tvm.call_extern(
"int32", "VTASynchronize",
get_env().dev.command_handle, 1<<31)
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push")
def coproc_dep_push(op):
return tvm.call_extern(
"int32", "VTADepPush",
get_env().dev.command_handle,
op.args[0], op.args[1])
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop")
def coproc_dep_pop(op):
return tvm.call_extern(
"int32", "VTADepPop",
get_env().dev.command_handle,
op.args[0], op.args[1])
def _init_env():
"""Iniitalize the default global env"""
python_vta_dir = os.path.dirname(__file__)
filename = os.path.join(python_vta_dir, '../../config.mk')
keys = set()
for k in Environment.cfg_keys:
keys.add("VTA_" + k)
if not os.path.isfile(filename):
raise RuntimeError(
"Error: {} not found.make sure you have config.mk in your vta root"
.format(filename))
cfg = {}
with open(filename) as f:
for line in f:
for k in keys:
if k +" =" in line:
val = line.split("=")[1].strip()
cfg[k[4:]] = int(val)
cfg["target"] = "pynq"
return Environment(cfg)
Environment.current = _init_env()
......@@ -9,11 +9,13 @@ import argparse
import os
import ctypes
import tvm
from tvm.contrib import rpc, util, cc
from tvm.contrib import rpc, cc
@tvm.register_func("tvm.contrib.rpc.server.start", override=True)
def server_start():
"""callback when server starts."""
# pylint: disable=unused-variable
curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__)))
dll_path = os.path.abspath(
......
"""VTA configuration constants (should match hw_spec.h"""
from __future__ import absolute_import as _abs
import os
python_vta_dir = os.path.dirname(__file__)
print python_vta_dir
filename = os.path.join(python_vta_dir, '../../config.mk')
VTA_PYNQ_BRAM_W = 32
VTA_PYNQ_BRAM_D = 1024
VTA_PYNQ_NUM_BRAM = 124
keys = ["VTA_LOG_INP_WIDTH", "VTA_LOG_WGT_WIDTH", "VTA_LOG_ACC_WIDTH",
"VTA_LOG_BATCH", "VTA_LOG_BLOCK_IN", "VTA_LOG_BLOCK_OUT",
"VTA_LOG_UOP_BUFF_SIZE", "VTA_LOG_INP_BUFF_SIZE",
"VTA_LOG_WGT_BUFF_SIZE", "VTA_LOG_ACC_BUFF_SIZE"]
params = {}
if os.path.isfile(filename):
with open(filename) as f:
for line in f:
for k in keys:
if k+" =" in line:
val = line.split("=")[1].strip()
params[k] = int(val)
# print params
else:
print "Error: {} not found. Please make sure you have config.mk in your vta root".format(filename)
exit()
# The Constants
VTA_INP_WIDTH = 1 << params["VTA_LOG_INP_WIDTH"]
VTA_WGT_WIDTH = 1 << params["VTA_LOG_WGT_WIDTH"]
VTA_OUT_WIDTH = 1 << params["VTA_LOG_ACC_WIDTH"]
VTA_TARGET = "VTA_PYNQ_TARGET"
# Dimensions of the GEMM unit
# (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT)
VTA_BATCH = 1 << params["VTA_LOG_BATCH"]
VTA_BLOCK_IN = 1 << params["VTA_LOG_BLOCK_IN"]
VTA_BLOCK_OUT = 1 << params["VTA_LOG_BLOCK_OUT"]
# log-2 On-chip uop buffer size in Bytes
VTA_LOG_UOP_BUFF_SIZE = params["VTA_LOG_UOP_BUFF_SIZE"]
# log-2 On-chip input buffer size in Bytes
VTA_LOG_INP_BUFF_SIZE = params["VTA_LOG_INP_BUFF_SIZE"]
# log-2 On-chip wgt buffer size in Bytes
VTA_LOG_WGT_BUFF_SIZE = params["VTA_LOG_WGT_BUFF_SIZE"]
# log-2 On-chip output buffer size in Bytes
VTA_LOG_OUT_BUFF_SIZE = params["VTA_LOG_ACC_BUFF_SIZE"]
# Uop buffer size
VTA_UOP_BUFF_SIZE = 1 << VTA_LOG_UOP_BUFF_SIZE
# Input buffer size
VTA_INP_BUFF_SIZE = 1 << VTA_LOG_INP_BUFF_SIZE
# On-chip wgt buffer size in Bytes
VTA_WGT_BUFF_SIZE = 1 << VTA_LOG_WGT_BUFF_SIZE
# Output buffer size.
VTA_OUT_BUFF_SIZE = 1 << VTA_LOG_OUT_BUFF_SIZE
# Number of bytes per buffer
VTA_INP_ELEM_BYTES = (VTA_BATCH*VTA_BLOCK_IN*VTA_INP_WIDTH//8)
VTA_WGT_ELEM_BYTES = (VTA_BLOCK_OUT*VTA_BLOCK_IN*VTA_WGT_WIDTH//8)
VTA_OUT_ELEM_BYTES = (VTA_BATCH*VTA_BLOCK_OUT*VTA_OUT_WIDTH//8)
# Maximum external buffer size in bytes
VTA_MAX_XFER = 1 << 22
# Number of elements
VTA_INP_BUFF_DEPTH = VTA_INP_BUFF_SIZE//VTA_INP_ELEM_BYTES
VTA_WGT_BUFF_DEPTH = VTA_WGT_BUFF_SIZE//VTA_WGT_ELEM_BYTES
VTA_OUT_BUFF_DEPTH = VTA_OUT_BUFF_SIZE//VTA_OUT_ELEM_BYTES
# Memory id for DMA
VTA_MEM_ID_UOP = 0
VTA_MEM_ID_WGT = 1
VTA_MEM_ID_INP = 2
VTA_MEM_ID_ACC = 3
VTA_MEM_ID_OUT = 4
# VTA ALU Opcodes
VTA_ALU_OPCODE_MIN = 0
VTA_ALU_OPCODE_MAX = 1
VTA_ALU_OPCODE_ADD = 2
VTA_ALU_OPCODE_SHR = 3
# Task queue id (pipeline stage)
VTA_QID_LOAD_INP = 1
VTA_QID_LOAD_WGT = 1
VTA_QID_LOAD_OUT = 2
VTA_QID_STORE_OUT = 3
VTA_QID_COMPUTE = 2
VTA_QID_STORE_INP = 3
# Debug flags
DEBUG_DUMP_INSN = (1 << 1)
DEBUG_DUMP_UOP = (1 << 2)
DEBUG_SKIP_READ_BARRIER = (1 << 3)
DEBUG_SKIP_WRITE_BARRIER = (1 << 4)
......@@ -2,65 +2,46 @@
from __future__ import absolute_import as _abs
import tvm
from . import hw_spec as spec
from .runtime import VTA_AXIS, VTA_PUSH_UOP, get_task_qid
from .runtime import SCOPE_OUT, SCOPE_INP, SCOPE_WGT
# The memory information for the compiler
@tvm.register_func("tvm.info.mem.%s" % SCOPE_INP)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits=spec.VTA_INP_ELEM_BYTES * 8,
max_simd_bits=spec.VTA_INP_ELEM_BYTES * 8,
max_num_bits=spec.VTA_INP_BUFF_SIZE * 8,
head_address=None)
def gevm(env, mock=False):
"""Vector-matrix multiply intrinsic
@tvm.register_func("tvm.info.mem.%s" % SCOPE_WGT)
def mem_info_wgt_buffer():
return tvm.make.node("MemoryInfo",
unit_bits=spec.VTA_WGT_ELEM_BYTES * 8,
max_simd_bits=spec.VTA_WGT_ELEM_BYTES * 8,
max_num_bits=spec.VTA_WGT_BUFF_SIZE * 8,
head_address=None)
Parameters
----------
env : Environment
The Environment
@tvm.register_func("tvm.info.mem.%s" % SCOPE_OUT)
def mem_info_out_buffer():
return tvm.make.node("MemoryInfo",
unit_bits=spec.VTA_OUT_ELEM_BYTES * 8,
max_simd_bits=spec.VTA_OUT_ELEM_BYTES * 8,
max_num_bits=spec.VTA_OUT_BUFF_SIZE * 8,
head_address=None)
def intrin_gevm(mock=False):
"""Vector-matrix multiply intrinsic"""
wgt_lanes = spec.VTA_WGT_ELEM_BYTES * 8 // spec.VTA_WGT_WIDTH
assert wgt_lanes == spec.VTA_BLOCK_OUT * spec.VTA_BLOCK_IN
wgt_shape = (spec.VTA_BLOCK_OUT, spec.VTA_BLOCK_IN)
mock : bool
Whether create a mock version.
"""
wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
inp_lanes = spec.VTA_INP_ELEM_BYTES * 8 // spec.VTA_INP_WIDTH
out_lanes = spec.VTA_OUT_ELEM_BYTES * 8 // spec.VTA_OUT_WIDTH
inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]),
dtype="int%d" % spec.VTA_WGT_WIDTH,
name=SCOPE_WGT)
dtype="int%d" % env.WGT_WIDTH,
name=env.wgt_scope)
inp = tvm.placeholder((wgt_shape[1], ),
dtype="int%d" % spec.VTA_INP_WIDTH,
name=SCOPE_INP)
dtype="int%d" % env.INP_WIDTH,
name=env.inp_scope)
k = tvm.reduce_axis((0, wgt_shape[1]), name="k")
out_dtype = "int%d" % spec.VTA_OUT_WIDTH
out_dtype = "int%d" % env.ACC_WIDTH
out = tvm.compute((wgt_shape[0],),
lambda i: tvm.sum(inp[k].astype(out_dtype) *
wgt[i, k].astype(out_dtype),
axis=[k]),
name="out")
wgt_layout = tvm.decl_buffer(
wgt.shape, wgt.dtype, SCOPE_WGT,
scope=SCOPE_WGT, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
wgt.shape, wgt.dtype, env.wgt_scope,
scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
inp_layout = tvm.decl_buffer(
inp.shape, inp.dtype, SCOPE_INP,
scope=SCOPE_INP, offset_factor=inp_lanes, data_alignment=inp_lanes)
inp.shape, inp.dtype, env.inp_scope,
scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
out_layout = tvm.decl_buffer(
out.shape, out.dtype, SCOPE_OUT,
scope=SCOPE_OUT, offset_factor=out_lanes, data_alignment=out_lanes)
out.shape, out.dtype, env.acc_scope,
scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
def intrin_func(ins, outs):
"""Vector-matrix multiply intrinsic function"""
......@@ -69,8 +50,11 @@ def intrin_gevm(mock=False):
def instr(index):
"""Generate vector-matrix multiply VTA instruction"""
irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE))
irb.scope_attr(VTA_AXIS, "coproc_uop_scope", VTA_PUSH_UOP)
dev = env.dev
irb.scope_attr(dev.vta_axis, "coproc_scope",
dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
dev.vta_push_uop)
if index == 0 or index == 2:
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
......@@ -101,45 +85,54 @@ def intrin_gevm(mock=False):
out: out_layout})
def intrin_gemm(mock=False):
"""Matrix-matrix multiply intrinsic"""
wgt_lanes = spec.VTA_WGT_ELEM_BYTES * 8 // spec.VTA_WGT_WIDTH
assert wgt_lanes == spec.VTA_BLOCK_OUT * spec.VTA_BLOCK_IN
wgt_shape = (spec.VTA_BLOCK_OUT, spec.VTA_BLOCK_IN)
def gemm(env, mock=False):
"""Matrix-matrix multiply intrinsic
Parameters
----------
env : Environment
The Environment
mock : bool
Whether create a mock version.
"""
wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
inp_lanes = spec.VTA_INP_ELEM_BYTES * 8 // spec.VTA_INP_WIDTH
assert inp_lanes == spec.VTA_BATCH * spec.VTA_BLOCK_IN
inp_shape = (spec.VTA_BATCH, spec.VTA_BLOCK_IN)
inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
assert inp_lanes == env.BATCH * env.BLOCK_IN
inp_shape = (env.BATCH, env.BLOCK_IN)
assert inp_shape[0] * inp_shape[1] == inp_lanes
out_lanes = spec.VTA_OUT_ELEM_BYTES * 8 // spec.VTA_OUT_WIDTH
assert out_lanes == spec.VTA_BATCH * spec.VTA_BLOCK_OUT
out_shape = (spec.VTA_BATCH, spec.VTA_BLOCK_OUT)
out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
assert out_lanes == env.BATCH * env.BLOCK_OUT
out_shape = (env.BATCH, env.BLOCK_OUT)
assert out_shape[0] * out_shape[1] == out_lanes
wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]),
dtype="int%d" % spec.VTA_WGT_WIDTH,
name=SCOPE_WGT)
dtype="int%d" % env.WGT_WIDTH,
name=env.wgt_scope)
inp = tvm.placeholder((inp_shape[0], inp_shape[1]),
dtype="int%d" % spec.VTA_INP_WIDTH,
name=SCOPE_INP)
dtype="int%d" % env.INP_WIDTH,
name=env.inp_scope)
k = tvm.reduce_axis((0, wgt_shape[1]), name="k")
out_dtype = "int%d" % spec.VTA_OUT_WIDTH
out_dtype = "int%d" % env.ACC_WIDTH
out = tvm.compute((out_shape[0], out_shape[1]),
lambda i, j: tvm.sum(inp[i, k].astype(out_dtype) *
wgt[j, k].astype(out_dtype),
axis=[k]),
name="out")
wgt_layout = tvm.decl_buffer(
wgt.shape, wgt.dtype, SCOPE_WGT,
scope=SCOPE_WGT, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
wgt.shape, wgt.dtype, env.wgt_scope,
scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
inp_layout = tvm.decl_buffer(
inp.shape, inp.dtype, SCOPE_INP,
scope=SCOPE_INP, offset_factor=inp_lanes, data_alignment=inp_lanes)
inp.shape, inp.dtype, env.inp_scope,
scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
out_layout = tvm.decl_buffer(
out.shape, out.dtype, SCOPE_OUT,
scope=SCOPE_OUT, offset_factor=out_lanes, data_alignment=out_lanes)
out.shape, out.dtype, env.acc_scope,
scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
def intrin_func(ins, outs):
"""Matrix-matrix multiply intrinsic function"""
......@@ -148,8 +141,11 @@ def intrin_gemm(mock=False):
def instr(index):
"""Generate matrix-matrix multiply VTA instruction"""
irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE))
irb.scope_attr(VTA_AXIS, "coproc_uop_scope", VTA_PUSH_UOP)
dev = env.dev
irb.scope_attr(dev.vta_axis, "coproc_scope",
dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
dev.vta_push_uop)
if index == 0 or index == 2:
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
......@@ -178,6 +174,3 @@ def intrin_gemm(mock=False):
binds={inp: inp_layout,
wgt: wgt_layout,
out: out_layout})
GEMM = intrin_gemm()
GEVM = intrin_gevm()
"""Additional IR Pass for VTA"""
# pylint: disable=len-as-condition
from __future__ import absolute_import as _abs
import tvm
from topi import util as util
from . import hw_spec as spec
from . runtime import CB_HANDLE, VTA_AXIS, VTA_PUSH_UOP
from . runtime import SCOPE_OUT, SCOPE_INP, SCOPE_WGT, DMA_COPY, get_task_qid
from .environment import get_env
def fold_uop_loop(stmt_in):
"""Pass to fold uop loops"""
"""Detect and fold uop loop.
VTA support uop programming model
that recognizes loop structure.
This pass detect the loop structure
and extract that into uop loop AST.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Output statement.
"""
env = get_env()
def _fold_outermost_loop(body):
stmt = body
while not isinstance(stmt, tvm.stmt.For):
if isinstance(stmt, (tvm.stmt.ProducerConsumer, )):
if isinstance(stmt, (tvm.stmt.ProducerConsumer,)):
stmt = stmt.body
else:
return None, body, None
......@@ -30,7 +47,8 @@ def fold_uop_loop(stmt_in):
args = []
args += op.args[:base_args]
for i in range(3):
m = tvm.arith.DetectLinearEquation(op.args[i + base_args], [loop_var])
m = tvm.arith.DetectLinearEquation(
op.args[i + base_args], [loop_var])
if not m:
fail[0] = True
return op
......@@ -68,7 +86,7 @@ def fold_uop_loop(stmt_in):
def _do_fold(stmt):
if (stmt.attr_key == "coproc_uop_scope" and
isinstance(stmt.value, tvm.expr.StringImm) and
stmt.value.value == VTA_PUSH_UOP.value):
stmt.value.value == env.dev.vta_push_uop.value):
body = stmt.body
begins = []
ends = []
......@@ -98,7 +116,12 @@ def fold_uop_loop(stmt_in):
def cpu_access_rewrite(stmt_in):
"""Rewrite the code when there is CPU access happening.
"""Detect CPU access to VTA buffer and get address correctly.
VTA's buffer is an opaque handle that do not
correspond to address in CPU.
This pass detect CPU access and rewrite to use pointer
returned VTABufferCPUPtr for CPU access.
Parameters
----------
......@@ -110,6 +133,7 @@ def cpu_access_rewrite(stmt_in):
stmt_out : Stmt
Transformed statement
"""
env = get_env()
rw_info = {}
def _post_order(op):
if isinstance(op, tvm.stmt.Allocate):
......@@ -119,7 +143,8 @@ def cpu_access_rewrite(stmt_in):
new_var = rw_info[buffer_var]
let_stmt = tvm.make.LetStmt(
new_var, tvm.call_extern(
"handle", "VTABufferCPUPtr", CB_HANDLE,
"handle", "VTABufferCPUPtr",
env.dev.command_handle,
buffer_var), op.body)
alloc = tvm.make.Allocate(
buffer_var, op.dtype, op.extents,
......@@ -147,7 +172,8 @@ def cpu_access_rewrite(stmt_in):
for buffer_var, new_var in rw_info.items():
stmt = tvm.make.LetStmt(
new_var, tvm.call_extern(
"handle", "VTABufferCPUPtr", CB_HANDLE,
"handle", "VTABufferCPUPtr",
env.dev.command_handle,
buffer_var), stmt)
return stmt
......@@ -184,9 +210,6 @@ def lift_alloc_to_scope_begin(stmt_in):
else:
raise RuntimeError("unexpected op")
del slist[:]
# n = len(slist)
# for i in range(n):
# slist.pop()
return body
def _pre_order(op):
......@@ -220,7 +243,7 @@ def lift_alloc_to_scope_begin(stmt_in):
def inject_skip_copy(stmt_in):
"""Pass to inject skip copy stmt, used in debug.
"""Pass to inject skip copy stmt, used for debug purpose.
Parameters
----------
......@@ -233,7 +256,7 @@ def inject_skip_copy(stmt_in):
Transformed statement
"""
def _do_fold(stmt):
if stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_dma_copy":
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_dma_copy"):
return tvm.make.Evaluate(0)
return None
return tvm.ir_pass.IRTransform(
......@@ -286,6 +309,7 @@ def inject_dma_intrin(stmt_in):
stmt_out : Stmt
Transformed statement
"""
env = get_env()
def _check_compact(buf):
ndim = len(buf.shape)
size = tvm.const(1, buf.shape[0].dtype)
......@@ -321,8 +345,9 @@ def inject_dma_intrin(stmt_in):
x_stride = buf.strides[ndim - base]
next_base = base
if not util.equal_const_int(x_stride % elem_block, 0):
raise RuntimeError("scope %s need to have block=%d, shape=%s, strides=%s" % (
scope, elem_block, buf.shape, buf.strides))
raise RuntimeError(
"scope %s need to have block=%d, shape=%s, strides=%s" % (
scope, elem_block, buf.shape, buf.strides))
for i in range(base, ndim + 1):
k = ndim - i
if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0):
......@@ -338,7 +363,6 @@ def inject_dma_intrin(stmt_in):
shape = list(reversed(shape))
return shape, strides
def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
elem_block = elem_bytes * 8 // elem_width
if buf.dtype != dtype:
......@@ -425,47 +449,50 @@ def inject_dma_intrin(stmt_in):
def _inject_copy(src, dst, pad_before, pad_after, pad_value):
# FIXME: pad_value is ignored...
_ = pad_value
if dst.scope == "global":
# Store
if pad_before or pad_after:
raise RuntimeError("Do not support copy into DRAM with pad")
if src.scope == SCOPE_OUT:
elem_width = spec.VTA_INP_WIDTH # output compression to inp type
elem_bytes = spec.VTA_INP_ELEM_BYTES # output compression to inp type
mem_type = spec.VTA_MEM_ID_OUT
data_type = "int%d" % spec.VTA_INP_WIDTH
task_qid = spec.VTA_QID_STORE_OUT
if src.scope == env.acc_scope:
elem_width = env.INP_WIDTH # output compression to inp type
elem_bytes = env.INP_ELEM_BYTES # output compression to inp type
mem_type = env.dev.MEM_ID_OUT
data_type = "int%d" % env.INP_WIDTH
task_qid = env.dev.QID_STORE_OUT
else:
raise RuntimeError("Do not support copy %s->dram" % (src.scope))
_check_compact(src)
x_size, y_size, x_stride, offset = _get_2d_pattern(
dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True)
irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(task_qid))
irb.scope_attr(env.dev.vta_axis, "coproc_scope",
env.dev.get_task_qid(task_qid))
irb.emit(tvm.call_extern(
"int32", "VTAStoreBuffer2D", CB_HANDLE,
"int32", "VTAStoreBuffer2D",
env.dev.command_handle,
src.access_ptr("r", "int32"),
mem_type, dst.data, offset, x_size, y_size, x_stride))
return irb.get()
elif src.scope == "global":
if dst.scope == SCOPE_OUT:
elem_width = spec.VTA_OUT_WIDTH
elem_bytes = spec.VTA_OUT_ELEM_BYTES
mem_type = spec.VTA_MEM_ID_ACC
data_type = "int%d" % spec.VTA_OUT_WIDTH
task_qid = spec.VTA_QID_LOAD_OUT
elif dst.scope == SCOPE_INP:
elem_width = spec.VTA_INP_WIDTH
elem_bytes = spec.VTA_INP_ELEM_BYTES
mem_type = spec.VTA_MEM_ID_INP
data_type = "int%d" % spec.VTA_INP_WIDTH
task_qid = spec.VTA_QID_LOAD_INP
elif dst.scope == SCOPE_WGT:
elem_width = spec.VTA_WGT_WIDTH
elem_bytes = spec.VTA_WGT_ELEM_BYTES
mem_type = spec.VTA_MEM_ID_WGT
data_type = "int%d" % spec.VTA_WGT_WIDTH
task_qid = spec.VTA_QID_LOAD_WGT
if dst.scope == env.acc_scope:
elem_width = env.ACC_WIDTH
elem_bytes = env.ACC_ELEM_BYTES
mem_type = env.dev.MEM_ID_ACC
data_type = "int%d" % env.ACC_WIDTH
task_qid = env.dev.QID_LOAD_OUT
elif dst.scope == env.inp_scope:
elem_width = env.INP_WIDTH
elem_bytes = env.INP_ELEM_BYTES
mem_type = env.dev.MEM_ID_INP
data_type = "int%d" % env.INP_WIDTH
task_qid = env.dev.QID_LOAD_INP
elif dst.scope == env.wgt_scope:
elem_width = env.WGT_WIDTH
elem_bytes = env.WGT_ELEM_BYTES
mem_type = env.dev.MEM_ID_WGT
data_type = "int%d" % env.WGT_WIDTH
task_qid = env.dev.QID_LOAD_WGT
else:
raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
# collect pad statistics
......@@ -498,13 +525,16 @@ def inject_dma_intrin(stmt_in):
_check_compact(dst)
x_size, y_size, x_stride, offset = _get_2d_pattern(
src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold)
src, elem_width, elem_bytes, data_type,
dst.scope, allow_fold=allow_fold)
irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(task_qid))
irb.scope_attr(env.dev.vta_axis, "coproc_scope",
env.dev.get_task_qid(task_qid))
irb.emit(tvm.call_extern(
"int32", "VTALoadBuffer2D", CB_HANDLE,
"int32", "VTALoadBuffer2D",
env.dev.command_handle,
src.data, offset, x_size, y_size, x_stride,
x_pad_before, y_pad_before,
x_pad_after, y_pad_after,
......@@ -514,7 +544,7 @@ def inject_dma_intrin(stmt_in):
else:
raise RuntimeError("Donot support copy %s->%s" % (src.scope, dst.scope))
return tvm.ir_pass.InjectCopyIntrin(stmt_in, DMA_COPY, _inject_copy)
return tvm.ir_pass.InjectCopyIntrin(stmt_in, "dma_copy", _inject_copy)
def annotate_alu_coproc_scope(stmt_in):
......@@ -530,11 +560,14 @@ def annotate_alu_coproc_scope(stmt_in):
stmt_out : Stmt
Transformed statement
"""
env = get_env()
def _do_fold(stmt):
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"):
irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE))
irb.scope_attr(VTA_AXIS, "coproc_uop_scope", tvm.make.StringImm("VTAPushALUOp"))
irb.scope_attr(env.dev.vta_axis, "coproc_scope",
env.dev.get_task_qid(env.dev.QID_COMPUTE))
irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
tvm.make.StringImm("VTAPushALUOp"))
irb.emit(stmt)
return irb.get()
elif (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_alu"):
......@@ -560,6 +593,7 @@ def inject_alu_intrin(stmt_in):
stmt_out : Stmt
Transformed statement
"""
env = get_env()
def _do_fold(stmt):
def _equal(x, y):
return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0)
......@@ -618,32 +652,32 @@ def inject_alu_intrin(stmt_in):
tmp_body = tmp_body.body
# Derive opcode
if isinstance(loop_body.value, tvm.expr.Add):
alu_opcode = spec.VTA_ALU_OPCODE_ADD
alu_opcode = env.dev.ALU_OPCODE_ADD
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Sub):
alu_opcode = spec.VTA_ALU_OPCODE_SUB
alu_opcode = env.dev.ALU_OPCODE_SUB
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Mul):
alu_opcode = spec.VTA_ALU_OPCODE_MUL
alu_opcode = env.dev.ALU_OPCODE_MUL
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Min):
alu_opcode = spec.VTA_ALU_OPCODE_MIN
alu_opcode = env.dev.ALU_OPCODE_MIN
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Max):
alu_opcode = spec.VTA_ALU_OPCODE_MAX
alu_opcode = env.dev.ALU_OPCODE_MAX
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Call):
if loop_body.value.name == 'shift_left':
alu_opcode = spec.VTA_ALU_OPCODE_SHR
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = tvm.ir_pass.Simplify(-loop_body.value.args[1])
elif loop_body.value.name == 'shift_right':
alu_opcode = spec.VTA_ALU_OPCODE_SHR
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = loop_body.value.args[1]
else:
......@@ -696,26 +730,26 @@ def inject_alu_intrin(stmt_in):
extents = list(extents)
assert len(src_coeff) > 1
assert len(dst_coeff) > 1
assert len(extents) > 0
assert len(extents) != 0
assert tvm.ir_pass.Equal(
tvm.ir_pass.Simplify(
src_coeff[-1]%(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)), 0)
src_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir_pass.Equal(
tvm.ir_pass.Simplify(
dst_coeff[-1]%(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)), 0)
dst_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir_pass.Equal(src_coeff[-2], 1)
assert tvm.ir_pass.Equal(dst_coeff[-2], 1)
if spec.VTA_BATCH > 1:
if env.BATCH > 1:
assert len(src_coeff) > 2
assert len(dst_coeff) > 2
assert len(extents) > 1
assert tvm.ir_pass.Equal(src_coeff[-3], spec.VTA_BLOCK_OUT)
assert tvm.ir_pass.Equal(dst_coeff[-3], spec.VTA_BLOCK_OUT)
assert tvm.ir_pass.Equal(src_coeff[-3], env.BLOCK_OUT)
assert tvm.ir_pass.Equal(dst_coeff[-3], env.BLOCK_OUT)
# Apply tensorization of the loop coefficients
src_offset = src_coeff[-1]
dst_offset = dst_coeff[-1]
if spec.VTA_BATCH == 1:
if env.BATCH == 1:
src_coeff = src_coeff[:-2]
dst_coeff = dst_coeff[:-2]
extents = extents[:-1]
......@@ -726,9 +760,9 @@ def inject_alu_intrin(stmt_in):
src_coeff.append(src_offset)
dst_coeff.append(dst_offset)
src_coeff = [
tvm.ir_pass.Simplify(c/(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)) for c in src_coeff]
tvm.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
dst_coeff = [
tvm.ir_pass.Simplify(c/(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)) for c in dst_coeff]
tvm.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
# Flatten the outer loops
if extents:
......@@ -750,13 +784,25 @@ def inject_alu_intrin(stmt_in):
irb.emit(tvm.call_extern(
"int32", "VTAUopLoopEnd"))
return irb.get()
return stmt
stmt_out = tvm.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"])
return stmt_out
def debug_print(stmt):
print stmt
"""A debug pass that print the stmt
Parameters
----------
stmt : Stmt
The input statement
Returns
-------
stmt : Stmt
The
"""
print(stmt)
return stmt
"""Mock interface for skip part of compute """
from .intrin import intrin_gevm, intrin_gemm
GEMM = intrin_gemm(True)
GEVM = intrin_gevm(True)
DMA_COPY = "skip_dma_copy"
ALU = "skip_alu"
"""VTA RPC client function"""
import os
from . import hw_spec as spec
from .environment import get_env
def reconfig_runtime(remote):
"""Reconfigure remote runtime based on current hardware spec.
......@@ -10,6 +11,7 @@ def reconfig_runtime(remote):
remote : RPCSession
The TVM RPC session
"""
env = get_env()
keys = ["VTA_LOG_WGT_WIDTH",
"VTA_LOG_INP_WIDTH",
"VTA_LOG_ACC_WIDTH",
......@@ -22,9 +24,10 @@ def reconfig_runtime(remote):
"VTA_LOG_WGT_BUFF_SIZE",
"VTA_LOG_ACC_BUFF_SIZE",
"VTA_LOG_OUT_BUFF_SIZE"]
cflags = ["-D%s" % spec.VTA_TARGET]
cflags = ["-DVTA_%s_TARGET" % env.target.upper()]
for k in keys:
cflags += ["-D%s=%s" % (k, str(getattr(spec, k)))]
cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))]
freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime")
freconfig(" ".join(cflags))
......
"""Runtime function related hooks"""
from __future__ import absolute_import as _abs
import tvm
def thread_local_command_buffer():
"""Get thread local command buffer"""
ctx = tvm.call_extern("handle", "VTATLSCommandHandle")
return tvm.make.Call(
"handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0)
CB_HANDLE = thread_local_command_buffer()
VTA_AXIS = tvm.thread_axis("vta")
VTA_PUSH_UOP = tvm.make.StringImm("VTAPushGEMMOp")
SCOPE_INP = "local.inp_buffer"
SCOPE_OUT = "local.out_buffer"
SCOPE_WGT = "local.wgt_buffer"
DMA_COPY = "dma_copy"
ALU = "alu"
DEBUG_NO_SYNC = False
def get_task_qid(qid):
"""Get transformed queue index."""
return 1 if DEBUG_NO_SYNC else qid
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync")
def coproc_sync(op):
return tvm.call_extern(
"int32", "VTASynchronize", CB_HANDLE, 1<<31)
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push")
def coproc_dep_push(op):
return tvm.call_extern(
"int32", "VTADepPush", CB_HANDLE, op.args[0], op.args[1])
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop")
def coproc_dep_pop(op):
return tvm.call_extern(
"int32", "VTADepPop", CB_HANDLE, op.args[0], op.args[1])
......@@ -7,17 +7,13 @@ import tvm
import topi
from nnvm.top import registry as reg, OpPattern
from . import environment as vta
from . import intrin, runtime as vta
from intrin import GEVM
TARGET_BOARD = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
Workload = namedtuple("Conv2DWorkload",
['height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
def packed_conv2d(data,
kernel,
padding,
......@@ -58,10 +54,9 @@ def packed_conv2d(data,
def _build(funcs, target, target_host):
tvm_t = tvm.target.create(target)
if tvm_t.device_name == "vta":
return tvm.build(funcs, target="ext_dev",
target_host=TARGET_BOARD)
elif tvm_t.device_name == "rasp" or tvm_t.device_name == "tcpu":
return tvm.build(funcs, target=TARGET_BOARD)
return tvm.build(funcs, target="ext_dev", target_host=target_host)
elif tvm_t.device_name == "rasp" or tvm_t.device_name == "vta-cpu":
return tvm.build(funcs, target=target_host)
return tvm.build(funcs, target=target)
......@@ -224,36 +219,39 @@ def schedule_packed_conv2d(outs):
wrkld = _get_workload(data, pad_data, kernel, output)
plan = _WL2PLAN[wrkld]
load_inp = load_wgt = load_out = store_out = "dma_copy"
alu = "alu"
gevm = GEVM
env = vta.get_env()
load_inp = load_wgt = load_out = store_out = env.dma_copy
alu = env.alu
gevm = env.gevm
# schedule1
oshape = topi.util.get_const_tuple(output.shape)
s = tvm.create_schedule(output.op)
# setup pad
if pad_data is not None:
cdata = pad_data
s[pad_data].set_scope(vta.SCOPE_INP)
s[pad_data].set_scope(env.inp_scope)
else:
cdata = s.cache_read(data, vta.SCOPE_INP, [conv2d_stage])
ckernel = s.cache_read(kernel, vta.SCOPE_WGT, [conv2d_stage])
s[conv2d_stage].set_scope(vta.SCOPE_OUT)
cdata = s.cache_read(data, env.inp_scope, [conv2d_stage])
ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage])
s[conv2d_stage].set_scope(env.acc_scope)
# cache read input
cache_read_ewise = []
for consumer, tensor in ewise_inputs:
cache_read_ewise.append(
s.cache_read(tensor, vta.SCOPE_OUT, [consumer]))
s.cache_read(tensor, env.acc_scope, [consumer]))
# set ewise scope
for op in ewise_ops:
s[op].set_scope(vta.SCOPE_OUT)
s[op].set_scope(env.acc_scope)
s[op].pragma(s[op].op.axis[0], alu)
# tile
oc_factor = (plan.oc_factor if plan.oc_factor
else wrkld.out_filter // vta.VTA_BLOCK_OUT)
else wrkld.out_filter // vta.BLOCK_OUT)
h_factor = (plan.h_factor if plan.h_factor else oshape[2])
w_factor = (plan.w_factor if plan.w_factor else oshape[3])
......
......@@ -11,9 +11,6 @@ import pandas as pd
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon"
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
Workload = namedtuple("Conv2DWorkload",
['batch', 'height', 'width', 'in_filter', 'out_filter',
......@@ -46,12 +43,13 @@ class Conv2DSchedule(object):
Schedule = Conv2DSchedule
def get_insn_count(layer, sched):
env = vta.get_env()
b, h, w, ci, co = sched
b_factor = b
h_factor = layer.height / h
w_factor = layer.width / w
ci_factor = int(np.ceil(float(layer.in_filter) / (ci * vta.VTA_BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) / (co * vta.VTA_BLOCK_OUT)))
ci_factor = int(np.ceil(float(layer.in_filter) / (ci * env.BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) / (co * env.BLOCK_OUT)))
input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
output_xfers = b_factor * h_factor * w_factor * co_factor
......@@ -62,6 +60,7 @@ def get_insn_count(layer, sched):
return insn_count
def find_schedules(layer, mtOnly=False, bestOnly=False):
env = vta.get_env()
# Helper function to get factors
def find_factors(n):
factors = []
......@@ -70,11 +69,11 @@ def find_schedules(layer, mtOnly=False, bestOnly=False):
factors.append(i)
return factors
# Scheduling exploration
batch_factors = find_factors(int(np.ceil(float(layer.batch) / vta.VTA_BATCH)))
batch_factors = find_factors(int(np.ceil(float(layer.batch) / env.BATCH)))
height_factors = find_factors(layer.height / layer.hstride)
width_factors = find_factors(layer.width / layer.wstride)
cin_factors = find_factors(int(np.ceil(float(layer.in_filter) / vta.VTA_BLOCK_IN)))
cout_factors = find_factors(int(np.ceil(float(layer.out_filter) / vta.VTA_BLOCK_OUT)))
cin_factors = find_factors(int(np.ceil(float(layer.in_filter) / env.BLOCK_IN)))
cout_factors = find_factors(int(np.ceil(float(layer.out_filter) / env.BLOCK_OUT)))
ht_factors = [1, 2]
cot_factors = [1, 2]
# Explore schedules
......@@ -90,12 +89,12 @@ def find_schedules(layer, mtOnly=False, bestOnly=False):
if ci == 1:
schedules.append([b, h, w, ci, co])
# Filter the schedules that wouldn't work in the available BRAM sizes
input_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_IN * vta.VTA_INP_WIDTH
weight_elem_size_b = vta.VTA_BLOCK_IN * vta.VTA_BLOCK_OUT * vta.VTA_WGT_WIDTH
output_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_OUT * vta.VTA_OUT_WIDTH
input_brams_capacity_b = vta.VTA_INP_BUFF_SIZE * 8
weight_brams_capacity_b = vta.VTA_WGT_BUFF_SIZE * 8
output_brams_capacity_b = vta.VTA_OUT_BUFF_SIZE * 8
input_elem_size_b = env.BATCH * env.BLOCK_IN * env.INP_WIDTH
weight_elem_size_b = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH
output_elem_size_b = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH
input_brams_capacity_b = env.INP_BUFF_SIZE * 8
weight_brams_capacity_b = env.WGT_BUFF_SIZE * 8
output_brams_capacity_b = env.OUT_BUFF_SIZE * 8
fil_sched = []
xfer_size = []
for sched in schedules:
......@@ -125,7 +124,7 @@ def find_schedules(layer, mtOnly=False, bestOnly=False):
if input_tile_elems*input_elem_size_b <= input_brams_capacity_b/(cot*ht) and \
weight_tile_elems*weight_elem_size_b <= weight_brams_capacity_b and \
output_tile_elems*output_elem_size_b <= output_brams_capacity_b/(cot*ht) and \
insn_count <= vta.VTA_MAX_XFER / 16 and \
insn_count <= env.MAX_XFER / 16 and \
h > 2 and w > 2:
schedule = Schedule(oc_factor=co, ko_factor=ci, h_factor=h,
w_factor=w, oc_nthread=cot, h_nthread=ht)
......@@ -137,10 +136,11 @@ def find_schedules(layer, mtOnly=False, bestOnly=False):
return fil_sched
def get_data_movementB(sched, layer):
env = vta.get_env()
# Derive data movement
input_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_IN * vta.VTA_INP_WIDTH
weight_elem_size_b = vta.VTA_BLOCK_IN * vta.VTA_BLOCK_OUT * vta.VTA_WGT_WIDTH
output_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_OUT * vta.VTA_OUT_WIDTH
input_elem_size_b = env.BATCH * env.BLOCK_IN * env.INP_WIDTH
weight_elem_size_b = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH
output_elem_size_b = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH
b = sched.b_factor
h = sched.h_factor
w = sched.w_factor
......@@ -154,11 +154,11 @@ def get_data_movementB(sched, layer):
weight_tile_elems = layer.hkernel * layer.wkernel * ci
output_tile_elems = b * h * w * co
# Derive factors
b_factor = int(np.ceil(float(layer.batch) / (b * vta.VTA_BATCH)))
b_factor = int(np.ceil(float(layer.batch) / (b * env.BATCH)))
h_factor = (layer.height / layer.hstride) / h
w_factor = (layer.width / layer.wstride) / w
ci_factor = int(np.ceil(float(layer.in_filter) / (ci * vta.VTA_BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) / (co * vta.VTA_BLOCK_OUT)))
ci_factor = int(np.ceil(float(layer.in_filter) / (ci * env.BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) / (co * env.BLOCK_OUT)))
# Derive transfers
input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
......@@ -171,19 +171,20 @@ def get_data_movementB(sched, layer):
return total_xfer_B
def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True):
assert batch_size % vta.VTA_BATCH == 0
assert wl.in_filter % vta.VTA_BLOCK_IN == 0
assert wl.out_filter % vta.VTA_BLOCK_OUT == 0
data_shape = (batch_size//vta.VTA_BATCH, wl.in_filter//vta.VTA_BLOCK_IN,
wl.height, wl.width, vta.VTA_BATCH, vta.VTA_BLOCK_IN)
kernel_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, wl.in_filter//vta.VTA_BLOCK_IN,
wl.hkernel, wl.wkernel, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN)
env = vta.get_env()
assert batch_size % env.BATCH == 0
assert wl.in_filter % env.BLOCK_IN == 0
assert wl.out_filter % env.BLOCK_OUT == 0
data_shape = (batch_size//env.BATCH, wl.in_filter//env.BLOCK_IN,
wl.height, wl.width, env.BATCH, env.BLOCK_IN)
kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN,
wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
res_shape = (batch_size//vta.VTA_BATCH, wl.out_filter//vta.VTA_BLOCK_OUT,
fout_height, fout_width, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)
data = tvm.placeholder(data_shape, name="data", dtype=inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=wgt_dtype)
res_shape = (batch_size//env.BATCH, wl.out_filter//env.BLOCK_OUT,
fout_height, fout_width, env.BATCH, env.BLOCK_OUT)
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
if wl.hpad or wl.wpad:
data_buf = topi.nn.pad(data, [0, 0, wl.hpad, wl.wpad, 0, 0], name="data_buf")
else:
......@@ -191,22 +192,22 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
kernel_buf = tvm.compute(kernel_shape, lambda *i: kernel(*i), "kernel_buf")
di = tvm.reduce_axis((0, wl.hkernel), name='di')
dj = tvm.reduce_axis((0, wl.wkernel), name='dj')
ko = tvm.reduce_axis((0, wl.in_filter//vta.VTA_BLOCK_IN), name='ko')
ki = tvm.reduce_axis((0, vta.VTA_BLOCK_IN), name='ki')
ko = tvm.reduce_axis((0, wl.in_filter//env.BLOCK_IN), name='ko')
ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki')
res_cnv = tvm.compute(
res_shape,
lambda bo, co, i, j, bi, ci: tvm.sum(
data_buf[bo, ko, i*wl.hstride+di, j*wl.wstride+dj, bi, ki].astype(out_dtype) *
kernel_buf[co, ko, di, dj, ci, ki].astype(out_dtype),
data_buf[bo, ko, i*wl.hstride+di, j*wl.wstride+dj, bi, ki].astype(env.acc_dtype) *
kernel_buf[co, ko, di, dj, ci, ki].astype(env.acc_dtype),
axis=[ko, di, dj, ki]),
name="res_cnv")
res_shf = tvm.compute(res_shape, lambda *i: res_cnv(*i) >> 8, name="res_shf")
res = tvm.compute(res_shape, lambda *i: res_shf(*i).astype(inp_dtype), name="res")
res = tvm.compute(res_shape, lambda *i: res_shf(*i).astype(env.inp_dtype), name="res")
num_ops = batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
total_xfer_B = get_data_movementB(sched, wl)
def verify(s, check_correctness):
mod = tvm.build(s, [data, kernel, res], "ext_dev", target, name="conv2d")
mod = vta.build(s, [data, kernel, res], "ext_dev", target, name="conv2d")
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("conv2d.o"))
......@@ -220,12 +221,12 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
kernel_orig = np.random.randint(
-128, 128, size=(wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)).astype(kernel.dtype)
data_packed = data_orig.reshape(
batch_size//vta.VTA_BATCH, vta.VTA_BATCH,
wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN,
batch_size//env.BATCH, env.BATCH,
wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3))
kernel_packed = kernel_orig.reshape(
wl.out_filter//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT,
wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN,
wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT,
wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_packed, ctx)
......@@ -237,13 +238,13 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
(0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
if check_correctness:
res_ref = mx.nd.Convolution(
mx.nd.array(data_orig.astype(out_dtype), mx.cpu(0)),
mx.nd.array(kernel_orig.astype(out_dtype), mx.cpu(0)),
mx.nd.array(data_orig.astype(env.acc_dtype), mx.cpu(0)),
mx.nd.array(kernel_orig.astype(env.acc_dtype), mx.cpu(0)),
stride=(wl.hstride, wl.wstride),
kernel=(wl.hkernel, wl.wkernel),
num_filter=wl.out_filter,
no_bias=True,
pad=(wl.hpad, wl.wpad)).asnumpy().astype(out_dtype)
pad=(wl.hpad, wl.wpad)).asnumpy().astype(env.acc_dtype)
res_ref = np.right_shift(res_ref, 8).astype(res.dtype)
np.testing.assert_allclose(res_unpack, res_ref)
print("Correctness check pass...")
......@@ -253,13 +254,13 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
print_ir, check_correctness):
# schedule1
s = tvm.create_schedule(res.op)
s[data_buf].set_scope(vta.SCOPE_INP)
s[kernel_buf].set_scope(vta.SCOPE_WGT)
s[res_cnv].set_scope(vta.SCOPE_OUT)
s[res_shf].set_scope(vta.SCOPE_OUT)
s[data_buf].set_scope(env.inp_scope)
s[kernel_buf].set_scope(env.wgt_scope)
s[res_cnv].set_scope(env.acc_scope)
s[res_shf].set_scope(env.acc_scope)
# tile
oc_factor = (sched.oc_factor if sched.oc_factor
else wl.out_filter // vta.VTA_BLOCK_OUT)
else wl.out_filter // env.BLOCK_OUT)
h_factor = (sched.h_factor if sched.h_factor else fout_height)
w_factor = (sched.w_factor if sched.w_factor else fout_width)
xbo, xco, xi, xj, xbi, xci = s[res].op.axis
......@@ -304,8 +305,8 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
def run_test(header, print_ir, check_correctness):
s = [1, sched.oc_factor, sched.ko_factor, sched.h_factor, sched.w_factor]
cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY,
vta.GEMM, vta.ALU, vta.DMA_COPY,
env.dma_copy, env.dma_copy,
env.gemm, env.alu, env.dma_copy,
print_ir, check_correctness)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
......@@ -316,15 +317,15 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
log_frame["total-gops"].append(gops)
log_frame["total-cost"].append(cost.mean)
log_frame["total-insn"].append(get_insn_count(wl, s))
log_frame["block-batch"].append(vta.VTA_BATCH)
log_frame["block-in"].append(vta.VTA_BLOCK_IN)
log_frame["block-out"].append(vta.VTA_BLOCK_OUT)
log_frame["inp-width"].append(vta.VTA_INP_WIDTH)
log_frame["wgt-width"].append(vta.VTA_WGT_WIDTH)
log_frame["uop-size"].append(vta.VTA_UOP_BUFF_SIZE)
log_frame["inp-size"].append(vta.VTA_INP_BUFF_SIZE)
log_frame["wgt-size"].append(vta.VTA_WGT_BUFF_SIZE)
log_frame["out-size"].append(vta.VTA_OUT_BUFF_SIZE)
log_frame["block-batch"].append(env.BATCH)
log_frame["block-in"].append(env.BLOCK_IN)
log_frame["block-out"].append(env.BLOCK_OUT)
log_frame["inp-width"].append(env.INP_WIDTH)
log_frame["wgt-width"].append(env.WGT_WIDTH)
log_frame["uop-size"].append(env.UOP_BUFF_SIZE)
log_frame["inp-size"].append(env.INP_BUFF_SIZE)
log_frame["wgt-size"].append(env.WGT_BUFF_SIZE)
log_frame["out-size"].append(env.OUT_BUFF_SIZE)
log_frame["oc-factor"].append(sched.oc_factor)
log_frame["ic-factor"].append(sched.ko_factor)
log_frame["h-factor"].append(sched.h_factor)
......@@ -333,16 +334,16 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
log_frame["h-threads"].append(sched.h_nthread)
log_frame["threaded"].append(True if sched.oc_nthread > 1 or sched.h_nthread > 1 else False)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir, True)
def skip_alu_unittest(print_ir):
mock = vta.mock
mock = env.mock
print("----- Skip ALU Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY,
vta.GEMM, mock.ALU, vta.DMA_COPY,
env.dma_copy, env.dma_copy,
env.gemm, mock.alu, env.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
......@@ -350,118 +351,118 @@ def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True)
log_frame["skip-alu-gops"].append(gops)
log_frame["skip-alu-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def gemm_unittest(print_ir):
mock = vta.mock
mock = env.mock
print("----- GEMM Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY,
vta.GEMM, mock.ALU, mock.DMA_COPY,
mock.dma_copy, mock.dma_copy,
env.gemm, mock.alu, mock.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
log_frame["gemm-gops"].append(gops)
log_frame["gemm-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def alu_unittest(print_ir):
mock = vta.mock
mock = env.mock
print("----- ALU Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY,
mock.GEMM, vta.ALU, mock.DMA_COPY,
mock.dma_copy, mock.dma_copy,
env.gemm, env.alu, mock.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
log_frame["alu-gops"].append(gops)
log_frame["alu-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def load_inp_unittest(print_ir):
mock = vta.mock
mock = env.mock
print("----- LoadInp Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
vta.DMA_COPY, mock.DMA_COPY,
mock.GEMM, mock.ALU, mock.DMA_COPY,
env.dma_copy, mock.dma_copy,
env.gemm, mock.alu, mock.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * wl.in_filter * wl.height *
wl.width * vta.INP_WIDTH / cost.mean) / float(10 ** 9)
wl.width * env.INP_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % (
cost.mean, gops, bandwith))
log_frame["ld-inp-gbits"].append(bandwith)
log_frame["ld-inp-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def load_wgt_unittest(print_ir):
mock = vta.mock
mock = env.mock
print("----- LoadWgt Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, vta.DMA_COPY,
mock.GEMM, mock.ALU, mock.DMA_COPY, print_ir,
mock.dma_copy, env.dma_copy,
env.gemm, mock.alu, mock.dma_copy, print_ir,
False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (wl.out_filter * wl.in_filter * wl.hkernel *
wl.wkernel * vta.WGT_WIDTH / cost.mean) / float(10 ** 9)
wl.wkernel * env.WGT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % (
cost.mean, gops, bandwith))
log_frame["ld-wgt-gbits"].append(bandwith)
log_frame["ld-wgt-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def store_out_unittest(print_ir):
mock = vta.mock
mock = env.mock
print("----- StoreOut Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY,
mock.GEMM, mock.ALU, vta.DMA_COPY, print_ir,
mock.dma_copy, mock.dma_copy,
env.gemm, mock.alu, env.dma_copy, print_ir,
False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * wl.out_filter * fout_height *
fout_width * vta.OUT_WIDTH / cost.mean) / float(10 ** 9)
fout_width * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % (
cost.mean, gops, bandwith))
log_frame["st-out-gbits"].append(bandwith)
log_frame["st-out-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def manual_unittest(print_ir):
# Manual section used to teak the components
mock = vta.mock
mock = env.mock
print("----- Manual Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY,
vta.GEMM, vta.ALU, mock.DMA_COPY, print_ir,
env.dma_copy, env.dma_copy,
env.gemm, env.alu, mock.dma_copy, print_ir,
False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (
cost.mean, gops))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
......
......@@ -8,34 +8,33 @@ from tvm.contrib import rpc, util
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf"
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
def test_gemm_packed(batch_size, channel, block):
data_shape = (batch_size//vta.VTA_BATCH,
channel//vta.VTA_BLOCK_IN,
vta.VTA_BATCH,
vta.VTA_BLOCK_IN)
weight_shape = (channel//vta.VTA_BLOCK_OUT,
channel//vta.VTA_BLOCK_IN,
vta.VTA_BLOCK_OUT,
vta.VTA_BLOCK_IN)
res_shape = (batch_size//vta.VTA_BATCH,
channel//vta.VTA_BLOCK_OUT,
vta.VTA_BATCH,
vta.VTA_BLOCK_OUT)
env = vta.get_env()
data_shape = (batch_size // env.BATCH,
channel // env.BLOCK_IN,
env.BATCH,
env.BLOCK_IN)
weight_shape = (channel // env.BLOCK_OUT,
channel // env.BLOCK_IN,
env.BLOCK_OUT,
env.BLOCK_IN)
res_shape = (batch_size // env.BATCH,
channel // env.BLOCK_OUT,
env.BATCH,
env.BLOCK_OUT)
num_ops = channel * channel * batch_size
ko = tvm.reduce_axis((0, channel//vta.VTA_BLOCK_IN), name='ko')
ki = tvm.reduce_axis((0, vta.VTA_BLOCK_IN), name='ki')
ko = tvm.reduce_axis((0, channel // env.BLOCK_IN), name='ko')
ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki')
data = tvm.placeholder(data_shape,
name="data",
dtype=inp_dtype)
dtype=env.inp_dtype)
weight = tvm.placeholder(weight_shape,
name="weight",
dtype=wgt_dtype)
dtype=env.wgt_dtype)
data_buf = tvm.compute(data_shape,
lambda *i: data(*i),
"data_buf")
......@@ -44,8 +43,8 @@ def test_gemm_packed(batch_size, channel, block):
"weight_buf")
res_gem = tvm.compute(res_shape,
lambda bo, co, bi, ci: tvm.sum(
data_buf[bo, ko, bi, ki].astype(out_dtype) *
weight_buf[co, ko, ci, ki].astype(out_dtype),
data_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
weight_buf[co, ko, ci, ki].astype(env.acc_dtype),
axis=[ko, ki]),
name="res_gem")
res_shf = tvm.compute(res_shape,
......@@ -55,14 +54,14 @@ def test_gemm_packed(batch_size, channel, block):
lambda *i: tvm.max(res_shf(*i), 0),
"res_max") #relu
res_min = tvm.compute(res_shape,
lambda *i: tvm.min(res_max(*i), (1<<(vta.VTA_INP_WIDTH-1))-1),
lambda *i: tvm.min(res_max(*i), (1<<(env.INP_WIDTH-1))-1),
"res_min") #relu
res = tvm.compute(res_shape,
lambda *i: res_min(*i).astype(inp_dtype),
lambda *i: res_min(*i).astype(env.inp_dtype),
name="res")
def verify(s, check_correctness=True):
mod = tvm.build(s, [data, weight, res], "ext_dev", target, name="gemm")
mod = vta.build(s, [data, weight, res], "ext_dev", target, name="gemm")
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("gemm.o"))
......@@ -76,29 +75,29 @@ def test_gemm_packed(batch_size, channel, block):
weight_orig = np.random.randint(
-128, 128, size=(channel, channel)).astype(weight.dtype)
data_packed = data_orig.reshape(
batch_size//vta.VTA_BATCH, vta.VTA_BATCH,
channel//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN).transpose((0, 2, 1, 3))
batch_size // env.BATCH, env.BATCH,
channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
weight_packed = weight_orig.reshape(
channel//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT,
channel//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN).transpose((0, 2, 1, 3))
channel // env.BLOCK_OUT, env.BLOCK_OUT,
channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_packed, ctx)
weight_arr = tvm.nd.array(weight_packed, ctx)
res_arr = tvm.nd.array(res_np, ctx)
res_ref = np.zeros(res_shape).astype(out_dtype)
for b in range(batch_size//vta.VTA_BATCH):
for i in range(channel//vta.VTA_BLOCK_OUT):
for j in range(channel//vta.VTA_BLOCK_IN):
res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(out_dtype),
weight_packed[i,j].T.astype(out_dtype))
res_ref = np.zeros(res_shape).astype(env.acc_dtype)
for b in range(batch_size // env.BATCH):
for i in range(channel // env.BLOCK_OUT):
for j in range(channel // env.BLOCK_IN):
res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(env.acc_dtype),
weight_packed[i,j].T.astype(env.acc_dtype))
res_ref = np.right_shift(res_ref, 8)
res_ref = np.clip(res_ref, 0, (1<<(vta.VTA_INP_WIDTH-1))-1).astype(res.dtype)
res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
time_f = f.time_evaluator("gemm", ctx, number=20)
cost = time_f(data_arr, weight_arr, res_arr)
res_unpack = res_arr.asnumpy().reshape(batch_size//vta.VTA_BATCH,
channel//vta.VTA_BLOCK_OUT,
vta.VTA_BATCH,
vta.VTA_BLOCK_OUT)
res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH,
channel // env.BLOCK_OUT,
env.BATCH,
env.BLOCK_OUT)
if check_correctness:
np.testing.assert_allclose(res_unpack, res_ref)
return cost
......@@ -111,17 +110,17 @@ def test_gemm_packed(batch_size, channel, block):
print_ir,
check_correctness):
s = tvm.create_schedule(res.op)
s[data_buf].set_scope(vta.SCOPE_INP)
s[weight_buf].set_scope(vta.SCOPE_WGT)
s[res_gem].set_scope(vta.SCOPE_OUT)
s[res_shf].set_scope(vta.SCOPE_OUT)
s[res_min].set_scope(vta.SCOPE_OUT)
s[res_max].set_scope(vta.SCOPE_OUT)
s[data_buf].set_scope(env.inp_scope)
s[weight_buf].set_scope(env.wgt_scope)
s[res_gem].set_scope(env.acc_scope)
s[res_shf].set_scope(env.acc_scope)
s[res_min].set_scope(env.acc_scope)
s[res_max].set_scope(env.acc_scope)
if block:
bblock = block // vta.VTA_BATCH
iblock = block // vta.VTA_BLOCK_IN
oblock = block // vta.VTA_BLOCK_OUT
bblock = block // env.BATCH
iblock = block // env.BLOCK_IN
oblock = block // env.BLOCK_OUT
xbo, xco, xbi, xci = s[res].op.axis
xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock)
store_pt = xb2
......@@ -162,91 +161,91 @@ def test_gemm_packed(batch_size, channel, block):
return verify(s, check_correctness)
def gemm_normal(print_ir):
mock = vta.mock
mock = env.mock
print("----- GEMM GFLOPS End-to-End Test-------")
def run_test(header, print_ir, check_correctness):
cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY, vta.GEMM, vta.ALU, vta.DMA_COPY,
env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy,
print_ir, check_correctness)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with tvm.build_config(add_lower_pass=vta.debug_mode(vta.DEBUG_DUMP_INSN)):
with vta.build_config():
run_test("NORMAL", print_ir, True)
print("")
def gevm_unittest(print_ir):
mock = vta.mock
mock = env.mock
print("----- GEMM Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY, vta.GEMM, mock.ALU, mock.DMA_COPY,
mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def alu_unittest(print_ir):
mock = vta.mock
mock = env.mock
print("----- ALU Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY, mock.GEMM, vta.ALU, mock.DMA_COPY,
mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def load_inp_unittest(print_ir):
mock = vta.mock
mock = env.mock
print("----- LoadInp Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
vta.DMA_COPY, mock.DMA_COPY, mock.GEMM, mock.ALU, mock.DMA_COPY, print_ir, False)
env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * vta.VTA_INP_WIDTH / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def load_wgt_unittest(print_ir):
mock = vta.mock
mock = env.mock
print("----- LoadWgt Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, vta.DMA_COPY, mock.GEMM, mock.ALU, mock.DMA_COPY, print_ir, False)
mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (channel * channel * vta.VTA_WGT_WIDTH / cost.mean) / float(10 ** 9)
bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def store_out_unittest(print_ir):
mock = vta.mock
mock = env.mock
print("----- StoreOut Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY, mock.GEMM, mock.ALU, vta.DMA_COPY,
mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * vta.VTA_OUT_WIDTH / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
......
......@@ -21,26 +21,26 @@ def my_clip(x, a_min, a_max):
host = "pynq"
port = 9091
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon"
print_ir = False
def test_vta_conv2d(key, batch_size, wl, profile=True):
data_shape = (batch_size, wl.in_filter//vta.VTA_BLOCK_IN,
wl.height, wl.width, vta.VTA_BLOCK_IN)
kernel_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, wl.in_filter//vta.VTA_BLOCK_IN,
wl.hkernel, wl.wkernel, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN)
bias_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, 1, 1, vta.VTA_BLOCK_OUT)
env = vta.get_env()
data_shape = (batch_size, wl.in_filter // env.BLOCK_IN,
wl.height, wl.width, env.BLOCK_IN)
kernel_shape = (wl.out_filter // env.BLOCK_OUT,
wl.in_filter // env.BLOCK_IN,
wl.hkernel, wl.wkernel,
env.BLOCK_OUT, env.BLOCK_IN)
bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
data = tvm.placeholder(data_shape, name="data", dtype=inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=wgt_dtype)
bias = tvm.placeholder(bias_shape, name="kernel", dtype=out_dtype)
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype)
res_conv = vta_conv2d.packed_conv2d(
data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride))
......@@ -73,14 +73,14 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
bias_orig = np.abs(bias_orig)
data_packed = data_orig.reshape(
batch_size, wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN,
batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
wl.height, wl.width).transpose((0, 1, 3, 4, 2))
kernel_packed = kernel_orig.reshape(
wl.out_filter//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT,
wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN,
wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT,
wl.in_filter // env.BLOCK_IN, env.BLOCK_IN,
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
bias_packed = bias_orig.reshape(
wl.out_filter//vta.VTA_BLOCK_OUT, 1, 1, vta.VTA_BLOCK_OUT)
wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT)
res_shape = topi.util.get_const_tuple(res.shape)
res_np = np.zeros(res_shape).astype(res.dtype)
......@@ -94,13 +94,13 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
(0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
if check_correctness:
res_ref = mx.nd.Convolution(
mx.nd.array(data_orig.astype(out_dtype), mx.cpu(0)),
mx.nd.array(kernel_orig.astype(out_dtype), mx.cpu(0)),
mx.nd.array(data_orig.astype(env.acc_dtype), mx.cpu(0)),
mx.nd.array(kernel_orig.astype(env.acc_dtype), mx.cpu(0)),
stride=(wl.hstride, wl.wstride),
kernel=(wl.hkernel, wl.wkernel),
num_filter=wl.out_filter,
no_bias=True,
pad=(wl.hpad, wl.wpad)).asnumpy().astype(out_dtype)
pad=(wl.hpad, wl.wpad)).asnumpy().astype(env.acc_dtype)
res_ref = res_ref >> 8
res_ref += bias_orig.reshape(wl.out_filter, 1, 1)
res_ref = np.clip(res_ref, 0, 127).astype("int8")
......@@ -110,10 +110,10 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
def conv_normal(print_ir):
print("----- CONV2D End-to-End Test-------")
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
with vta.build_config():
s = vta_conv2d.schedule_packed_conv2d([res])
if print_ir:
print(tvm.lower(s, [data, kernel, bias, res], simple_mode=True))
print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
cost = verify(s, True)
gops = (num_ops / cost.mean) / float(10 ** 9)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
......
......@@ -3,14 +3,15 @@ import vta
import os
from tvm.contrib import rpc, util
env = vta.get_env()
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf"
bit = "{}x{}x{}_{}bx{}b_{}_{}_{}_{}_100MHz_10ns.bit".format(
vta.VTA_BATCH, vta.VTA_BLOCK_IN, vta.VTA_BLOCK_OUT,
vta.VTA_INP_WIDTH, vta.VTA_WGT_WIDTH,
vta.VTA_LOG_UOP_BUFF_SIZE, vta.VTA_LOG_INP_BUFF_SIZE,
vta.VTA_LOG_WGT_BUFF_SIZE, vta.VTA_LOG_OUT_BUFF_SIZE)
env.BATCH, env.BLOCK_IN, env.BLOCK_OUT,
env.INP_WIDTH, env.WGT_WIDTH,
env.LOG_UOP_BUFF_SIZE, env.LOG_INP_BUFF_SIZE,
env.LOG_WGT_BUFF_SIZE, env.LOG_ACC_BUFF_SIZE)
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
bitstream = os.path.join(curr_path, "../../../../vta_bitstreams/bitstreams/", bit)
......
"""Unit test TPU's instructions """
"""Unit test VTA's instructions """
import tvm
import vta
import mxnet as mx
......@@ -9,44 +9,42 @@ from tvm.contrib import rpc, util
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf"
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
do_verify = True
print_ir = False
def test_save_load_out():
env = vta.get_env()
"""Test save/store output command"""
n = 4
x = tvm.placeholder(
(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(n, n, env.BATCH, env.BLOCK_OUT),
name="x",
dtype=out_dtype)
dtype=env.acc_dtype)
x_buf = tvm.compute(
(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: x(*i),
"x_buf")
# insert no-op that won't be optimized away
y_buf = tvm.compute(
(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: x_buf(*i)>>0,
"y_buf")
y = tvm.compute(
(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: y_buf(*i).astype(inp_dtype),
(n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: y_buf(*i).astype(env.inp_dtype),
"y")
# schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(vta.SCOPE_OUT)
s[x_buf].pragma(x_buf.op.axis[0], vta.DMA_COPY)
s[y_buf].set_scope(vta.SCOPE_OUT)
s[y_buf].pragma(y_buf.op.axis[0], vta.ALU)
s[y].pragma(y.op.axis[0], vta.DMA_COPY)
s[x_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], env.dma_copy)
def verify():
# build
m = tvm.build(s, [x, y], "ext_dev", target)
with vta.build_config(env.DEBUG_DUMP_INSN):
m = vta.build(s, [x, y], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
m.save(temp.relpath("load_act.o"))
......@@ -55,7 +53,7 @@ def test_save_load_out():
# verify
ctx = remote.ext_dev(0)
x_np = np.random.randint(
1, 10, size=(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(x.dtype)
1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
y_np = x_np.astype(y.dtype)
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
......@@ -65,38 +63,41 @@ def test_save_load_out():
if do_verify:
verify()
def test_padded_load():
"""Test padded load."""
env = vta.get_env()
# declare
n = 21
m = 20
pad_before = [0, 1, 0, 0]
pad_after = [1, 3, 0, 0]
x = tvm.placeholder(
(n, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(n, m, env.BATCH, env.BLOCK_OUT),
name="x",
dtype=out_dtype)
dtype=env.acc_dtype)
x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
# insert no-op that won't be optimized away
y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
vta.VTA_BATCH,
vta.VTA_BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
env.BATCH,
env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
y = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
vta.VTA_BATCH,
vta.VTA_BLOCK_OUT), lambda *i: y_buf(*i).astype(inp_dtype), "y")
env.BATCH,
env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
# schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(vta.SCOPE_OUT)
s[x_buf].pragma(x_buf.op.axis[0], vta.DMA_COPY)
s[y_buf].set_scope(vta.SCOPE_OUT)
s[y_buf].pragma(y_buf.op.axis[0], vta.ALU)
s[y].pragma(y.op.axis[0], vta.DMA_COPY)
s[x_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], env.dma_copy)
def verify():
# build
mod = tvm.build(s, [x, y], "ext_dev", target)
with vta.build_config(env.DEBUG_DUMP_INSN):
mod = vta.build(s, [x, y], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("padded_load.o"))
......@@ -105,11 +106,11 @@ def test_padded_load():
# verify
ctx = remote.ext_dev(0)
x_np = np.random.randint(1, 2, size=(
n, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(x.dtype)
n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
y_np = np.zeros((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
vta.VTA_BATCH,
vta.VTA_BLOCK_OUT)).astype(y.dtype)
env.BATCH,
env.BLOCK_OUT)).astype(y.dtype)
y_np[pad_before[0]:pad_before[0] + n,
pad_before[1]:pad_before[1] + m,
:] = x_np
......@@ -118,51 +119,50 @@ def test_padded_load():
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...")
if print_ir:
print(tvm.lower(s, [y, x], simple_mode=True))
if do_verify:
with tvm.build_config(add_lower_pass=vta.debug_mode(
vta.DEBUG_DUMP_INSN)):
verify()
print(vta.lower(s, [y, x], simple_mode=True))
def test_gemm():
"""Test GEMM."""
env = vta.get_env()
# declare
o = 4
n = 4
m = 4
x = tvm.placeholder((o, n, vta.VTA_BATCH, vta.VTA_BLOCK_IN), name="x", dtype=inp_dtype)
w = tvm.placeholder((m, n, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN), name="w", dtype=wgt_dtype)
x_buf = tvm.compute((o, n, vta.VTA_BATCH, vta.VTA_BLOCK_IN), lambda *i: x(*i), "x_buf")
w_buf = tvm.compute((m, n, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN), lambda *i: w(*i), "w_buf")
x = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="x", dtype=env.inp_dtype)
w = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="w", dtype=env.wgt_dtype)
x_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: x(*i), "x_buf")
w_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: w(*i), "w_buf")
ko = tvm.reduce_axis((0, n), name="ko")
ki = tvm.reduce_axis((0, vta.VTA_BLOCK_IN), name="ki")
ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki")
y_gem = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(o, m, env.BATCH, env.BLOCK_OUT),
lambda bo, co, bi, ci:
tvm.sum(x_buf[bo, ko, bi, ki].astype(out_dtype) *
w_buf[co, ko, ci, ki].astype(out_dtype),
tvm.sum(x_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
w_buf[co, ko, ci, ki].astype(env.acc_dtype),
axis=[ko, ki]),
name="y_gem")
y_shf = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: y_gem(*i)>>8,
name="y_shf")
y_max = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.max(y_shf(*i), 0),
"y_max") #relu
y_min = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: tvm.min(y_max(*i), (1<<(vta.VTA_INP_WIDTH-1))-1),
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.min(y_max(*i), (1<<(env.INP_WIDTH-1))-1),
"y_min") #relu
y = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: y_min(*i).astype(inp_dtype),
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: y_min(*i).astype(env.inp_dtype),
name="y")
def verify(s):
mod = tvm.build(s, [x, w, y], "ext_dev", target)
mod = vta.build(s, [x, w, y], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("gemm.o"))
......@@ -171,21 +171,21 @@ def test_gemm():
# verify
ctx = remote.ext_dev(0)
x_np = np.random.randint(
-128, 128, size=(o, n, vta.VTA_BATCH, vta.VTA_BLOCK_IN)).astype(x.dtype)
-128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(x.dtype)
w_np = np.random.randint(
-128, 128, size=(m, n, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN)).astype(w.dtype)
y_np = np.zeros((o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(y.dtype)
-128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(w.dtype)
y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype)
x_nd = tvm.nd.array(x_np, ctx)
w_nd = tvm.nd.array(w_np, ctx)
y_nd = tvm.nd.array(y_np, ctx)
y_np = y_np.astype(out_dtype)
y_np = y_np.astype(env.acc_dtype)
for b in range(o):
for i in range(m):
for j in range(n):
y_np[b,i,:] += np.dot(x_np[b,j,:].astype(out_dtype),
w_np[i,j].T.astype(out_dtype))
y_np[b,i,:] += np.dot(x_np[b,j,:].astype(env.acc_dtype),
w_np[i,j].T.astype(env.acc_dtype))
y_np = np.right_shift(y_np, 8)
y_np = np.clip(y_np, 0, (1<<(vta.VTA_INP_WIDTH-1))-1).astype(y.dtype)
y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
f(x_nd, w_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...")
......@@ -194,19 +194,19 @@ def test_gemm():
# default schedule with no smt
s = tvm.create_schedule(y.op)
# set the scope of the SRAM buffers
s[x_buf].set_scope(vta.SCOPE_INP)
s[w_buf].set_scope(vta.SCOPE_WGT)
s[y_gem].set_scope(vta.SCOPE_OUT)
s[y_shf].set_scope(vta.SCOPE_OUT)
s[y_max].set_scope(vta.SCOPE_OUT)
s[y_min].set_scope(vta.SCOPE_OUT)
s[x_buf].set_scope(env.SCOPE_INP)
s[w_buf].set_scope(env.SCOPE_WGT)
s[y_gem].set_scope(env.acc_scope)
s[y_shf].set_scope(env.acc_scope)
s[y_max].set_scope(env.acc_scope)
s[y_min].set_scope(env.acc_scope)
# set pragmas for DMA transfer and ALU ops
s[x_buf].pragma(s[x_buf].op.axis[0], vta.DMA_COPY)
s[w_buf].pragma(s[w_buf].op.axis[0], vta.DMA_COPY)
s[y_shf].pragma(s[y_shf].op.axis[0], vta.ALU)
s[y_max].pragma(s[y_max].op.axis[0], vta.ALU)
s[y_min].pragma(s[y_min].op.axis[0], vta.ALU)
s[y].pragma(s[y].op.axis[0], vta.DMA_COPY)
s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
s[y_max].pragma(s[y_max].op.axis[0], env.alu)
s[y_min].pragma(s[y_min].op.axis[0], env.alu)
s[y].pragma(s[y].op.axis[0], env.dma_copy)
# tensorization
s[y_gem].reorder(
ko,
......@@ -215,23 +215,22 @@ def test_gemm():
s[y_gem].op.axis[2],
s[y_gem].op.axis[3],
ki)
s[y_gem].tensorize(s[y_gem].op.axis[2], vta.GEMM)
s[y_gem].tensorize(s[y_gem].op.axis[2], env.GEMM)
if print_ir:
print(tvm.lower(s, [x, w, y], simple_mode=True))
print(vta.lower(s, [x, w, y], simple_mode=True))
if do_verify:
with tvm.build_config(
add_lower_pass=vta.debug_mode(vta.DEBUG_DUMP_INSN)):
with vta.build_config(env.DEBUG_DUMP_INSN):
verify(s)
def test_smt():
# test smt schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(vta.SCOPE_INP)
s[w_buf].set_scope(vta.SCOPE_WGT)
s[y_gem].set_scope(vta.SCOPE_OUT)
s[y_shf].set_scope(vta.SCOPE_OUT)
s[y_max].set_scope(vta.SCOPE_OUT)
s[y_min].set_scope(vta.SCOPE_OUT)
s[x_buf].set_scope(env.SCOPE_INP)
s[w_buf].set_scope(env.SCOPE_WGT)
s[y_gem].set_scope(env.acc_scope)
s[y_shf].set_scope(env.acc_scope)
s[y_max].set_scope(env.acc_scope)
s[y_min].set_scope(env.acc_scope)
abo, aco, abi, aci = s[y].op.axis
abo1, abo2 = s[y].split(abo, nparts=2)
s[y].bind(abo1, tvm.thread_axis("cthread"))
......@@ -246,20 +245,19 @@ def test_gemm():
s[y_gem].op.axis[2],
s[y_gem].op.axis[3],
ki)
s[y_gem].tensorize(s[y_gem].op.axis[2], vta.GEMM)
s[y_shf].pragma(s[y_shf].op.axis[0], vta.ALU)
s[y_max].pragma(s[y_max].op.axis[0], vta.ALU)
s[y_min].pragma(s[y_min].op.axis[0], vta.ALU)
s[y_gem].tensorize(s[y_gem].op.axis[2], env.GEMM)
s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
s[y_max].pragma(s[y_max].op.axis[0], env.alu)
s[y_min].pragma(s[y_min].op.axis[0], env.alu)
s[x_buf].compute_at(s[y_gem], ko)
s[x_buf].pragma(s[x_buf].op.axis[0], vta.DMA_COPY)
s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
s[w_buf].compute_at(s[y_gem], ko)
s[w_buf].pragma(s[w_buf].op.axis[0], vta.DMA_COPY)
s[y].pragma(abo2, vta.DMA_COPY)
s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y].pragma(abo2, env.dma_copy)
if print_ir:
print(tvm.lower(s, [x, y, w], simple_mode=True))
print(vta.lower(s, [x, y, w], simple_mode=True))
if do_verify:
with tvm.build_config(
add_lower_pass=vta.debug_mode(vta.DEBUG_DUMP_INSN)):
with vta.build_config(env.DEBUG_DUMP_INSN):
verify(s)
test_schedule1()
......@@ -267,62 +265,64 @@ def test_gemm():
def test_alu(tvm_op, np_op=None, use_imm=False):
"""Test ALU"""
env = vta.get_env()
m = 8
n = 8
imm = np.random.randint(1,5)
# compute
a = tvm.placeholder(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(m, n, env.BATCH, env.BLOCK_OUT),
name="a",
dtype=out_dtype)
dtype=env.acc_dtype)
a_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i),
"a_buf") #DRAM->SRAM
if use_imm:
res_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm_op(a_buf(*i), imm),
"res_buf") #compute
else:
b = tvm.placeholder(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(m, n, env.BATCH, env.BLOCK_OUT),
name="b",
dtype=out_dtype)
dtype=env.acc_dtype)
b_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: b(*i),
"b_buf") #DRAM->SRAM
res_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm_op(a_buf(*i), b_buf(*i)),
"res_buf") #compute
res = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: res_buf(*i).astype(inp_dtype),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_buf(*i).astype(env.inp_dtype),
"res") #SRAM->DRAM
# schedule
s = tvm.create_schedule(res.op)
s[a_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM
s[res_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[res_buf].pragma(res_buf.op.axis[0], vta.ALU) # compute
s[res].pragma(res.op.axis[0], vta.DMA_COPY) # SRAM->DRAM
s[a_buf].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[res_buf].set_scope(env.acc_scope) # SRAM
s[res_buf].pragma(res_buf.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
if use_imm:
if print_ir:
print(tvm.lower(s, [a, res], simple_mode=True))
print(vta.lower(s, [a, res], simple_mode=True))
else:
s[b_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[b_buf].pragma(b_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM
s[b_buf].set_scope(env.acc_scope) # SRAM
s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
if print_ir:
print(tvm.lower(s, [a, b, res], simple_mode=True))
print(vta.lower(s, [a, b, res], simple_mode=True))
def verify():
# build
if use_imm:
mod = tvm.build(s, [a, res], "ext_dev", target)
else:
mod = tvm.build(s, [a, b, res], "ext_dev", target)
with vta.build_config():
if use_imm:
mod = vta.build(s, [a, res], "ext_dev", target)
else:
mod = vta.build(s, [a, b, res], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o"))
......@@ -331,17 +331,17 @@ def test_alu(tvm_op, np_op=None, use_imm=False):
# verify
ctx = remote.ext_dev(0)
a_np = np.random.randint(
-16, 16, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(a.dtype)
-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
if use_imm:
res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm)
else:
b_np = np.random.randint(
-16, 16, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(b.dtype)
-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(b.dtype)
res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np)
res_np = res_np.astype(res.dtype)
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(res.dtype), ctx)
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if use_imm:
f(a_nd, res_nd)
else:
......@@ -355,44 +355,46 @@ def test_alu(tvm_op, np_op=None, use_imm=False):
def test_relu():
"""Test RELU on ALU"""
env = vta.get_env()
m = 8
n = 8
# compute
a = tvm.placeholder(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(m, n, env.BATCH, env.BLOCK_OUT),
name="a",
dtype=out_dtype)
dtype=env.acc_dtype)
a_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i),
"a_buf") # DRAM->SRAM
max_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.max(a_buf(*i), 0),
"res_buf") # relu
min_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: tvm.min(max_buf(*i), (1<<(vta.VTA_INP_WIDTH-1))-1),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.min(max_buf(*i), (1<<(env.INP_WIDTH-1))-1),
"max_buf") # relu
res = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: min_buf(*i).astype(inp_dtype),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: min_buf(*i).astype(env.inp_dtype),
"min_buf") # SRAM->DRAM
# schedule
s = tvm.create_schedule(res.op)
s[a_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM
s[max_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[min_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[max_buf].pragma(max_buf.op.axis[0], vta.ALU) # compute
s[min_buf].pragma(min_buf.op.axis[0], vta.ALU) # compute
s[res].pragma(res.op.axis[0], vta.DMA_COPY) # SRAM->DRAM
s[a_buf].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[max_buf].set_scope(env.acc_scope) # SRAM
s[min_buf].set_scope(env.acc_scope) # SRAM
s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute
s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
if print_ir:
print(tvm.lower(s, [a, res], simple_mode=True))
print(vta.lower(s, [a, res], simple_mode=True))
def verify():
# build
mod = tvm.build(s, [a, res], "ext_dev", target)
with vta.build_config(env.DEBUG_DUMP_INSN):
mod = vta.build(s, [a, res], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o"))
......@@ -401,11 +403,11 @@ def test_relu():
# verify
ctx = remote.ext_dev(0)
a_np = np.random.randint(
-256, 256, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(a.dtype)
res_np = np.clip(a_np, 0, (1<<(vta.VTA_INP_WIDTH-1))-1).astype(res.dtype)
-256, 256, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
res_np = np.clip(a_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(res.dtype), ctx)
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
print("\tFinished verification...")
......@@ -415,45 +417,47 @@ def test_relu():
def test_shift_and_scale():
"""Test shift and scale on ALU"""
env = vta.get_env()
m = 8
n = 8
imm_shift = np.random.randint(-10,10)
imm_scale = np.random.randint(1,5)
# compute
a = tvm.placeholder(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
name="a", dtype=out_dtype)
(m, n, env.BATCH, env.BLOCK_OUT),
name="a", dtype=env.acc_dtype)
a_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i),
"a_buf") # DRAM->SRAM
res_shift = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a_buf(*i)+imm_shift,
"res_shift") # compute
res_scale = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_shift(*i)>>imm_scale,
"res_scale") # compute
res = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: res_scale(*i).astype(inp_dtype),
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_scale(*i).astype(env.inp_dtype),
"res") # SRAM->DRAM
# schedule
s = tvm.create_schedule(res.op)
s[a_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[res_shift].set_scope(vta.SCOPE_OUT) # SRAM
s[res_scale].set_scope(vta.SCOPE_OUT) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM
s[res_shift].pragma(res_shift.op.axis[0], vta.ALU) # compute
s[res_scale].pragma(res_scale.op.axis[0], vta.ALU) # compute
s[res].pragma(res.op.axis[0], vta.DMA_COPY) # SRAM->DRAM
s[a_buf].set_scope(env.acc_scope) # SRAM
s[res_shift].set_scope(env.acc_scope) # SRAM
s[res_scale].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[res_shift].pragma(res_shift.op.axis[0], env.alu) # compute
s[res_scale].pragma(res_scale.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
if print_ir:
print(tvm.lower(s, [a, res], simple_mode=True))
print(vta.lower(s, [a, res], simple_mode=True))
def verify():
# build
mod = tvm.build(s, [a, res], "ext_dev", target)
mod = vta.build(s, [a, res], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o"))
......@@ -462,12 +466,12 @@ def test_shift_and_scale():
# verify
ctx = remote.ext_dev(0)
a_np = np.random.randint(
-10, 10, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(a.dtype)
-10, 10, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
res_np = np.right_shift((a_np + imm_shift), imm_scale)
res_np = res_np.astype(res.dtype)
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(res.dtype), ctx)
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
print("\tFinished verification...")
......@@ -476,10 +480,10 @@ def test_shift_and_scale():
verify()
if __name__ == "__main__":
print("Padded load test")
test_padded_load()
print("Load/store test")
test_save_load_out()
print("Padded load test")
test_padded_load()
# print("GEMM test")
# test_gemm()
print("Max immediate")
......
import vta
def test_env():
env = vta.get_env()
mock = env.mock
assert mock.alu == "skip_alu"
if __name__ == "__main__":
test_env()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment