Commit 3818b2a2 by Thierry Moreau Committed by Tianqi Chen

[VTA][Relay] Relay Compilation + AutoTVM compatible operator libraries for VTA (#3135)

parent 813a3d52
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
......@@ -14,24 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Reuse conv2d schedule from ARM CPU"""
PROJROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../../" && pwd )"
import tvm
from topi.nn import conv2d, conv2d_alter_layout
from topi import generic
@conv2d.register(["vtacpu", "vta"])
def compute(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
return conv2d(*args, **kwargs)
@generic.schedule_conv2d_nchw.register(["vtacpu", "vta"])
def schedule(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
return generic.schedule_conv2d_nchw(*args, **kwargs)
@conv2d_alter_layout.register(["vtacpu", "vta"])
def alter(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
return conv2d_alter_layout(*args, **kwargs)
export PYTHONPATH=${PYTHONPATH}:${PROJROOT}/python:${PROJROOT}/vta/python
export PYTHONPATH=${PYTHONPATH}:/home/xilinx/pynq
python3 -m vta.exec.rpc_server --tracker fleet:9190 --key pynq
......@@ -215,7 +215,10 @@ subsection_order = ExplicitOrder(
'../tutorials/autotvm',
'../tutorials/dev',
'../tutorials/topi',
'../tutorials/deployment'])
'../tutorials/deployment',
'../vta/tutorials/frontend',
'../vta/tutorials/optimize',
'../vta/tutorials/autotvm'])
def generate_doxygen_xml(app):
"""Run the doxygen make commands if we're on the ReadTheDocs server"""
......
......@@ -78,7 +78,7 @@ reg.register_pattern("log_softmax", OpPattern.OPAQUE)
def compute_dense(attrs, inputs, _):
"""Compute definition of dense"""
if attrs.get_bool("use_bias"):
return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2])
return topi.nn.dense(inputs[0], inputs[1], inputs[2])
return topi.nn.dense(inputs[0], inputs[1])
@reg.register_schedule("dense")
......@@ -114,25 +114,25 @@ def compute_conv2d(attrs, inputs, _):
if groups == 1 and layout == 'NCHW4c' and inputs[0].dtype == 'int8':
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype=out_dtype)
dilation, layout, out_dtype)
# pylint: enable=assignment-from-no-return
elif groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)
elif layout == "NCHW" and \
groups == get_const_int(inputs[0].shape[1]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
elif layout in ["NCHW", "NCHW4c"]:
out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
out_dtype=out_dtype)
out_dtype)
elif layout == "NHWC" and \
kernel_layout == "HWOI" and \
groups == get_const_int(inputs[0].shape[3]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nhwc(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
else:
raise ValueError("not support arbitrary group number for now")
......
......@@ -65,18 +65,19 @@ def expr2graph(expr, target_ops, node_dict, node_list):
% op_name)
topi_funcs += OP2COMPUTE[op_name]
env.reset(topi_funcs)
_expr2graph_impl(expr, target_ops, node_dict, node_list)
task_pos = 0
for node_entry in node_list:
if node_entry["op"] in target_ops:
task_name, args = env.task_collection[task_pos]
task = autotvm.task.create(task_name, args,
target="llvm",
target_host=None,
template_key='direct')
node_entry["workloads"] = [task.workload]
node_entry["topi_op"] = [task_name]
task_pos += 1
with env:
_expr2graph_impl(expr, target_ops, node_dict, node_list)
task_pos = 0
for node_entry in node_list:
if node_entry["op"] in target_ops:
task_name, args = env.task_collection[task_pos]
task = autotvm.task.create(task_name, args,
target="llvm",
target_host=None,
template_key='direct')
node_entry["workloads"] = [task.workload]
node_entry["topi_op"] = [task_name]
task_pos += 1
def _expr2graph_impl(expr, target_ops, node_dict, node_list):
......
......@@ -86,7 +86,6 @@ class LocalBuilder(Builder):
build_func = ndk.create_shared
else:
raise ValueError("Invalid build_func" + build_func)
self.build_func = _wrap_build_func(build_func)
self.executor = LocalExecutor(timeout=timeout)
self.tmp_dir = tempfile.mkdtemp()
......@@ -360,8 +359,14 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
if cuda_arch:
set_cuda_target_arch(cuda_arch)
with build_config(**opts):
func = build(s, args, target_host=task.target_host)
# if target is vta, we need to use vta build
if hasattr(measure_input.target, 'device_name') and \
measure_input.target.device_name == 'vta':
import vta
func = vta.build(s, args, target_host=task.target_host)
else:
with build_config(**opts):
func = build(s, args, target_host=task.target_host)
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
......@@ -452,6 +457,12 @@ def run_through_rpc(measure_input, build_result,
try:
# upload built module
remote = request_remote(*remote_args)
# Program the FPGA every single time when targeting VTA
if hasattr(measure_input.target, 'device_name') and \
measure_input.target.device_name == 'vta':
from vta import program_fpga, reconfig_runtime
program_fpga(remote, None)
reconfig_runtime(remote)
remote.upload(build_result.filename)
func = remote.load_module(os.path.split(build_result.filename)[1])
ctx = remote.context(str(measure_input.target), 0)
......
......@@ -19,23 +19,22 @@
Decorator and utilities for the integration with TOPI and NNVM
"""
import threading
import warnings
import logging
from ... import target as _target
from .task import create
from .topi_integration import TaskExtractEnv
logger = logging.getLogger('autotvm')
def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
def extract_from_graph(graph, shape, dtype, target, symbols, params=None, target_host=None):
""" Extract tuning tasks from a nnvm graph.
This function collects tuning tasks by building the graph
with a "tracing" target and tracing all the calls to topi.
and trace all the calls to topi.
Parameters
----------
......@@ -49,6 +48,8 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
The compilation target
symbols : Array of nnvm.symbol
Array of nnvm symbols want to be tuned
params : dict of str to NDArray
The parameter dictionary.
target_host: tvm.target.Target
The host compilation target
......@@ -63,8 +64,8 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
env = TaskExtractEnv.get()
#NOTE: To add more symbols, you only need to change the following lists
#nnvm symbol -> topi compute
# NOTE: To add more symbols, you only need to change the following lists
# nnvm symbol -> topi compute
SYMBOL2TOPI = {
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
......@@ -81,29 +82,40 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
nnvm.compiler.engine.clear_cache()
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
logger.disabled = old_state
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
nnvm.compiler.engine.clear_cache()
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=nnvm.compiler.build,
args=(graph,
target,
shape,
dtype,
params,
target_host))
build_thread.start()
build_thread.join()
logger.disabled = old_state
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
print("[Warning] Invalid shape during AutoTVM task creation")
return tasks
def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_host=None):
def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, params, target_host=None):
""" Extract tuning tasks from multiple nnvm graphs.
This function is the multiple graph version of extract_from_graph
......@@ -120,6 +132,8 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_
The compilation target
symbols : Array of nnvm.symbol
Array of nnvm symbols want to be tuned
params : dict of str to NDArray
The parameter dictionary.
target_host: tvm.target.Target
The host compilation target
......@@ -152,25 +166,35 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
nnvm.compiler.engine.clear_cache()
for graph, shape, dtype in zip(graphs, shapes, dtypes):
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
logger.disabled = old_state
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
for graph, shape, dtype in zip(graphs, shapes, dtypes):
nnvm.compiler.engine.clear_cache()
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=nnvm.compiler.build,
args=(graph,
target,
shape,
dtype,
params,
target_host))
build_thread.start()
build_thread.join()
logger.disabled = old_state
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
print("[Warning] Invalid shape during AutoTVM task creation")
return tasks
......@@ -25,14 +25,30 @@ import warnings
import logging
from ... import target as _target
from .task import create
from .topi_integration import TaskExtractEnv
logger = logging.getLogger('autotvm')
# TODO(moreau89) find a more elegant way to build for VTAs
def _build(func,
target,
target_host,
params):
""" Helper to build VTA properly.
"""
from tvm import relay
if hasattr(target, 'device_name') and target.device_name == "vta":
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
import vta
with vta.build_config():
return relay.build(func, target, target_host, params)
# default case
return relay.build(func, target, target_host, params)
def extract_from_program(func, params, ops, target, target_host=None):
""" Extract tuning tasks from a relay program.
......@@ -57,11 +73,12 @@ def extract_from_program(func, params, ops, target, target_host=None):
task: Array of autotvm.task.Task
collected tasks
"""
env = TaskExtractEnv.get()
import tvm.relay.op
from tvm import relay
import topi
env = TaskExtractEnv.get()
# NOTE: To add more ops, you only need to change the following lists
# relay op -> topi compute
OP2TOPI = {
......@@ -81,30 +98,33 @@ def extract_from_program(func, params, ops, target, target_host=None):
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=relay.build, args=(func,
tracing_target,
target_host,
params))
build_thread.start()
build_thread.join()
logger.disabled = old_state
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=_build,
args=(func,
target,
target_host,
params))
build_thread.start()
build_thread.join()
logger.disabled = old_state
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
warnings.warn("Invalid shape during AutoTVM task creation")
return tasks
......@@ -155,30 +175,33 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
for func, param in zip(funcs, params):
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=relay.build, args=(func,
tracing_target,
target_host,
params))
build_thread.start()
build_thread.join()
logger.disabled = old_state
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
for func, param in zip(funcs, params):
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=my_build,
args=(func,
target,
target_host,
params))
build_thread.start()
build_thread.join()
logger.disabled = old_state
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
print("[Warning] Invalid shape during AutoTVM task creation")
return tasks
......@@ -27,7 +27,7 @@ tuple.
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
from ... import _api_internal, tensor, placeholder, create_schedule
from ... import _api_internal, tensor, placeholder
from .task import args_to_workload, dispatcher, register
from ..util import get_const_tuple
......@@ -73,6 +73,7 @@ def deserialize_args(args):
class TaskExtractEnv:
"""Global environment for extracting tuning tasks from nnvm graph"""
current = None
registered = None
def __init__(self, allow_duplicate=False):
import topi
......@@ -106,47 +107,65 @@ class TaskExtractEnv:
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
}
# function reflection for tracing
self.func_to_reflection = {
topi.nn.conv2d: lambda x: setattr(topi.nn, 'conv2d', x),
topi.nn.conv2d_NCHWc: lambda x: setattr(topi.nn, 'conv2d_NCHWc', x),
topi.nn.depthwise_conv2d_nchw: lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x),
topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x),
topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x),
topi.nn.bitserial_conv2d_nchw: lambda x: setattr(topi.nn, 'bitserial_conv2d_nchw', x),
topi.nn.bitserial_conv2d_nhwc: lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x),
topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x),
topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 'deformable_conv2d_nchw', x),
}
self.allow_duplicate = allow_duplicate
self._register_tracing()
self._register_topi_task()
self.task_collection = []
self.wanted_topi_funcs = list(self.topi_to_task.keys())
self.modified_funcs = []
def __enter__(self):
self.task_collection = []
self.modified_funcs = []
def _register_tracing(self):
"""Register tracing function to track the topi function call"""
# register topi compute for "tracing" target
for topi_compute in self.topi_to_task:
for topi_compute in self.wanted_topi_funcs:
def _local_scope(compute_func):
"""start a scope to hold the local function in for loop"""
@compute_func.register("tracing", )
def _tracing_topi_compute(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when" \
"kwargs is used in TOPI function call." \
def _tracing_wrapper(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when " \
"kwargs is used in TOPI function call. " \
"Please modify it to use only positional args."
if compute_func in self.wanted_topi_funcs: # record this call
key = (self.topi_to_task[compute_func], serialize_args(args))
if self.allow_duplicate or key not in self.task_collection:
self.task_collection.append(key)
return compute_func.fdefault(*args)
key = (self.topi_to_task[compute_func], serialize_args(args))
if self.allow_duplicate or key not in self.task_collection:
self.task_collection.append(key)
return compute_func(*args, **kwargs)
self.func_to_reflection[compute_func](_tracing_wrapper)
self.modified_funcs.append(compute_func)
_local_scope(topi_compute)
# register topi schedule for "tracing" target
for topi_compute in self.topi_to_task:
for topi_schedule in self.topi_to_schedule[topi_compute]:
def _local_scope_(schedule_func):
"""start a scope to hold the local function in for loop"""
return self
@schedule_func.register("tracing", )
def _tracing_topi_compute(outs):
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
return create_schedule([x.op for x in outs])
_local_scope_(topi_schedule)
def __exit__(self, exc_type, exc_val, exc_tb):
# revert modification
for func in self.modified_funcs:
self.func_to_reflection[func](func)
def _register_topi_task(self):
"""register tuning wrapper for topi function"""
import topi
# Avoid double registration for certain targets
if TaskExtractEnv.registered:
return
TaskExtractEnv.registered = True
# Tuning wrapper for topi functions
@register("topi_nn_conv2d")
def _topi_nn_conv2d(*args, **kwargs):
......@@ -190,7 +209,11 @@ class TaskExtractEnv:
def _topi_nn_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
data, weight, bias, _ = args
if len(args) > 2:
data, weight, bias = args[:3]
else:
data, weight = args
bias = None
C = topi.nn.dense(*args, **kwargs)
s = topi.generic.schedule_dense([C])
if bias is not None:
......
......@@ -44,7 +44,7 @@ PACKAGE_VERSION = {
'opencl': "v0.02",
'mali': "v0.05",
'vta': "v0.04",
'vta': "v0.05",
}
logger = logging.getLogger('autotvm')
......
......@@ -56,7 +56,7 @@ def compute_dense(attrs, inputs, out_type, target):
"""Compute definition of dense"""
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
return [topi.nn.dense(inputs[0], inputs[1], out_dtype=out_dtype)]
return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)]
@reg.register_schedule("nn.dense")
......@@ -119,21 +119,21 @@ def compute_conv2d(attrs, inputs, out_type, target):
if groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype=out_dtype)
dilation, layout, out_dtype)
elif layout == "NCHW" and \
get_const_int(inputs[1].shape[0]) == groups and \
get_const_int(inputs[1].shape[1]) == 1:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
elif layout == "NHWC" and \
kernel_layout == "HWOI" and\
get_const_int(inputs[1].shape[2]) == groups and \
get_const_int(inputs[1].shape[3]) == 1:
out = topi.nn.depthwise_conv2d_nhwc(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
elif layout in ['NCHW', 'NCHW4c']:
out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
out_dtype=out_dtype)
out_dtype)
else:
raise ValueError("not support arbitrary group number for now")
return [out]
......
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=unused-argument
#pylint: disable=unused-argument,inconsistent-return-statements
"""Internal module for registering attribute for annotation."""
from __future__ import absolute_import
import warnings
......@@ -171,7 +171,7 @@ def conv2d_rewrite(ref_call, new_args, ctx):
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
assert rhs_kind is None
......@@ -191,7 +191,8 @@ def check_to_skip():
return False
@register_annotate_function("nn.dense")
# TODO(tmoreau89,ziheng) need to include an option to turn off dense quant
# @register_annotate_function("nn.dense")
def dense_rewrite(ref_call, new_args, ctx):
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
dense will be quantized to weight field. Output would be in activation field."""
......@@ -201,13 +202,14 @@ def dense_rewrite(ref_call, new_args, ctx):
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
......@@ -222,6 +224,7 @@ def multiply_rewrite(ref_call, new_args, ctx):
if lhs_kind is None and rhs_kind is None:
return None
if lhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and rhs_kind is None:
# quantize lhs to INPUT field
if lhs_kind == QAnnotateKind.ACTIVATION:
......@@ -230,6 +233,7 @@ def multiply_rewrite(ref_call, new_args, ctx):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError
......@@ -244,22 +248,35 @@ def add_rewrite(ref_call, new_args, ctx):
if lhs_kind is None and rhs_kind is None:
return None
if lhs_kind is None and rhs_kind is not None:
# quantize lhs to INPUT field if it is normal expression
assert rhs_kind == QAnnotateKind.INPUT
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
if lhs_kind is not None and rhs_kind is None:
if isinstance(rhs_expr, _expr.Constant):
# quantize rhs to WEIGHT field if it is Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
assert lhs_kind == QAnnotateKind.ACTIVATION
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
# quantize rhs to INPUT field if both lhs and rhs are ACTIVATION
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError
if lhs_kind is not None and rhs_kind is not None:
if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT:
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
# quantize rhs to INPUT field if both lhs and rhs are ACTIVATION
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
@register_annotate_function("stop_fusion")
......@@ -294,6 +311,7 @@ register_annotate_function("clip", identity_rewrite)
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
register_annotate_function("nn.avg_pool2d", identity_rewrite)
register_annotate_function("annotation.stop_fusion", identity_rewrite)
def pool2d_rewrite(ref_call, new_args, ctx):
......@@ -307,6 +325,7 @@ def pool2d_rewrite(ref_call, new_args, ctx):
return None
if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
......@@ -314,6 +333,23 @@ def pool2d_rewrite(ref_call, new_args, ctx):
register_annotate_function("nn.max_pool2d", pool2d_rewrite)
@register_annotate_function("annotation.force_cast")
def force_cast_rewrite(ref_call, new_args, ctx):
"""Rewrite function to force cast"""
if check_to_skip():
return None
expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
return new_args[0]
if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate"""
......@@ -333,3 +369,71 @@ def concatenate_rewrite(ref_call, new_args, ctx):
expr_list[i] = attach_simulated_quantize(expr_list[i], QAnnotateKind.ACTIVATION)
expr = _forward_op(ref_call, [_expr.Tuple(expr_list)])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
# Graph rewrite function registration for VTA target
def register_vta_rewrite(op_name, frewrite=None, level=10):
def _register(func):
return _op.op._Register(op_name, "FQVTARewrite", func, level)
return _register(frewrite) if frewrite is not None else _register
@register_relay_node
class QVTAExpr(_expr.TempExpr):
def __init__(self, expr):
self.__init_handle_by_constructor__(
_quantize.make_vta_expr, expr)
def realize(self):
return _quantize.temp_expr_realize(self)
def vta_expr_check(expr):
if isinstance(expr, QVTAExpr):
return True, expr.expr
return False, expr
@register_vta_rewrite("nn.conv2d")
def conv2d_vta_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv2d for VTA target"""
actx = annotate_context()
if current_qconfig().skip_conv_layers is not None:
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if actx.conv2d_counter() in skipped_indices:
actx.count_conv2d()
return None
actx.count_conv2d()
data_cond, data = vta_expr_check(new_args[0])
kernel_cond, kernel = vta_expr_check(new_args[1])
assert not kernel_cond
if data_cond:
data = new_args[0].realize()
ret = _forward_op(ref_call, [data, kernel])
return QVTAExpr(ret)
def identity_vta_rewrite(ref_call, new_args, ctx):
cond, expr = vta_expr_check(new_args[0])
if cond:
return QVTAExpr(_forward_op(ref_call, [expr]))
return None
register_vta_rewrite("nn.relu", identity_vta_rewrite)
register_vta_rewrite("nn.max_pool2d", identity_vta_rewrite)
@register_vta_rewrite("add")
def add_vta_rewrite(ref_call, new_args, ctx):
"""Rewrite function for ewise add for VTA target"""
lhs_cond, lhs = vta_expr_check(new_args[0])
rhs_cond, rhs = vta_expr_check(new_args[1])
if lhs_cond and rhs_cond:
lhs = new_args[0].realize()
rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs])
elif lhs_cond and not rhs_cond:
return QVTAExpr(_forward_op(ref_call, [lhs, rhs]))
return None
......@@ -124,7 +124,9 @@ def current_qconfig():
"""Get the current quantization configuration."""
return _quantize._GetCurrentQConfig()
# TODO(tmoreau89, ZihengJiang) the skip parameters are
# hacky - we should explore a more future-proof way to
# skip operators based on pattern matching
def qconfig(**kwargs):
"""Configure the quantization behavior by setting config variables.
......@@ -279,6 +281,17 @@ def realize():
return _quantize.QuantizeRealize()
def rewrite_for_vta():
"""Performs rewriting for VTA target.
Returns
-------
ret: tvm.relay.Pass
The registered pass for VTA rewrite.
"""
return _quantize.QuantizeRewriteForVTA()
def _bind_params(func, params):
"""Bind the params to the expression.
"""
......@@ -337,15 +350,19 @@ def quantize(graph, params=None, dataset=None):
calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="QuantizeCalibrate")
quantize_seq = _transform.Sequential([annotate(),
calibrate_pass,
realize(),
_transform.FoldConstant()])
with annotate_context():
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
mod = quantize_seq(mod)
# Quantize pass list
quant_passes = [annotate(),
calibrate_pass,
realize(),
_transform.FoldConstant()]
if current_qconfig().store_lowbit_output:
quant_passes = [rewrite_for_vta()] + quant_passes
quantize_seq = _transform.Sequential(quant_passes)
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
mod = quantize_seq(mod)
return mod[mod.entry_func.name_hint]
......@@ -344,7 +344,7 @@ def generic_func(fdefault):
The function to be registered.
override : bool
Whether override existing registeration.
Whether override existing registration.
Returns
-------
......@@ -489,6 +489,13 @@ def rasp(options=None):
return arm_cpu('rasp3b', options)
def vta(model='unknown', options=None):
opts = ["-device=vta", '-keys=cpu', '-model=%s' % model]
opts = _merge_opts(opts, options)
ret = _api_internal._TargetCreate("ext_dev", *opts)
return ret
def create(target_str):
"""Get a target given target string.
......
......@@ -58,6 +58,7 @@ Target CreateTarget(const std::string& target_name,
std::string libs_flag = "-libs=";
std::string device_flag = "-device=";
std::string keys_flag = "-keys=";
for (auto& item : options) {
t->options_array.push_back(ir::StringImm::make(item));
......@@ -69,6 +70,13 @@ Target CreateTarget(const std::string& target_name,
}
} else if (item.find(device_flag) == 0) {
t->device_name = item.substr(device_flag.length());
t->keys_array.push_back(ir::StringImm::make(t->device_name));
} else if (item.find(keys_flag) == 0) {
std::stringstream ss(item.substr(keys_flag.length()));
std::string key_item;
while (std::getline(ss, key_item, ',')) {
t->keys_array.push_back(ir::StringImm::make(key_item));
}
}
}
......@@ -572,6 +580,9 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
Map<Target, Array<LoweredFunc>> updated_input;
for (const auto& it : inputs) {
auto target = Target::Create(it.first);
if (target->device_name == "vta") {
target = Target::Create("ext_dev");
}
updated_input.Set(target, it.second);
}
return build(updated_input, target_host, config);
......
......@@ -184,7 +184,7 @@ LoweredFunc MakeAPI(Stmt body,
for (Var v : undefined) {
os << " \'" << v->name_hint << "\' ";
}
os << " does not appeared in api_args";
os << " does not appear in api_args";
LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
}
return f;
......
......@@ -30,6 +30,7 @@
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
#include <utility>
#include <limits>
#include <mutex>
......@@ -170,7 +171,7 @@ class ScheduleGetter :
LOG(FATAL) << "not handled";
return tvm::Expr();
}
});
}, "compile_engine_const", topi::kBroadcast);
scalars_.push_back(value->op);
return {value};
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -83,5 +83,61 @@ TVM_ADD_FILELINE)
return {topi::identity(inputs[0])};
});
Expr ForceCast(Expr data) {
static const Op& op = Op::Get("annotation.force_cast");
return CallNode::make(op, {data}, Attrs{}, {});
}
RELAY_REGISTER_OP("annotation.force_cast")
.describe(R"code(Annotate an expression to force a cast.)code"
TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input data.")
.add_type_rel("Identity", IdentityRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
RELAY_REGISTER_OP("annotation.bitpack_start")
.describe(R"code(
Mark the start of bitpacking.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
RELAY_REGISTER_OP("annotation.bitpack_end")
.describe(R"code(
Mark the end of bitpacking.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
} // namespace relay
} // namespace tvm
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -379,6 +379,8 @@ Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array
Expr StopFusion(Expr data);
Expr ForceCast(Expr data);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
......@@ -90,7 +90,7 @@ RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.add_argument("clip_min", "Tensor", "lower bound. It should be a scalar")
.add_argument("clip_max", "Tensor", "upper bound. It should be a scalar")
.set_attrs_type_key("relay.attrs.SimulatedQuantizeAttrs")
.set_support_level(10)
.set_support_level(11)
.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel);
TVM_REGISTER_API("relay._quantize.simulated_quantize")
......@@ -133,6 +133,23 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
static_cast<QAnnotateKind>(args[1].operator int()));
});
TVM_REGISTER_API("relay._quantize.annotate")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr);
});
// =============
// realize pass
......@@ -371,7 +388,6 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
const Array<Expr>& args,
DataType* dtype_ptr,
Expr* scale_ptr) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();
std::vector<const QRealizeIntExprNode*> nptrs;
......@@ -385,19 +401,15 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
// unify the data type
CHECK_EQ(ref_args.size(), args.size());
DataType dtype = cfg->dtype_activation;
for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input);
if (cfg->store_lowbit_output) {
new_arg = StopFusion(new_arg);
}
ret.Set(i, Cast(new_arg, dtype));
}
DataType dtype;
if (nptrs[0]->dtype == cfg->dtype_activation) {
DataType dtype = cfg->dtype_activation;
ret.Set(1, Cast(ret[1], dtype));
} else if (nptrs[1]->dtype == cfg->dtype_input) {
DataType dtype = cfg->dtype_input;
ret.Set(0, Cast(ret[0], dtype));
} else {
LOG(FATAL) << "should not touch here.";
}
// unify the dom_scale
......@@ -504,10 +516,13 @@ RELAY_REGISTER_OP("nn.relu")
RELAY_REGISTER_OP("strided_slice")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("annotation.stop_fusion")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
Expr MaxPoolRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
/* \brief for unary operators which requantize its input to dtype_nbit */
Expr CastDtypeInputRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
......@@ -520,7 +535,7 @@ Expr MaxPoolRealize(const Call& ref_call,
}
RELAY_REGISTER_OP("nn.max_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", MaxPoolRealize);
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
Expr AvgPoolRealize(const Call& ref_call,
......@@ -543,6 +558,29 @@ Expr AvgPoolRealize(const Call& ref_call,
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
Expr ForceCastRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = Cast(n->data, cfg->dtype_input);
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("annotation.force_cast")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ForceCastRealize);
TVM_REGISTER_API("relay._quantize.realize")
.set_body_typed<Expr(Expr)>([](const Expr& e) {
Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
return ret;
});
// =============
// qconfig
......@@ -646,6 +684,51 @@ Pass QuantizeRealizePass() {
TVM_REGISTER_API("relay._quantize.QuantizeRealize")
.set_body_typed(QuantizeRealizePass);
Pass QuantizeRewriteForVTAPass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQVTARewrite", nullptr, nullptr));
};
return CreateFunctionPass(pass_func, 1, "QuantizeRewriteForVTA", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeRewriteForVTA")
.set_body_typed(QuantizeRewriteForVTAPass);
// =============
// Insert stop_fusion for vta.
Expr QVTAExprNode::Realize() const {
Expr ret = ForceCast(this->expr);
return StopFusion(ret);
}
QVTAExpr QVTAExprNode::make(Expr expr) {
auto rnode = make_node<QVTAExprNode>();
rnode->expr = expr;
return QVTAExpr(rnode);
}
TVM_REGISTER_API("relay._quantize.make_vta_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QVTAExprNode::make(args[0]);
});
TVM_REGISTER_API("relay._quantize.make_stop_fusion")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
return StopFusion(expr);
});
TVM_REGISTER_API("relay._quantize.temp_expr_realize")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
const QVTAExprNode* n = expr.as<QVTAExprNode>();
CHECK(n);
return n->Realize();
});
} // namespace quantize
} // namespace relay
} // namespace tvm
......@@ -72,6 +72,33 @@ class QAnnotateExprNode : public TempExprNode {
RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr);
/*!
* \brief TempExpr used to insert `force_cast` for VTA.
*/
class QVTAExpr;
/*!
* \brief TempExprNode used to insert `force_cast` for VTA.
*/
class QVTAExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr expr;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
}
TVM_DLL static QVTAExpr make(Expr expr);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QVTAExpr";
TVM_DECLARE_NODE_TYPE_INFO(QVTAExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QVTAExpr, QVTAExprNode, TempExpr);
/*! \brief TempExpr used during realize forward rewrite. */
class QRealizeExpr;
/*! \brief TempExpr representing integer. */
......
......@@ -47,7 +47,7 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v)
if (FindNodeRef(all_vars, v) < all_vars->data.size()) {
LOG(FATAL) << "Operate on iter var " << v
<< "that has already been splitted";
<< "that has already been split";
} else {
LOG(FATAL) << "Operate on iter var " << v
<< "that is not part of the schedule";
......
......@@ -35,6 +35,8 @@ from . import vision
from . import image
from . import sparse
from . import hls
# error reporting
from .util import InvalidShapeError
# not import testing by default
# because testing can have extra deps that are not necessary
# we can import them from test cases explicitly
......
......@@ -23,6 +23,10 @@ import tvm
from tvm.api import layout, bijective_layout
from . import tag
class InvalidShapeError(ValueError):
"""Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
pass
def traverse_inline(s, final_op, callback):
"""Traverse computation graph and do auto inline
......
......@@ -42,7 +42,7 @@ extern "C" {
/*! \brief Physically contiguous buffer size limit */
#ifndef VTA_MAX_XFER
#define VTA_MAX_XFER (1<<22)
#define VTA_MAX_XFER (1<<25)
#endif
/*! PAGE SIZE */
......
......@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""VTA specific buildin for runtime."""
from __future__ import absolute_import as _abs
......
......@@ -235,6 +235,10 @@ class Environment(object):
return self.dev.gemm
@property
def target(self):
return tvm.target.vta(model=self.TARGET)
@property
def target_host(self):
"""The target host"""
if self.TARGET == "pynq":
......@@ -243,6 +247,9 @@ class Environment(object):
return "llvm"
raise ValueError("Unknown target %s" % self.TARGET)
@property
def target_vta_cpu(self):
return tvm.target.arm_cpu(model=self.TARGET)
def get_env():
"""Get the current VTA Environment.
......
......@@ -66,6 +66,9 @@ def server_start():
@tvm.register_func("tvm.contrib.vta.init", override=True)
def program_fpga(file_name):
from pynq import xlnk
# Reset xilinx driver
xlnk.Xlnk().xlnk_reset()
path = tvm.get_global_func("tvm.rpc.server.workpath")(file_name)
env = get_env()
program_bitstream.bitstream_program(env.TARGET, path)
......
......@@ -77,8 +77,6 @@ class PkgConfig(object):
if self.target == "pynq":
self.ldflags = [
"-L/usr/lib",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/",
"-l:libcma.so"]
else:
self.ldflags = []
......
......@@ -84,4 +84,17 @@ def tsim_cycles():
"""
return tvm.get_global_func("tvm.vta.tsim.cycles")()
# debug flag to skip execution.
DEBUG_SKIP_EXEC = 1
def debug_mode(flag):
"""Set debug mode
Paramaters
----------
flag : int
The debug flag, 0 means clear all flags.
"""
tvm.get_global_func("vta.simulator.profiler_debug_mode")(flag)
LIBS = _load_lib()
......@@ -18,7 +18,7 @@
from __future__ import absolute_import as _abs
import os
from tvm import rpc
from tvm import rpc, autotvm
from ..environment import get_env
from . import simulator
......@@ -42,7 +42,7 @@ def run(run_func):
# the port it's listening to, e.g. 9090
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
remote = rpc.connect("localhost", local_rpc)
remote = rpc.connect("127.0.0.1", local_rpc)
run_func(env, remote)
else:
# Make sure simulation library exists
......@@ -54,12 +54,22 @@ def run(run_func):
elif env.TARGET == "pynq":
# Run on PYNQ if env variable exists
host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
port = int(os.environ.get("VTA_PYNQ_RPC_PORT", None))
if host and port:
remote = rpc.connect(host, port)
tracket_host = os.environ.get("TVM_TRACKER_HOST", None)
tracket_port = int(os.environ.get("TVM_TRACKER_PORT", None))
pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
pynq_port = int(os.environ.get("VTA_PYNQ_RPC_PORT", None))
# Run device from fleet node if env variables are defined
if tracket_host and tracket_port:
remote = autotvm.measure.request_remote(env.TARGET,
tracket_host,
tracket_port,
timeout=10000)
run_func(env, remote)
else:
raise RuntimeError(
"Please set the VTA_PYNQ_RPC_HOST and VTA_PYNQ_RPC_PORT environment variables")
# Next, run on PYNQ if env variables are defined
if pynq_host and pynq_port:
remote = rpc.connect(pynq_host, pynq_port)
run_func(env, remote)
else:
raise RuntimeError(
"Please set the VTA_PYNQ_RPC_HOST and VTA_PYNQ_RPC_PORT environment variables")
"""TVM TOPI connector, eventually most of these should go to TVM repo"""
from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
from . import bitpack
from .graphpack import graph_pack
from . import op
from . import vta_conv2d
from . import arm_conv2d
from . import vta_dense
# NNVM is deprecated for VTA
# from . import nnvm_bitpack
# from .nnvm_graphpack import nnvm_graph_pack
# from . import nnvm_op
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=ungrouped-imports
"""Bit packing operators"""
from __future__ import absolute_import as _abs
import tvm
from topi import util
from tvm.relay.op.op import register_compute, register_schedule
from tvm.relay.op.op import register_pattern, OpPattern
from tvm.relay.op.op import schedule_injective
def bitpack(data, bits, pack_type="int8", name="bitpack"):
"""Packs lowest dimension into format needed by VTA
Parameters
----------
pack_axis : int
index of the axis to pack in data
bit_axis : int
index of axis to place bit axis in resulting packed data
Returns
-------
packed : Tensor
The packed tensor.
"""
shape_vec = list(data.shape)
if pack_type == 'int8':
data_width = 8
elif pack_type == 'int16':
data_width = 16
elif pack_type == 'int32':
data_width = 32
else:
raise RuntimeError("Unknown pack type %s" % pack_type)
assert data_width % bits == 0
lanes = data_width // bits
# Data must be in multiples of the data_width
assert util.get_const_int(shape_vec[-1]) % lanes == 0, "Not a multiple of word size"
shape_vec[-1] = shape_vec[-1] // lanes
oshape = tuple(shape_vec)
def _bitpack(*indices):
ret = None
mask = tvm.const((1 << bits) - 1, pack_type)
for k in range(lanes):
idx = list(indices)
idx[-1] = idx[-1] * lanes + k
elem = data(*idx).astype(pack_type)
if k == 0:
ret = elem & mask
else:
val = (elem & mask) << tvm.const(k * bits, pack_type)
ret = ret | val
return ret
return tvm.compute(
oshape, _bitpack, name=name, tag='bitpack')
@register_compute("bitpack", level=15)
def compute_bitpack(attrs, inputs):
lanes = attrs.lanes
dtype = inputs[0].dtype
assert dtype == "int8"
width = 8
assert width % lanes == 0
bits = 8 // lanes
return bitpack(inputs[0], bits, dtype)
register_schedule("bitpack", schedule_injective)
register_pattern("bitpack", OpPattern.INJECTIVE)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""Bit packing operators"""
from __future__ import absolute_import as _abs
import tvm
from topi import util
from nnvm.top import registry as reg, OpPattern
from nnvm.top.tensor import _fschedule_broadcast
def bitpack(data, bits, pack_type="int8", name="bitpack"):
"""Packs lowest dimension into format needed by VTA
Parameters
----------
pack_axis : int
index of the axis to pack in data
bit_axis : int
index of axis to place bit axis in resulting packed data
Returns
-------
packed : Tensor
The packed tensor.
"""
shape_vec = list(data.shape)
if pack_type == 'int8':
data_width = 8
elif pack_type == 'int16':
data_width = 16
elif pack_type == 'int32':
data_width = 32
else:
raise RuntimeError("Unknown pack type %s" % pack_type)
assert data_width % bits == 0
lanes = data_width // bits
# Data must be in multiples of the data_width
assert util.get_const_int(shape_vec[-1]) % lanes == 0, "Not a multiple of word size"
shape_vec[-1] = shape_vec[-1] // lanes
oshape = tuple(shape_vec)
def _bitpack(*indices):
ret = None
mask = tvm.const((1 << bits) - 1, pack_type)
for k in range(lanes):
idx = list(indices)
idx[-1] = idx[-1] * lanes + k
elem = data(*idx).astype(pack_type)
if k == 0:
ret = elem & mask
else:
val = (elem & mask) << tvm.const(k * bits, pack_type)
ret = ret | val
return ret
return tvm.compute(
oshape, _bitpack, name=name, tag='bitpack')
@reg.register_compute("bitpack", level=15)
def compute_bitpack(attrs, inputs, out):
lanes = attrs.get_int("lanes")
dtype = inputs[0].dtype
assert dtype == "int8"
width = 8
assert width % lanes == 0
bits = 8 // lanes
return bitpack(inputs[0], bits, dtype)
reg.register_schedule("bitpack", _fschedule_broadcast)
reg.register_pattern("bitpack", OpPattern.INJECTIVE)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""An NNVM implementation of graph packing."""
import nnvm
from nnvm.compiler import graph_attr, graph_util
def _pack_batch_channel(data, dshape, bfactor, cfactor):
"""Pack the data channel dimension.
"""
assert dshape[0] % bfactor == 0
assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // bfactor, bfactor,
dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(0, 2, 4, 5, 1, 3))
return data
def _unpack_batch_channel(data, old_shape):
"""Unpack the data channel dimension.
"""
data = nnvm.sym.transpose(data, axes=(0, 4, 1, 5, 2, 3))
data = nnvm.sym.reshape(data, shape=old_shape)
return data
def _pack_weight(data, dshape, cfactor):
"""Pack the weight into packed format.
"""
assert len(dshape) == 4
assert dshape[0] % cfactor == 0
assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // cfactor, cfactor,
dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(0, 2, 4, 5, 1, 3))
return data
def _pack_weight_conv2d_transpose(data, dshape, cfactor):
"""Pack the weight into packed format.
"""
assert len(dshape) == 4
assert dshape[0] % cfactor == 0
assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // cfactor, cfactor,
dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(2, 0, 4, 5, 3, 1))
return data
def _pack_bias(data, dshape, bfactor, cfactor):
"""Pack the bias parameter.
"""
assert len(dshape) == 3
assert dshape[0] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // cfactor,
cfactor, dshape[1],
dshape[2], 1))
data = nnvm.sym.transpose(
data, axes=(0, 2, 3, 4, 1))
# broadcast batch dimension to bfactor
data = nnvm.sym.broadcast_to(
data,
shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor))
return data
def _get_shape(sym, shape_dict):
"""Get the shape of a node.
"""
return graph_util.infer_shape(
nnvm.graph.create(sym), **shape_dict)[1][0]
def nnvm_graph_pack(graph,
shape_dict,
bfactor,
cfactor,
weight_bits,
start_name="max_pool2d0",
stop_name="global_avg_pool2d0"):
"""Pack the graph into batch&channel packed format.
Parameters
----------
graph : Graph
The input graph.
shape_dict : dict of str to shape
The input shape.
bfactor : int
The packing factor in batch
cfactor : int
The packing factor in channel
start_name: str, optional
Start packing from certain known node.
start_name: str, optional
Stop packing from certain known node.
Returns
-------
graph : Graph
The transformed graph.
"""
graph = graph_attr.set_shape_inputs(graph, shape_dict)
graph = graph.apply("InferShape")
shape = graph.json_attr("shape")
gidx = graph.index
node_map = {}
dset = set()
start_pack = False
for nid, node in enumerate(gidx.nodes):
children = [node_map[e[0]] for e in node["inputs"]]
ishape = [shape[gidx.entry_id(e)] for e in node["inputs"]]
oshape = shape[gidx.entry_id(nid, 0)]
attrs = node.get("attrs", {})
node_name = node["name"]
op_name = node["op"]
get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
*c, name=n_n, **a)
if op_name == "null":
new_node = nnvm.symbol.Variable(node_name)
if start_name and node_name == start_name:
start_pack = True
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
if start_pack and "_begin_state_" in node_name: # RNN -> CNN, pack
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
elif node_name == start_name:
assert not start_pack
start_pack = True
new_node = get_clone(children, op_name, node_name, attrs)
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
elif node_name == stop_name:
if start_pack:
start_pack = False
children[0] = _unpack_batch_channel(children[0], ishape[0])
new_node = getattr(nnvm.symbol, op_name)(
*children, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name == "conv2d" and attrs.get("out_dtype", None) == "int32":
assert 8 % weight_bits == 0
w_lanes = 8 // weight_bits
if start_pack:
attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
attrs["kernel_layout"] = "OIHW%do%di%dp" % (cfactor, cfactor, w_lanes)
data, weight = children
weight = _pack_weight(weight, ishape[1], cfactor)
# insert bit packing when necessary
if w_lanes != 1:
assert 8 % w_lanes == 0
weight = nnvm.sym.bitpack(weight, lanes=w_lanes)
new_node = nnvm.sym.conv2d(
data, weight, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name == "conv2d_transpose" and attrs.get("out_dtype", None) == "int32":
assert 8 % weight_bits == 0
w_lanes = 8 // weight_bits
if start_pack:
attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
attrs["kernel_layout"] = "IOHW%di%do%dp" % (cfactor, cfactor, w_lanes)
data, weight = children
weight = _pack_weight_conv2d_transpose(weight, ishape[1], cfactor)
new_node = nnvm.sym.conv2d_transpose(
data, weight, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("broadcast_") and tuple(ishape[0]) == tuple(ishape[1]):
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("broadcast") and len(ishape[1]) == 3:
if start_pack:
children[1] = _pack_bias(children[1], ishape[1], bfactor, cfactor)
new_node = getattr(nnvm.symbol, op_name)(
*children, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("elementwise_add"):
new_node = get_clone(children, op_name, node_name, attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
dset.add(op_name)
node_map[nid] = new_node
assert len(graph.index.output_entries) == 1
ret = node_map[graph.index.output_entries[0][0]]
if start_pack:
oshape = shape[graph.index.output_entries[0][0]]
ret = _unpack_batch_channel(ret, oshape)
graph = nnvm.graph.create(ret)
graph = graph_attr.set_shape_inputs(graph, shape_dict)
graph = graph.apply("InferShape")
return graph
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace for supporting packed_conv2d + ewise variant of nnvm."""
from __future__ import absolute_import as _abs
import logging
import tvm
import topi
from nnvm.top import registry as reg, OpPattern
from nnvm.top import nn as _nn
from .vta_conv2d import is_packed_layout
from ..environment import get_env
@tvm.register_func("nnvm.compiler.build_target", override=True)
def _build(funcs, target, target_host):
tvm_t = tvm.target.create(target)
if tvm_t.device_name == "vta":
return tvm.build(funcs, target="ext_dev", target_host=target_host)
if tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu":
return tvm.build(funcs, target=target_host)
return tvm.build(funcs, target=target)
@tvm.register_func("nnvm.compiler.lower", override=True)
def _lower(sch, inputs, func_name, graph):
import traceback
# pylint: disable=broad-except
try:
f = tvm.lower(sch, inputs, name=func_name)
if "quantized_conv2d" in func_name:
logging.info(graph.ir(join_entry_attrs=["shape"]))
except Exception:
msg = traceback.format_exc()
msg += "Error during compile graph\n"
msg += "--------------------------\n"
msg += graph.ir(join_entry_attrs=["shape"])
raise RuntimeError(msg)
return f if isinstance(
f, (tvm.container.Array, tuple, list)) else [f]
# override to force partition at copy
reg.register_pattern("copy", OpPattern.INJECTIVE, level=15)
@reg.register_compute("clip", level=15)
def compute_clip(attrs, inputs, _):
""" Clip operator. """
x = inputs[0]
a_min = attrs.get_float("a_min")
a_max = attrs.get_float("a_max")
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
with tvm.tag_scope(topi.tag.ELEMWISE):
x = tvm.compute(
x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(
x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
@reg.register_compute("conv2d", level=15)
def compute_conv2d(attrs, inputs, out):
""" Compute definition of conv2d """
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs["layout"]
out_dtype = attrs['out_dtype']
assert dilation == (1, 1), "not support dilate now"
if is_packed_layout(layout):
if groups == 1:
assert groups == 1
env = get_env()
assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now"
inputs = list(inputs)
assert inputs[1].dtype == "int8"
return topi.nn.conv2d(inputs[0], inputs[1], strides,
padding, dilation, layout, out_dtype)
return topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides,
padding, dilation, groups, out_dtype)
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.compute_conv2d(attrs, inputs, out)
@reg.register_schedule("conv2d", level=15)
def schedule_conv2d(attrs, outs, target):
""" Schedule definition of conv2d """
layout = attrs["layout"]
groups = attrs.get_int('groups')
if is_packed_layout(layout):
target = tvm.target.create(target)
if target.device_name == "vta":
if groups == 1:
return topi.generic.schedule_conv2d_nchw(outs)
return topi.generic.schedule_group_conv2d_nchw(outs)
elif str(target).startswith("llvm"):
return tvm.create_schedule([x.op for x in outs])
else:
raise RuntimeError("not support target %s" % target)
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.schedule_conv2d(attrs, outs, tvm.target.current_target())
@reg.register_alter_op_layout("conv2d", level=15)
def alter_conv2d_layout(attrs, inputs, out):
layout = attrs['layout']
if is_packed_layout(layout):
return None
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.alter_conv2d_layout(attrs, inputs, out)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, ungrouped-imports
"""Namespace for supporting packed_conv2d + ewise variant of nnvm."""
from __future__ import absolute_import as _abs
import tvm
import topi
from tvm.relay.op import op as reg
from tvm.relay.op.op import OpPattern
from tvm.relay.op.nn import _nn
from .vta_conv2d import is_packed_layout
from ..environment import get_env
# override to force partition at copy
reg.register_pattern("copy", OpPattern.INJECTIVE, level=15)
@reg.register_compute("clip", level=15)
def compute_clip(attrs, inputs, output_type, target):
""" Clip operator. """
x = inputs[0]
a_min = attrs.a_min
a_max = attrs.a_max
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
with tvm.tag_scope(topi.tag.ELEMWISE):
x = tvm.compute(
x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(
x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return [x]
@reg.register_compute("nn.conv2d", level=15)
def compute_conv2d(attrs, inputs, output_type, target):
""" Compute definition of conv2d """
padding = topi.util.get_const_tuple(attrs.padding)
strides = topi.util.get_const_tuple(attrs.strides)
dilation = tuple([int(d) for d in attrs.dilation])
groups = attrs.groups
layout = attrs.data_layout
out_dtype = attrs.out_dtype
if target.device_name == "vta":
assert dilation == (1, 1), "support for dilation limited to (1, 1)"
if is_packed_layout(layout):
if groups == 1:
assert groups == 1
env = get_env()
assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
assert env.LOG_WGT_WIDTH == 3, "only support 8bit wgt for now"
inputs = list(inputs)
assert inputs[1].dtype == "int8"
return [topi.nn.conv2d(inputs[0],
inputs[1],
strides,
padding,
dilation,
layout,
out_dtype)]
return [topi.nn.group_conv2d_nchw(inputs[0],
inputs[1],
strides,
padding,
dilation,
groups,
out_dtype)]
# If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.compute_conv2d(attrs, inputs, output_type, target)
# If VTA is not the target, default to _nn def
return _nn.compute_conv2d(attrs, inputs, output_type, target)
@reg.register_schedule("nn.conv2d", level=15)
def schedule_conv2d(attrs, outs, target):
""" Schedule definition of conv2d """
groups = attrs.groups
layout = attrs.data_layout
if target.device_name == "vta":
if is_packed_layout(layout):
target = tvm.target.create(target)
assert target.device_name == "vta"
if groups == 1:
return topi.generic.schedule_conv2d_nchw(outs)
return topi.generic.schedule_group_conv2d_nchw(outs)
# If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.schedule_conv2d(attrs, outs, tvm.target.current_target())
# If VTA is not the target, default to _nn def
return _nn.schedule_conv2d(attrs, outs, target)
@reg.register_compute("nn.dense", level=15)
def compute_dense(attrs, inputs, out_type, target):
"""Compute definition of dense"""
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
if target.device_name == "vta":
if inputs[0].shape == 4: # this implies the layout is packed
target = tvm.target.create(target)
return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)]
# If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.compute_dense(attrs, inputs, out_type, target)
# If VTA is not the target, default to _nn def
return _nn.compute_dense(attrs, inputs, out_type, target)
@reg.register_schedule("nn.dense", level=15)
def schedule_dense(attrs, outs, target):
"""Schedule definition of dense"""
if target.device_name == "vta":
if outs[0].shape == 4: # this implies the layout is packed
target = tvm.target.create(target)
assert target.device_name == "vta"
return topi.generic.schedule_dense(outs)
# If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.schedule_dense(attrs, outs, tvm.target.current_target())
# If VTA is not the target, default to _nn def
return _nn.schedule_dense(attrs, outs, target)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""Dense operator declaration and schedule registration for VTA."""
import numpy as np
import tvm
from tvm import autotvm
import topi
from ..environment import get_env
def is_packed_layout(layout):
"""Check if layout is packed layout"""
if layout == "NCHW":
return False
if "n" in layout and "c" in layout:
return True
return False
@autotvm.register_topi_compute(topi.nn.dense, 'vta', 'direct')
def _declaration_dense(cfg,
data,
weight,
bias=None,
out_dtype=None):
"""Dense function declaration."""
# Make sure that the dense operator is packed
if len(data.shape) != 4 or len(weight.shape) != 4:
raise topi.InvalidShapeError()
# Derive shapes
ishape = topi.util.get_const_tuple(data.shape)
wshape = topi.util.get_const_tuple(weight.shape)
oshape = (data.shape[0], weight.shape[0], data.shape[2], weight.shape[2])
# Reduction axes (input channel)
assert ishape[1] == wshape[1]
assert ishape[3] == wshape[3]
k_o = tvm.reduce_axis((0, ishape[1]), name='k_o')
k_i = tvm.reduce_axis((0, ishape[3]), name='k_i')
res = tvm.compute(
oshape,
lambda b_o, c_o, b_i, c_i: tvm.sum(
data[b_o, k_o, b_i, k_i].astype(out_dtype) *
weight[c_o, k_o, c_i, k_i].astype(out_dtype),
axis=[k_o, k_i]),
name="res", tag="dense_pack")
cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) *
ishape[1] * ishape[3])
return res
@autotvm.register_topi_schedule(topi.generic.schedule_dense, 'vta', 'direct')
def _schedule_dense(cfg, outs):
"""Packed dense schedule."""
assert len(outs) == 1
output = outs[0]
const_ops = []
ewise_inputs = []
ewise_ops = []
dense_res = []
assert "int" in output.op.input_tensors[0].dtype
def _traverse(op):
if topi.tag.is_broadcast(op.tag):
if not op.same_as(output.op):
if not op.axis:
const_ops.append(op)
else:
ewise_ops.append(op)
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.PlaceholderOp):
ewise_inputs.append((op, tensor))
else:
_traverse(tensor.op)
else:
assert op.tag == "dense_pack"
dense_res.append(op)
_traverse(output.op)
assert len(dense_res) == 1
dense_stage = dense_res[0].output(0)
s = tvm.create_schedule(output.op)
##### space definition begin #####
b, c_o, _, _ = s[dense_stage].op.axis
c_i, _ = s[dense_stage].op.reduce_axis
cfg.define_split('tile_b', b, num_outputs=2)
cfg.define_split('tile_ci', c_i, num_outputs=2)
cfg.define_split('tile_co', c_o, num_outputs=2)
cfg.define_knob('oc_nthread', [1, 2])
###### space definition end ######
data, weight = dense_stage.op.input_tensors
env = get_env()
cdata = s.cache_read(data, env.inp_scope, [dense_stage])
cweight = s.cache_read(weight, env.wgt_scope, [dense_stage])
s[dense_stage].set_scope(env.acc_scope)
# cache read input
cache_read_ewise = []
for consumer, tensor in ewise_inputs:
cache_read_ewise.append(
s.cache_read(tensor, env.acc_scope, [consumer]))
# set ewise scope
for op in ewise_ops:
s[op].set_scope(env.acc_scope)
s[op].pragma(s[op].op.axis[0], env.alu)
for op in const_ops:
s[op].compute_inline()
# apply tiling for SRAM reuse
x_b, x_c, _, _ = s[output].op.axis
x_bo, x_bi = cfg['tile_b'].apply(s, output, x_b)
x_co, x_ci = cfg['tile_co'].apply(s, output, x_c)
s[output].reorder(x_bo, x_co, x_bi, x_ci)
store_pt = x_co
# set all compute scopes
s[dense_stage].compute_at(s[output], store_pt)
for op in ewise_ops:
s[op].compute_at(s[output], store_pt)
for tensor in cache_read_ewise:
s[tensor].compute_at(s[output], store_pt)
s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy)
# virtual threading along output channel axes
if cfg['oc_nthread'].val > 1:
_, v_t = s[output].split(x_co, factor=cfg['oc_nthread'].val)
s[output].reorder(v_t, x_bo)
s[output].bind(v_t, tvm.thread_axis("cthread"))
x_bo, x_co, x_bi, _ = s[dense_stage].op.axis
k_o, _ = s[dense_stage].op.reduce_axis
s[dense_stage].reorder(x_bo, k_o, x_co)
k_o, _ = cfg['tile_ci'].apply(s, dense_stage, k_o)
s[cdata].compute_at(s[dense_stage], k_o)
s[cweight].compute_at(s[dense_stage], k_o)
# Use VTA instructions
s[cdata].pragma(s[cdata].op.axis[0], env.dma_copy)
s[cweight].pragma(s[cweight].op.axis[0], env.dma_copy)
s[dense_stage].tensorize(x_bi, env.gemm)
s[output].pragma(x_ci, env.dma_copy)
return s
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tuning a single conv2d operator"""
from collections import namedtuple
import logging
import os
import tvm
from tvm import autotvm
from tvm.contrib.util import get_lower_ir
import topi
import vta
import vta.testing
env = vta.get_env()
Workload = namedtuple("Conv2DWorkload",
['batch', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
resnet_wkls = [
# Workloads of resnet18 on imagenet
# ('resnet-18.C1', Workload(1, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)),
('resnet-18.C2', Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
# ('resnet-18.C3', Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1)), # this layer does not appear in ResNet
('resnet-18.C4', Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
('resnet-18.C5', Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
('resnet-18.C6', Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
('resnet-18.C7', Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
('resnet-18.C8', Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
('resnet-18.C9', Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
('resnet-18.C10', Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
('resnet-18.C11', Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
('resnet-18.C12', Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
]
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype):
data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
bias_shape = (N//env.BATCH, CO//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
with tvm.target.vta():
res = topi.nn.conv2d(data, kernel, padding=padding, strides=strides, dilation=dilation,
layout='NCHW%dn%dc' % (env.BATCH, env.BLOCK_IN), out_dtype='int32')
res = topi.add(res, bias)
res = topi.right_shift(res, 8)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta':
s = topi.generic.schedule_conv2d_nchw([res])
else:
s = tvm.create_schedule([res.op])
return s, [data, kernel, bias, res]
if __name__ == '__main__':
# Logging config (for printing tuning log to the screen)
logging.basicConfig()
logging.getLogger('autotvm').setLevel(logging.DEBUG)
# Get tracker info from env
tracket_host = os.environ.get("TVM_TRACKER_HOST", None)
tracket_port = int(os.environ.get("TVM_TRACKER_PORT", None))
if not tracket_host or not tracket_port:
print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit()
for wl_name, wl in resnet_wkls:
# Workload parameters
N = wl.batch
CI = wl.in_filter
H = wl.height
W = wl.width
CO = wl.out_filter
KH = wl.hkernel
KW = wl.wkernel
strides = (wl.hstride, wl.wstride)
padding = (wl.hpad, wl.wpad)
dilation = (1, 1)
in_dtype = 'int8'
out_dtype = 'int32'
task = autotvm.task.create(conv2d, args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype),
target=tvm.target.vta(), target_host=env.target_host, template_key='direct')
print(task.config_space)
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(env.TARGET, tracket_host, tracket_port, number=4, repeat=3, timeout=10000,
check_correctness=True))
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=len(task.config_space),
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('conv2d.log')])
print("\nBest tuner config:")
print(tuner.best_config)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tuning a single dense operator"""
from collections import namedtuple
import logging
import os
import tvm
from tvm import autotvm
from tvm.contrib.util import get_lower_ir
import topi
import vta
import vta.testing
env = vta.get_env()
Workload = namedtuple("DenseWorkload",
['batch', 'in_filter', 'out_filter'])
resnet_wkls = [
# Workloads of resnet18 on imagenet
('resnet-18.dense', Workload(16, 512, 1024)),
]
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
def dense(N, CI, CO):
data_shape = (N//env.BATCH, CI//env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, env.BLOCK_OUT, env.BLOCK_IN)
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
with tvm.target.vta():
res = topi.nn.dense(data, kernel, None, 'int32')
res = topi.right_shift(res, 8)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta':
s = topi.generic.schedule_dense([res])
else:
s = tvm.create_schedule([res.op])
return s, [data, kernel, res]
if __name__ == '__main__':
# Logging config (for printing tuning log to the screen)
logging.basicConfig()
logging.getLogger('autotvm').setLevel(logging.DEBUG)
# Get tracker info from env
tracket_host = os.environ.get("TVM_TRACKER_HOST", None)
tracket_port = int(os.environ.get("TVM_TRACKER_PORT", None))
if not tracket_host or not tracket_port:
print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit()
for wl_name, wl in resnet_wkls:
# Workload parameters
N = wl.batch
CI = wl.in_filter
CO = wl.out_filter
task = autotvm.task.create(dense, args=(N, CI, CO),
target=tvm.target.vta(), target_host=env.target_host, template_key='direct')
print(task.config_space)
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(env.TARGET, tracket_host, tracket_port, number=4, repeat=3, timeout=10000,
check_correctness=True))
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=len(task.config_space),
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('dense.log')])
print("\nBest tuner config:")
print(tuner.best_config)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Perform ResNet autoTVM tuning on VTA using NNVM."""
import argparse
import os
import time
import numpy as np
import tvm
from tvm import rpc, autotvm
from tvm.autotvm.measure.measure_methods import request_remote
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib import graph_runtime, util
from tvm.contrib.download import download
import topi
import nnvm.compiler
import vta
import vta.testing
env = vta.get_env()
def register_vta_tuning_tasks():
from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
# init autotvm env to register VTA operator
TaskExtractEnv()
@autotvm.task.register("topi_nn_conv2d", override=True)
def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
with tvm.target.vta():
res = topi.nn.conv2d(*args, **kwargs)
res = topi.right_shift(res, 8)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta':
s = topi.generic.schedule_conv2d_nchw([res])
else:
s = tvm.create_schedule([res.op])
return s, [A, W, res]
def generate_graph(sym, params, target, target_host):
# Populate the shape and data type dictionary
shape_dict = {"data": (1, 3, 224, 224)}
dtype_dict = {"data": 'float32'}
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
# Apply NNVM graph optimization passes
sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym)
assert env.BLOCK_IN == env.BLOCK_OUT
sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)
# Compile NNVM graph
with nnvm.compiler.build_config(opt_level=3):
with vta.build_config():
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)
return graph, lib, params
def extract_tasks(sym, params, target, target_host):
# Populate the shape and data type dictionary
shape_dict = {"data": (1, 3, 224, 224)}
dtype_dict = {"data": 'float32'}
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
# Apply NNVM graph optimization passes
sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym)
assert env.BLOCK_IN == env.BLOCK_OUT
sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)
with vta.build_config():
tasks = autotvm.task.extract_from_graph(graph=sym, shape=shape_dict, dtype=dtype_dict, target=target,
params=params, symbols=(nnvm.sym.conv2d,), target_host=target_host)
return tasks
def download_model():
url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
categ_fn = 'synset.txt'
graph_fn = 'resnet18_qt8.json'
params_fn = 'resnet18_qt8.params'
data_dir = '_data'
if not os.path.exists(data_dir):
os.makedirs(data_dir)
for file in [categ_fn, graph_fn, params_fn]:
if not os.path.isfile(file):
download(os.path.join(url, file), os.path.join(data_dir, file))
sym = nnvm.graph.load_json(open(os.path.join(data_dir, graph_fn)).read())
params = nnvm.compiler.load_param_dict(open(os.path.join(data_dir, params_fn), 'rb').read())
return sym, params
def tune_tasks(tasks,
measure_option,
tuner='xgb',
n_trial=1000,
early_stopping=None,
log_filename='tuning.log',
use_transfer_learning=True,
try_winograd=True):
# create tmp log file
tmp_log_file = log_filename + ".tmp"
if os.path.exists(tmp_log_file):
os.remove(tmp_log_file)
for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
tuner_obj = XGBTuner(tsk, loss_type='rank')
elif tuner == 'ga':
tuner_obj = GATuner(tsk, pop_size=50)
elif tuner == 'random':
tuner_obj = RandomTuner(tsk)
elif tuner == 'gridsearch':
tuner_obj = GridSearchTuner(tsk)
else:
raise ValueError("Invalid tuner: " + tuner)
if use_transfer_learning:
if os.path.isfile(tmp_log_file):
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
# do tuning
n_trial_ = min(n_trial, len(tsk.config_space))
tuner_obj.tune(n_trial_,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial_, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
# pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_filename)
os.remove(tmp_log_file)
if __name__ == '__main__':
# Get tracker info from env
tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
tracker_port = int(os.environ.get("TVM_TRACKER_PORT", None))
if not tracker_host or not tracker_port:
print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit()
# Download model
sym, params = download_model()
# Register VTA tuning tasks
register_vta_tuning_tasks()
# Extract tasks
print("Extracting tasks...")
target = tvm.target.vta()
target_host = env.target_host
tasks = extract_tasks(sym, params, target, target_host)
# Perform Autotuning
print("Tuning...")
tuning_opt = {
'log_filename': 'resnet-18.log',
'tuner': 'random',
'n_trial': 1e9,
'early_stopping': None,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(env.TARGET, tracker_host, tracker_port,
number=4, repeat=3, timeout=60,
check_correctness=True))
}
tune_tasks(tasks, **tuning_opt)
# compile kernels with history best records
with autotvm.tophub.context(target, extra_files=[tuning_opt['log_filename']]):
# ResNet parameters
input_shape = (1, 3, 224, 224)
dtype = 'float32'\
# Compile network
print("Compiling network with best tuning parameters...")
graph, lib, params = generate_graph(sym, params, target, target_host)
input_shape = (1, 3, 224, 224)
dtype = 'float32'
# Export library
tmp = util.tempdir()
filename = "net.tar"
lib.export_library(tmp.relpath(filename))
# Upload module to device
print("Upload...")
remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000)
remote.upload(tmp.relpath(filename))
rlib = remote.load_module(filename)
# Upload parameters to device
ctx = remote.context(str(target), 0)
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module = graph_runtime.create(graph, rlib, ctx)
module.set_input('data', data_tvm)
module.set_input(**rparams)
# Evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=3, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
......@@ -908,12 +908,10 @@ class CommandQueue {
insn_queue_.InitSpace();
device_ = VTADeviceAlloc();
CHECK(device_ != nullptr);
printf("Initialize VTACommandHandle...\n");
}
~CommandQueue() {
VTADeviceFree(device_);
printf("Close VTACommandhandle...\n");
}
uint32_t GetElemBytes(uint32_t memory_id) {
......
......@@ -35,6 +35,11 @@
namespace vta {
namespace sim {
/*! \brief debug flag for skipping computation */
enum DebugFlagMask {
kSkipExec = 1
};
/*!
* \brief Helper class to pack and unpack bits
* Applies truncation when pack to low level bits.
......@@ -253,8 +258,12 @@ class SRAM {
return &(data_[index]);
}
// Execute the load instruction on this SRAM
void Load(const VTAMemInsn* op, DRAM* dram, uint64_t* load_counter) {
void Load(const VTAMemInsn* op,
DRAM* dram,
uint64_t* load_counter,
bool skip_exec) {
load_counter[0] += (op->x_size * op->y_size) * kElemBytes;
if (skip_exec) return;
DType* sram_ptr = data_ + op->sram_base;
uint8_t* dram_ptr = static_cast<uint8_t*>(dram->GetAddr(
op->dram_base * kElemBytes));
......@@ -325,6 +334,8 @@ class Profiler {
uint64_t gemm_counter{0};
/*! \brief instr counter for ALU ops */
uint64_t alu_counter{0};
/*! \brief set debug mode */
int64_t debug_flag{0};
/*! \brief clear the profiler */
void Clear() {
inp_load_nbytes = 0;
......@@ -335,6 +346,10 @@ class Profiler {
gemm_counter = 0;
alu_counter = 0;
}
/*! \return Whether we should skip execution. */
bool SkipExec() const {
return (debug_flag & DebugFlagMask::kSkipExec) != 0;
}
std::string AsJSON() {
std::ostringstream os;
......@@ -398,13 +413,15 @@ class Device {
void RunLoad(const VTAMemInsn* op) {
if (op->x_size == 0) return;
if (op->memory_type == VTA_MEM_ID_INP) {
inp_.Load(op, dram_, &(prof_->inp_load_nbytes));
inp_.Load(op, dram_, &(prof_->inp_load_nbytes), prof_->SkipExec());
} else if (op->memory_type == VTA_MEM_ID_WGT) {
wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes));
wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes), prof_->SkipExec());
} else if (op->memory_type == VTA_MEM_ID_ACC) {
acc_.Load(op, dram_, &(prof_->acc_load_nbytes));
acc_.Load(op, dram_, &(prof_->acc_load_nbytes), prof_->SkipExec());
} else if (op->memory_type == VTA_MEM_ID_UOP) {
uop_.Load(op, dram_, &(prof_->uop_load_nbytes));
// always load in uop, since uop is stateful
// subsequent non-debug mode exec can depend on it.
uop_.Load(op, dram_, &(prof_->uop_load_nbytes), false);
} else {
LOG(FATAL) << "Unknown memory_type=" << op->memory_type;
}
......@@ -416,7 +433,9 @@ class Device {
op->memory_type == VTA_MEM_ID_UOP) {
prof_->out_store_nbytes += (
op->x_size * op->y_size * VTA_BATCH * VTA_BLOCK_OUT * VTA_OUT_WIDTH / 8);
acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
if (!prof_->SkipExec()) {
acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
}
} else {
LOG(FATAL) << "Store do not support memory_type="
<< op->memory_type;
......@@ -425,7 +444,8 @@ class Device {
void RunGEMM(const VTAGemInsn* op) {
if (!op->reset_reg) {
prof_->gemm_counter += op->iter_out * op->iter_in;
prof_->gemm_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn);
if (prof_->SkipExec()) return;
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
......@@ -459,6 +479,7 @@ class Device {
}
}
} else {
if (prof_->SkipExec()) return;
// reset
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
......@@ -477,7 +498,6 @@ class Device {
}
void RunALU(const VTAAluInsn* op) {
prof_->alu_counter += op->iter_out * op->iter_in;
if (op->use_imm) {
RunALU_<true>(op);
} else {
......@@ -520,6 +540,8 @@ class Device {
template<bool use_imm, typename F>
void RunALULoop(const VTAAluInsn* op, F func) {
prof_->alu_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn);
if (prof_->SkipExec()) return;
for (int y = 0; y < op->iter_out; ++y) {
for (int x = 0; x < op->iter_in; ++x) {
for (int k = op->uop_bgn; k < op->uop_end; ++k) {
......@@ -566,6 +588,10 @@ TVM_REGISTER_GLOBAL("vta.simulator.profiler_status")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Profiler::ThreadLocal()->AsJSON();
});
TVM_REGISTER_GLOBAL("vta.simulator.profiler_debug_mode")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Profiler::ThreadLocal()->debug_flag = args[0];
});
} // namespace sim
} // namespace vta
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Testing topi gemm operator for VTA"""
import os
import json
from collections import namedtuple
import numpy as np
import tvm
from tvm import autotvm
from tvm.contrib import util
from tvm.contrib.pickle_memoize import memoize
import topi
import topi.testing
import vta
from vta import program_fpga, reconfig_runtime
import vta.testing
from vta.testing import simulator
# FIXME: we need a custom clip operator to circumvent a pattern detection limitation
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
def run_gemm(env, remote, target,
batch_size, in_feat, out_feat,
check_correctness=True, print_ir=True,
samples=4):
# Perform packing only if we are targeting the accelerator
if "arm_cpu" in target.keys:
data_pack = False
elif "vta" in target.keys:
data_pack = True
# Derive shapes depending upon packing
a_shape = (batch_size, in_feat)
w_shape = (out_feat, in_feat)
if data_pack:
data_shape = (batch_size//env.BATCH, in_feat//env.BLOCK_IN,
env.BATCH, env.BLOCK_IN)
kernel_shape = (out_feat//env.BLOCK_OUT, in_feat//env.BLOCK_IN,
env.BLOCK_OUT, env.BLOCK_IN)
else:
data_shape = a_shape
kernel_shape = w_shape
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
# Define base computation schedule
with target:
res = topi.nn.dense(
data, kernel, out_dtype=env.acc_dtype)
res = topi.right_shift(res, 8)
res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
res = topi.cast(res, env.out_dtype)
# Derive base schedule
s = topi.generic.schedule_dense([res])
if print_ir:
print(vta.lower(s, [data, kernel, res], simple_mode=True))
# Derive number of ops
num_ops = 2 * batch_size * in_feat * out_feat
# @memoize("vta.tests.test_benchmark_topi.dense.verify")
def get_ref_data():
# derive min max for act, wgt types (max non inclusive)
a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1))
w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1))
a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype)
w_np = np.random.randint(w_min, w_max, size=w_shape).astype(kernel.dtype)
r_np = np.dot(a_np.astype(env.acc_dtype), w_np.T.astype(env.acc_dtype)).astype(env.acc_dtype)
return a_np, w_np, r_np
# Data in original format
data_np, kernel_np, res_ref = get_ref_data()
if data_pack:
data_np = data_np.reshape(
batch_size//env.BATCH, env.BATCH,
in_feat//env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
kernel_np = kernel_np.reshape(
out_feat//env.BLOCK_OUT, env.BLOCK_OUT,
in_feat//env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
# Build
if "vta" in target.keys:
mod = vta.build(s, [data, kernel, res],
target=target,
target_host=env.target_host,
name="dense")
else:
mod = tvm.build(s, [data, kernel, res],
target=target,
target_host=env.target_host,
name="dense")
temp = util.tempdir()
mod.save(temp.relpath("dense.o"))
remote.upload(temp.relpath("dense.o"))
f = remote.load_module("dense.o")
ctx = remote.context(str(target))
res_np = np.zeros(topi.util.get_const_tuple(res.shape)).astype(res.dtype)
data_arr = tvm.nd.array(data_np, ctx)
kernel_arr = tvm.nd.array(kernel_np, ctx)
res_arr = tvm.nd.array(res_np, ctx)
time_f = f.time_evaluator("dense", ctx, number=samples)
# In vta sim mode, collect simulator runtime statistics
stats = {}
cost = None
if env.TARGET == "sim":
# Check if we're in local RPC mode (allows us to rebuild the
# runtime on the fly when varying the VTA designs)
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
remote.get_function("vta.simulator.profiler_clear")()
cost = time_f(data_arr, kernel_arr, res_arr)
stats = json.loads(remote.get_function("vta.simulator.profiler_status")())
else:
simulator.clear_stats()
cost = time_f(data_arr, kernel_arr, res_arr)
stats = simulator.stats()
else:
cost = time_f(data_arr, kernel_arr, res_arr)
# Check correctness
correct = False
if check_correctness:
res_orig = res_arr.asnumpy()
if data_pack:
res_orig = res_orig.reshape(batch_size, out_feat)
res_ref = res_ref >> 8
res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1)
res_ref = res_ref.astype(env.out_dtype)
correct = np.allclose(res_orig, res_ref)
gops = (num_ops / cost.mean) / float(10 ** 9)
status = "PASSED" if correct else "FAILED"
if "arm_cpu" in target.keys:
device = "CPU"
elif "vta" in target.keys:
device = "VTA"
print("%s DENSE TEST %s: Time cost = %g sec/op, %g GOPS" % (device, status, cost.mean, gops))
return correct, cost, stats
def test_gemm(device="vta", batch=128, in_feat=128, out_feat=128):
def _run(env, remote):
if device == "vta":
target = env.target
if env.TARGET != "sim":
assert tvm.module.enabled("rpc")
program_fpga(remote, bitstream=None)
reconfig_runtime(remote)
elif device == "arm_cpu":
target = env.target_vta_cpu
with autotvm.tophub.context(target): # load pre-tuned schedule parameters
run_gemm(env, remote, target, batch, in_feat, out_feat)
vta.testing.run(_run)
if __name__ == "__main__":
test_gemm("vta", 16, 512, 1008)
VTA Tutorials
=============
This page contains tutorials about VTA and how to use TVM/Relay to target VTA.
Auto tuning
-------------
.. _tutorial-frontend:
Compile Deep Learning Models
----------------------------
Optimize Tensor Operators
-------------------------
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment