Commit 3818b2a2 by Thierry Moreau Committed by Tianqi Chen

[VTA][Relay] Relay Compilation + AutoTVM compatible operator libraries for VTA (#3135)

parent 813a3d52
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
......@@ -14,24 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Reuse conv2d schedule from ARM CPU"""
PROJROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../../" && pwd )"
import tvm
from topi.nn import conv2d, conv2d_alter_layout
from topi import generic
@conv2d.register(["vtacpu", "vta"])
def compute(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
return conv2d(*args, **kwargs)
@generic.schedule_conv2d_nchw.register(["vtacpu", "vta"])
def schedule(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
return generic.schedule_conv2d_nchw(*args, **kwargs)
@conv2d_alter_layout.register(["vtacpu", "vta"])
def alter(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
return conv2d_alter_layout(*args, **kwargs)
export PYTHONPATH=${PYTHONPATH}:${PROJROOT}/python:${PROJROOT}/vta/python
export PYTHONPATH=${PYTHONPATH}:/home/xilinx/pynq
python3 -m vta.exec.rpc_server --tracker fleet:9190 --key pynq
......@@ -215,7 +215,10 @@ subsection_order = ExplicitOrder(
'../tutorials/autotvm',
'../tutorials/dev',
'../tutorials/topi',
'../tutorials/deployment'])
'../tutorials/deployment',
'../vta/tutorials/frontend',
'../vta/tutorials/optimize',
'../vta/tutorials/autotvm'])
def generate_doxygen_xml(app):
"""Run the doxygen make commands if we're on the ReadTheDocs server"""
......
......@@ -78,7 +78,7 @@ reg.register_pattern("log_softmax", OpPattern.OPAQUE)
def compute_dense(attrs, inputs, _):
"""Compute definition of dense"""
if attrs.get_bool("use_bias"):
return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2])
return topi.nn.dense(inputs[0], inputs[1], inputs[2])
return topi.nn.dense(inputs[0], inputs[1])
@reg.register_schedule("dense")
......@@ -114,25 +114,25 @@ def compute_conv2d(attrs, inputs, _):
if groups == 1 and layout == 'NCHW4c' and inputs[0].dtype == 'int8':
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype=out_dtype)
dilation, layout, out_dtype)
# pylint: enable=assignment-from-no-return
elif groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)
elif layout == "NCHW" and \
groups == get_const_int(inputs[0].shape[1]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
elif layout in ["NCHW", "NCHW4c"]:
out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
out_dtype=out_dtype)
out_dtype)
elif layout == "NHWC" and \
kernel_layout == "HWOI" and \
groups == get_const_int(inputs[0].shape[3]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nhwc(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
else:
raise ValueError("not support arbitrary group number for now")
......
......@@ -65,18 +65,19 @@ def expr2graph(expr, target_ops, node_dict, node_list):
% op_name)
topi_funcs += OP2COMPUTE[op_name]
env.reset(topi_funcs)
_expr2graph_impl(expr, target_ops, node_dict, node_list)
task_pos = 0
for node_entry in node_list:
if node_entry["op"] in target_ops:
task_name, args = env.task_collection[task_pos]
task = autotvm.task.create(task_name, args,
target="llvm",
target_host=None,
template_key='direct')
node_entry["workloads"] = [task.workload]
node_entry["topi_op"] = [task_name]
task_pos += 1
with env:
_expr2graph_impl(expr, target_ops, node_dict, node_list)
task_pos = 0
for node_entry in node_list:
if node_entry["op"] in target_ops:
task_name, args = env.task_collection[task_pos]
task = autotvm.task.create(task_name, args,
target="llvm",
target_host=None,
template_key='direct')
node_entry["workloads"] = [task.workload]
node_entry["topi_op"] = [task_name]
task_pos += 1
def _expr2graph_impl(expr, target_ops, node_dict, node_list):
......
......@@ -86,7 +86,6 @@ class LocalBuilder(Builder):
build_func = ndk.create_shared
else:
raise ValueError("Invalid build_func" + build_func)
self.build_func = _wrap_build_func(build_func)
self.executor = LocalExecutor(timeout=timeout)
self.tmp_dir = tempfile.mkdtemp()
......@@ -360,8 +359,14 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
if cuda_arch:
set_cuda_target_arch(cuda_arch)
with build_config(**opts):
func = build(s, args, target_host=task.target_host)
# if target is vta, we need to use vta build
if hasattr(measure_input.target, 'device_name') and \
measure_input.target.device_name == 'vta':
import vta
func = vta.build(s, args, target_host=task.target_host)
else:
with build_config(**opts):
func = build(s, args, target_host=task.target_host)
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
......@@ -452,6 +457,12 @@ def run_through_rpc(measure_input, build_result,
try:
# upload built module
remote = request_remote(*remote_args)
# Program the FPGA every single time when targeting VTA
if hasattr(measure_input.target, 'device_name') and \
measure_input.target.device_name == 'vta':
from vta import program_fpga, reconfig_runtime
program_fpga(remote, None)
reconfig_runtime(remote)
remote.upload(build_result.filename)
func = remote.load_module(os.path.split(build_result.filename)[1])
ctx = remote.context(str(measure_input.target), 0)
......
......@@ -19,23 +19,22 @@
Decorator and utilities for the integration with TOPI and NNVM
"""
import threading
import warnings
import logging
from ... import target as _target
from .task import create
from .topi_integration import TaskExtractEnv
logger = logging.getLogger('autotvm')
def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
def extract_from_graph(graph, shape, dtype, target, symbols, params=None, target_host=None):
""" Extract tuning tasks from a nnvm graph.
This function collects tuning tasks by building the graph
with a "tracing" target and tracing all the calls to topi.
and trace all the calls to topi.
Parameters
----------
......@@ -49,6 +48,8 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
The compilation target
symbols : Array of nnvm.symbol
Array of nnvm symbols want to be tuned
params : dict of str to NDArray
The parameter dictionary.
target_host: tvm.target.Target
The host compilation target
......@@ -63,8 +64,8 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
env = TaskExtractEnv.get()
#NOTE: To add more symbols, you only need to change the following lists
#nnvm symbol -> topi compute
# NOTE: To add more symbols, you only need to change the following lists
# nnvm symbol -> topi compute
SYMBOL2TOPI = {
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
......@@ -81,29 +82,40 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
nnvm.compiler.engine.clear_cache()
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
logger.disabled = old_state
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
nnvm.compiler.engine.clear_cache()
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=nnvm.compiler.build,
args=(graph,
target,
shape,
dtype,
params,
target_host))
build_thread.start()
build_thread.join()
logger.disabled = old_state
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
print("[Warning] Invalid shape during AutoTVM task creation")
return tasks
def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_host=None):
def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, params, target_host=None):
""" Extract tuning tasks from multiple nnvm graphs.
This function is the multiple graph version of extract_from_graph
......@@ -120,6 +132,8 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_
The compilation target
symbols : Array of nnvm.symbol
Array of nnvm symbols want to be tuned
params : dict of str to NDArray
The parameter dictionary.
target_host: tvm.target.Target
The host compilation target
......@@ -152,25 +166,35 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
nnvm.compiler.engine.clear_cache()
for graph, shape, dtype in zip(graphs, shapes, dtypes):
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
logger.disabled = old_state
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
for graph, shape, dtype in zip(graphs, shapes, dtypes):
nnvm.compiler.engine.clear_cache()
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=nnvm.compiler.build,
args=(graph,
target,
shape,
dtype,
params,
target_host))
build_thread.start()
build_thread.join()
logger.disabled = old_state
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
print("[Warning] Invalid shape during AutoTVM task creation")
return tasks
......@@ -25,14 +25,30 @@ import warnings
import logging
from ... import target as _target
from .task import create
from .topi_integration import TaskExtractEnv
logger = logging.getLogger('autotvm')
# TODO(moreau89) find a more elegant way to build for VTAs
def _build(func,
target,
target_host,
params):
""" Helper to build VTA properly.
"""
from tvm import relay
if hasattr(target, 'device_name') and target.device_name == "vta":
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
import vta
with vta.build_config():
return relay.build(func, target, target_host, params)
# default case
return relay.build(func, target, target_host, params)
def extract_from_program(func, params, ops, target, target_host=None):
""" Extract tuning tasks from a relay program.
......@@ -57,11 +73,12 @@ def extract_from_program(func, params, ops, target, target_host=None):
task: Array of autotvm.task.Task
collected tasks
"""
env = TaskExtractEnv.get()
import tvm.relay.op
from tvm import relay
import topi
env = TaskExtractEnv.get()
# NOTE: To add more ops, you only need to change the following lists
# relay op -> topi compute
OP2TOPI = {
......@@ -81,30 +98,33 @@ def extract_from_program(func, params, ops, target, target_host=None):
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=relay.build, args=(func,
tracing_target,
target_host,
params))
build_thread.start()
build_thread.join()
logger.disabled = old_state
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=_build,
args=(func,
target,
target_host,
params))
build_thread.start()
build_thread.join()
logger.disabled = old_state
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
warnings.warn("Invalid shape during AutoTVM task creation")
return tasks
......@@ -155,30 +175,33 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
for func, param in zip(funcs, params):
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=relay.build, args=(func,
tracing_target,
target_host,
params))
build_thread.start()
build_thread.join()
logger.disabled = old_state
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
for func, param in zip(funcs, params):
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=my_build,
args=(func,
target,
target_host,
params))
build_thread.start()
build_thread.join()
logger.disabled = old_state
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
print("[Warning] Invalid shape during AutoTVM task creation")
return tasks
......@@ -27,7 +27,7 @@ tuple.
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
from ... import _api_internal, tensor, placeholder, create_schedule
from ... import _api_internal, tensor, placeholder
from .task import args_to_workload, dispatcher, register
from ..util import get_const_tuple
......@@ -73,6 +73,7 @@ def deserialize_args(args):
class TaskExtractEnv:
"""Global environment for extracting tuning tasks from nnvm graph"""
current = None
registered = None
def __init__(self, allow_duplicate=False):
import topi
......@@ -106,47 +107,65 @@ class TaskExtractEnv:
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
}
# function reflection for tracing
self.func_to_reflection = {
topi.nn.conv2d: lambda x: setattr(topi.nn, 'conv2d', x),
topi.nn.conv2d_NCHWc: lambda x: setattr(topi.nn, 'conv2d_NCHWc', x),
topi.nn.depthwise_conv2d_nchw: lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x),
topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x),
topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x),
topi.nn.bitserial_conv2d_nchw: lambda x: setattr(topi.nn, 'bitserial_conv2d_nchw', x),
topi.nn.bitserial_conv2d_nhwc: lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x),
topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x),
topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 'deformable_conv2d_nchw', x),
}
self.allow_duplicate = allow_duplicate
self._register_tracing()
self._register_topi_task()
self.task_collection = []
self.wanted_topi_funcs = list(self.topi_to_task.keys())
self.modified_funcs = []
def __enter__(self):
self.task_collection = []
self.modified_funcs = []
def _register_tracing(self):
"""Register tracing function to track the topi function call"""
# register topi compute for "tracing" target
for topi_compute in self.topi_to_task:
for topi_compute in self.wanted_topi_funcs:
def _local_scope(compute_func):
"""start a scope to hold the local function in for loop"""
@compute_func.register("tracing", )
def _tracing_topi_compute(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when" \
"kwargs is used in TOPI function call." \
def _tracing_wrapper(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when " \
"kwargs is used in TOPI function call. " \
"Please modify it to use only positional args."
if compute_func in self.wanted_topi_funcs: # record this call
key = (self.topi_to_task[compute_func], serialize_args(args))
if self.allow_duplicate or key not in self.task_collection:
self.task_collection.append(key)
return compute_func.fdefault(*args)
key = (self.topi_to_task[compute_func], serialize_args(args))
if self.allow_duplicate or key not in self.task_collection:
self.task_collection.append(key)
return compute_func(*args, **kwargs)
self.func_to_reflection[compute_func](_tracing_wrapper)
self.modified_funcs.append(compute_func)
_local_scope(topi_compute)
# register topi schedule for "tracing" target
for topi_compute in self.topi_to_task:
for topi_schedule in self.topi_to_schedule[topi_compute]:
def _local_scope_(schedule_func):
"""start a scope to hold the local function in for loop"""
return self
@schedule_func.register("tracing", )
def _tracing_topi_compute(outs):
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
return create_schedule([x.op for x in outs])
_local_scope_(topi_schedule)
def __exit__(self, exc_type, exc_val, exc_tb):
# revert modification
for func in self.modified_funcs:
self.func_to_reflection[func](func)
def _register_topi_task(self):
"""register tuning wrapper for topi function"""
import topi
# Avoid double registration for certain targets
if TaskExtractEnv.registered:
return
TaskExtractEnv.registered = True
# Tuning wrapper for topi functions
@register("topi_nn_conv2d")
def _topi_nn_conv2d(*args, **kwargs):
......@@ -190,7 +209,11 @@ class TaskExtractEnv:
def _topi_nn_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
data, weight, bias, _ = args
if len(args) > 2:
data, weight, bias = args[:3]
else:
data, weight = args
bias = None
C = topi.nn.dense(*args, **kwargs)
s = topi.generic.schedule_dense([C])
if bias is not None:
......
......@@ -44,7 +44,7 @@ PACKAGE_VERSION = {
'opencl': "v0.02",
'mali': "v0.05",
'vta': "v0.04",
'vta': "v0.05",
}
logger = logging.getLogger('autotvm')
......
......@@ -56,7 +56,7 @@ def compute_dense(attrs, inputs, out_type, target):
"""Compute definition of dense"""
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
return [topi.nn.dense(inputs[0], inputs[1], out_dtype=out_dtype)]
return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)]
@reg.register_schedule("nn.dense")
......@@ -119,21 +119,21 @@ def compute_conv2d(attrs, inputs, out_type, target):
if groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype=out_dtype)
dilation, layout, out_dtype)
elif layout == "NCHW" and \
get_const_int(inputs[1].shape[0]) == groups and \
get_const_int(inputs[1].shape[1]) == 1:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
elif layout == "NHWC" and \
kernel_layout == "HWOI" and\
get_const_int(inputs[1].shape[2]) == groups and \
get_const_int(inputs[1].shape[3]) == 1:
out = topi.nn.depthwise_conv2d_nhwc(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
elif layout in ['NCHW', 'NCHW4c']:
out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
out_dtype=out_dtype)
out_dtype)
else:
raise ValueError("not support arbitrary group number for now")
return [out]
......
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=unused-argument
#pylint: disable=unused-argument,inconsistent-return-statements
"""Internal module for registering attribute for annotation."""
from __future__ import absolute_import
import warnings
......@@ -171,7 +171,7 @@ def conv2d_rewrite(ref_call, new_args, ctx):
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
assert rhs_kind is None
......@@ -191,7 +191,8 @@ def check_to_skip():
return False
@register_annotate_function("nn.dense")
# TODO(tmoreau89,ziheng) need to include an option to turn off dense quant
# @register_annotate_function("nn.dense")
def dense_rewrite(ref_call, new_args, ctx):
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
dense will be quantized to weight field. Output would be in activation field."""
......@@ -201,13 +202,14 @@ def dense_rewrite(ref_call, new_args, ctx):
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
......@@ -222,6 +224,7 @@ def multiply_rewrite(ref_call, new_args, ctx):
if lhs_kind is None and rhs_kind is None:
return None
if lhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and rhs_kind is None:
# quantize lhs to INPUT field
if lhs_kind == QAnnotateKind.ACTIVATION:
......@@ -230,6 +233,7 @@ def multiply_rewrite(ref_call, new_args, ctx):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError
......@@ -244,22 +248,35 @@ def add_rewrite(ref_call, new_args, ctx):
if lhs_kind is None and rhs_kind is None:
return None
if lhs_kind is None and rhs_kind is not None:
# quantize lhs to INPUT field if it is normal expression
assert rhs_kind == QAnnotateKind.INPUT
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
if lhs_kind is not None and rhs_kind is None:
if isinstance(rhs_expr, _expr.Constant):
# quantize rhs to WEIGHT field if it is Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
assert lhs_kind == QAnnotateKind.ACTIVATION
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
# quantize rhs to INPUT field if both lhs and rhs are ACTIVATION
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError
if lhs_kind is not None and rhs_kind is not None:
if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT:
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
# quantize rhs to INPUT field if both lhs and rhs are ACTIVATION
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
@register_annotate_function("stop_fusion")
......@@ -294,6 +311,7 @@ register_annotate_function("clip", identity_rewrite)
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
register_annotate_function("nn.avg_pool2d", identity_rewrite)
register_annotate_function("annotation.stop_fusion", identity_rewrite)
def pool2d_rewrite(ref_call, new_args, ctx):
......@@ -307,6 +325,7 @@ def pool2d_rewrite(ref_call, new_args, ctx):
return None
if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
......@@ -314,6 +333,23 @@ def pool2d_rewrite(ref_call, new_args, ctx):
register_annotate_function("nn.max_pool2d", pool2d_rewrite)
@register_annotate_function("annotation.force_cast")
def force_cast_rewrite(ref_call, new_args, ctx):
"""Rewrite function to force cast"""
if check_to_skip():
return None
expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
return new_args[0]
if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate"""
......@@ -333,3 +369,71 @@ def concatenate_rewrite(ref_call, new_args, ctx):
expr_list[i] = attach_simulated_quantize(expr_list[i], QAnnotateKind.ACTIVATION)
expr = _forward_op(ref_call, [_expr.Tuple(expr_list)])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
# Graph rewrite function registration for VTA target
def register_vta_rewrite(op_name, frewrite=None, level=10):
def _register(func):
return _op.op._Register(op_name, "FQVTARewrite", func, level)
return _register(frewrite) if frewrite is not None else _register
@register_relay_node
class QVTAExpr(_expr.TempExpr):
def __init__(self, expr):
self.__init_handle_by_constructor__(
_quantize.make_vta_expr, expr)
def realize(self):
return _quantize.temp_expr_realize(self)
def vta_expr_check(expr):
if isinstance(expr, QVTAExpr):
return True, expr.expr
return False, expr
@register_vta_rewrite("nn.conv2d")
def conv2d_vta_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv2d for VTA target"""
actx = annotate_context()
if current_qconfig().skip_conv_layers is not None:
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if actx.conv2d_counter() in skipped_indices:
actx.count_conv2d()
return None
actx.count_conv2d()
data_cond, data = vta_expr_check(new_args[0])
kernel_cond, kernel = vta_expr_check(new_args[1])
assert not kernel_cond
if data_cond:
data = new_args[0].realize()
ret = _forward_op(ref_call, [data, kernel])
return QVTAExpr(ret)
def identity_vta_rewrite(ref_call, new_args, ctx):
cond, expr = vta_expr_check(new_args[0])
if cond:
return QVTAExpr(_forward_op(ref_call, [expr]))
return None
register_vta_rewrite("nn.relu", identity_vta_rewrite)
register_vta_rewrite("nn.max_pool2d", identity_vta_rewrite)
@register_vta_rewrite("add")
def add_vta_rewrite(ref_call, new_args, ctx):
"""Rewrite function for ewise add for VTA target"""
lhs_cond, lhs = vta_expr_check(new_args[0])
rhs_cond, rhs = vta_expr_check(new_args[1])
if lhs_cond and rhs_cond:
lhs = new_args[0].realize()
rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs])
elif lhs_cond and not rhs_cond:
return QVTAExpr(_forward_op(ref_call, [lhs, rhs]))
return None
......@@ -124,7 +124,9 @@ def current_qconfig():
"""Get the current quantization configuration."""
return _quantize._GetCurrentQConfig()
# TODO(tmoreau89, ZihengJiang) the skip parameters are
# hacky - we should explore a more future-proof way to
# skip operators based on pattern matching
def qconfig(**kwargs):
"""Configure the quantization behavior by setting config variables.
......@@ -279,6 +281,17 @@ def realize():
return _quantize.QuantizeRealize()
def rewrite_for_vta():
"""Performs rewriting for VTA target.
Returns
-------
ret: tvm.relay.Pass
The registered pass for VTA rewrite.
"""
return _quantize.QuantizeRewriteForVTA()
def _bind_params(func, params):
"""Bind the params to the expression.
"""
......@@ -337,15 +350,19 @@ def quantize(graph, params=None, dataset=None):
calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="QuantizeCalibrate")
quantize_seq = _transform.Sequential([annotate(),
calibrate_pass,
realize(),
_transform.FoldConstant()])
with annotate_context():
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
mod = quantize_seq(mod)
# Quantize pass list
quant_passes = [annotate(),
calibrate_pass,
realize(),
_transform.FoldConstant()]
if current_qconfig().store_lowbit_output:
quant_passes = [rewrite_for_vta()] + quant_passes
quantize_seq = _transform.Sequential(quant_passes)
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
mod = quantize_seq(mod)
return mod[mod.entry_func.name_hint]
......@@ -344,7 +344,7 @@ def generic_func(fdefault):
The function to be registered.
override : bool
Whether override existing registeration.
Whether override existing registration.
Returns
-------
......@@ -489,6 +489,13 @@ def rasp(options=None):
return arm_cpu('rasp3b', options)
def vta(model='unknown', options=None):
opts = ["-device=vta", '-keys=cpu', '-model=%s' % model]
opts = _merge_opts(opts, options)
ret = _api_internal._TargetCreate("ext_dev", *opts)
return ret
def create(target_str):
"""Get a target given target string.
......
......@@ -58,6 +58,7 @@ Target CreateTarget(const std::string& target_name,
std::string libs_flag = "-libs=";
std::string device_flag = "-device=";
std::string keys_flag = "-keys=";
for (auto& item : options) {
t->options_array.push_back(ir::StringImm::make(item));
......@@ -69,6 +70,13 @@ Target CreateTarget(const std::string& target_name,
}
} else if (item.find(device_flag) == 0) {
t->device_name = item.substr(device_flag.length());
t->keys_array.push_back(ir::StringImm::make(t->device_name));
} else if (item.find(keys_flag) == 0) {
std::stringstream ss(item.substr(keys_flag.length()));
std::string key_item;
while (std::getline(ss, key_item, ',')) {
t->keys_array.push_back(ir::StringImm::make(key_item));
}
}
}
......@@ -572,6 +580,9 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
Map<Target, Array<LoweredFunc>> updated_input;
for (const auto& it : inputs) {
auto target = Target::Create(it.first);
if (target->device_name == "vta") {
target = Target::Create("ext_dev");
}
updated_input.Set(target, it.second);
}
return build(updated_input, target_host, config);
......
......@@ -184,7 +184,7 @@ LoweredFunc MakeAPI(Stmt body,
for (Var v : undefined) {
os << " \'" << v->name_hint << "\' ";
}
os << " does not appeared in api_args";
os << " does not appear in api_args";
LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
}
return f;
......
......@@ -30,6 +30,7 @@
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
#include <utility>
#include <limits>
#include <mutex>
......@@ -170,7 +171,7 @@ class ScheduleGetter :
LOG(FATAL) << "not handled";
return tvm::Expr();
}
});
}, "compile_engine_const", topi::kBroadcast);
scalars_.push_back(value->op);
return {value};
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -83,5 +83,61 @@ TVM_ADD_FILELINE)
return {topi::identity(inputs[0])};
});
Expr ForceCast(Expr data) {
static const Op& op = Op::Get("annotation.force_cast");
return CallNode::make(op, {data}, Attrs{}, {});
}
RELAY_REGISTER_OP("annotation.force_cast")
.describe(R"code(Annotate an expression to force a cast.)code"
TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input data.")
.add_type_rel("Identity", IdentityRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
RELAY_REGISTER_OP("annotation.bitpack_start")
.describe(R"code(
Mark the start of bitpacking.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
RELAY_REGISTER_OP("annotation.bitpack_end")
.describe(R"code(
Mark the end of bitpacking.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
});
} // namespace relay
} // namespace tvm
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -379,6 +379,8 @@ Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array
Expr StopFusion(Expr data);
Expr ForceCast(Expr data);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
......@@ -90,7 +90,7 @@ RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.add_argument("clip_min", "Tensor", "lower bound. It should be a scalar")
.add_argument("clip_max", "Tensor", "upper bound. It should be a scalar")
.set_attrs_type_key("relay.attrs.SimulatedQuantizeAttrs")
.set_support_level(10)
.set_support_level(11)
.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel);
TVM_REGISTER_API("relay._quantize.simulated_quantize")
......@@ -133,6 +133,23 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
static_cast<QAnnotateKind>(args[1].operator int()));
});
TVM_REGISTER_API("relay._quantize.annotate")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr);
});
// =============
// realize pass
......@@ -371,7 +388,6 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
const Array<Expr>& args,
DataType* dtype_ptr,
Expr* scale_ptr) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();
std::vector<const QRealizeIntExprNode*> nptrs;
......@@ -385,19 +401,15 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
// unify the data type
CHECK_EQ(ref_args.size(), args.size());
DataType dtype = cfg->dtype_activation;
for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input);
if (cfg->store_lowbit_output) {
new_arg = StopFusion(new_arg);
}
ret.Set(i, Cast(new_arg, dtype));
}
DataType dtype;
if (nptrs[0]->dtype == cfg->dtype_activation) {
DataType dtype = cfg->dtype_activation;
ret.Set(1, Cast(ret[1], dtype));
} else if (nptrs[1]->dtype == cfg->dtype_input) {
DataType dtype = cfg->dtype_input;
ret.Set(0, Cast(ret[0], dtype));
} else {
LOG(FATAL) << "should not touch here.";
}
// unify the dom_scale
......@@ -504,10 +516,13 @@ RELAY_REGISTER_OP("nn.relu")
RELAY_REGISTER_OP("strided_slice")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("annotation.stop_fusion")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
Expr MaxPoolRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
/* \brief for unary operators which requantize its input to dtype_nbit */
Expr CastDtypeInputRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
......@@ -520,7 +535,7 @@ Expr MaxPoolRealize(const Call& ref_call,
}
RELAY_REGISTER_OP("nn.max_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", MaxPoolRealize);
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
Expr AvgPoolRealize(const Call& ref_call,
......@@ -543,6 +558,29 @@ Expr AvgPoolRealize(const Call& ref_call,
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
Expr ForceCastRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = Cast(n->data, cfg->dtype_input);
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("annotation.force_cast")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ForceCastRealize);
TVM_REGISTER_API("relay._quantize.realize")
.set_body_typed<Expr(Expr)>([](const Expr& e) {
Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
return ret;
});
// =============
// qconfig
......@@ -646,6 +684,51 @@ Pass QuantizeRealizePass() {
TVM_REGISTER_API("relay._quantize.QuantizeRealize")
.set_body_typed(QuantizeRealizePass);
Pass QuantizeRewriteForVTAPass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQVTARewrite", nullptr, nullptr));
};
return CreateFunctionPass(pass_func, 1, "QuantizeRewriteForVTA", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeRewriteForVTA")
.set_body_typed(QuantizeRewriteForVTAPass);
// =============
// Insert stop_fusion for vta.
Expr QVTAExprNode::Realize() const {
Expr ret = ForceCast(this->expr);
return StopFusion(ret);
}
QVTAExpr QVTAExprNode::make(Expr expr) {
auto rnode = make_node<QVTAExprNode>();
rnode->expr = expr;
return QVTAExpr(rnode);
}
TVM_REGISTER_API("relay._quantize.make_vta_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QVTAExprNode::make(args[0]);
});
TVM_REGISTER_API("relay._quantize.make_stop_fusion")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
return StopFusion(expr);
});
TVM_REGISTER_API("relay._quantize.temp_expr_realize")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
const QVTAExprNode* n = expr.as<QVTAExprNode>();
CHECK(n);
return n->Realize();
});
} // namespace quantize
} // namespace relay
} // namespace tvm
......@@ -72,6 +72,33 @@ class QAnnotateExprNode : public TempExprNode {
RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr);
/*!
* \brief TempExpr used to insert `force_cast` for VTA.
*/
class QVTAExpr;
/*!
* \brief TempExprNode used to insert `force_cast` for VTA.
*/
class QVTAExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr expr;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
}
TVM_DLL static QVTAExpr make(Expr expr);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QVTAExpr";
TVM_DECLARE_NODE_TYPE_INFO(QVTAExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QVTAExpr, QVTAExprNode, TempExpr);
/*! \brief TempExpr used during realize forward rewrite. */
class QRealizeExpr;
/*! \brief TempExpr representing integer. */
......
......@@ -47,7 +47,7 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v)
if (FindNodeRef(all_vars, v) < all_vars->data.size()) {
LOG(FATAL) << "Operate on iter var " << v
<< "that has already been splitted";
<< "that has already been split";
} else {
LOG(FATAL) << "Operate on iter var " << v
<< "that is not part of the schedule";
......
......@@ -35,6 +35,8 @@ from . import vision
from . import image
from . import sparse
from . import hls
# error reporting
from .util import InvalidShapeError
# not import testing by default
# because testing can have extra deps that are not necessary
# we can import them from test cases explicitly
......
......@@ -23,6 +23,10 @@ import tvm
from tvm.api import layout, bijective_layout
from . import tag
class InvalidShapeError(ValueError):
"""Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
pass
def traverse_inline(s, final_op, callback):
"""Traverse computation graph and do auto inline
......
......@@ -42,7 +42,7 @@ extern "C" {
/*! \brief Physically contiguous buffer size limit */
#ifndef VTA_MAX_XFER
#define VTA_MAX_XFER (1<<22)
#define VTA_MAX_XFER (1<<25)
#endif
/*! PAGE SIZE */
......
......@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""VTA specific buildin for runtime."""
from __future__ import absolute_import as _abs
......
......@@ -235,6 +235,10 @@ class Environment(object):
return self.dev.gemm
@property
def target(self):
return tvm.target.vta(model=self.TARGET)
@property
def target_host(self):
"""The target host"""
if self.TARGET == "pynq":
......@@ -243,6 +247,9 @@ class Environment(object):
return "llvm"
raise ValueError("Unknown target %s" % self.TARGET)
@property
def target_vta_cpu(self):
return tvm.target.arm_cpu(model=self.TARGET)
def get_env():
"""Get the current VTA Environment.
......
......@@ -66,6 +66,9 @@ def server_start():
@tvm.register_func("tvm.contrib.vta.init", override=True)
def program_fpga(file_name):
from pynq import xlnk
# Reset xilinx driver
xlnk.Xlnk().xlnk_reset()
path = tvm.get_global_func("tvm.rpc.server.workpath")(file_name)
env = get_env()
program_bitstream.bitstream_program(env.TARGET, path)
......
......@@ -77,8 +77,6 @@ class PkgConfig(object):
if self.target == "pynq":
self.ldflags = [
"-L/usr/lib",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/",
"-l:libcma.so"]
else:
self.ldflags = []
......
......@@ -84,4 +84,17 @@ def tsim_cycles():
"""
return tvm.get_global_func("tvm.vta.tsim.cycles")()
# debug flag to skip execution.
DEBUG_SKIP_EXEC = 1
def debug_mode(flag):
"""Set debug mode
Paramaters
----------
flag : int
The debug flag, 0 means clear all flags.
"""
tvm.get_global_func("vta.simulator.profiler_debug_mode")(flag)
LIBS = _load_lib()
......@@ -18,7 +18,7 @@
from __future__ import absolute_import as _abs
import os
from tvm import rpc
from tvm import rpc, autotvm
from ..environment import get_env
from . import simulator
......@@ -42,7 +42,7 @@ def run(run_func):
# the port it's listening to, e.g. 9090
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
remote = rpc.connect("localhost", local_rpc)
remote = rpc.connect("127.0.0.1", local_rpc)
run_func(env, remote)
else:
# Make sure simulation library exists
......@@ -54,12 +54,22 @@ def run(run_func):
elif env.TARGET == "pynq":
# Run on PYNQ if env variable exists
host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
port = int(os.environ.get("VTA_PYNQ_RPC_PORT", None))
if host and port:
remote = rpc.connect(host, port)
tracket_host = os.environ.get("TVM_TRACKER_HOST", None)
tracket_port = int(os.environ.get("TVM_TRACKER_PORT", None))
pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
pynq_port = int(os.environ.get("VTA_PYNQ_RPC_PORT", None))
# Run device from fleet node if env variables are defined
if tracket_host and tracket_port:
remote = autotvm.measure.request_remote(env.TARGET,
tracket_host,
tracket_port,
timeout=10000)
run_func(env, remote)
else:
raise RuntimeError(
"Please set the VTA_PYNQ_RPC_HOST and VTA_PYNQ_RPC_PORT environment variables")
# Next, run on PYNQ if env variables are defined
if pynq_host and pynq_port:
remote = rpc.connect(pynq_host, pynq_port)
run_func(env, remote)
else:
raise RuntimeError(
"Please set the VTA_PYNQ_RPC_HOST and VTA_PYNQ_RPC_PORT environment variables")
"""TVM TOPI connector, eventually most of these should go to TVM repo"""
from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
from . import bitpack
from .graphpack import graph_pack
from . import op
from . import vta_conv2d
from . import arm_conv2d
from . import vta_dense
# NNVM is deprecated for VTA
# from . import nnvm_bitpack
# from .nnvm_graphpack import nnvm_graph_pack
# from . import nnvm_op
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=ungrouped-imports
"""Bit packing operators"""
from __future__ import absolute_import as _abs
import tvm
from topi import util
from tvm.relay.op.op import register_compute, register_schedule
from tvm.relay.op.op import register_pattern, OpPattern
from tvm.relay.op.op import schedule_injective
def bitpack(data, bits, pack_type="int8", name="bitpack"):
"""Packs lowest dimension into format needed by VTA
Parameters
----------
pack_axis : int
index of the axis to pack in data
bit_axis : int
index of axis to place bit axis in resulting packed data
Returns
-------
packed : Tensor
The packed tensor.
"""
shape_vec = list(data.shape)
if pack_type == 'int8':
data_width = 8
elif pack_type == 'int16':
data_width = 16
elif pack_type == 'int32':
data_width = 32
else:
raise RuntimeError("Unknown pack type %s" % pack_type)
assert data_width % bits == 0
lanes = data_width // bits
# Data must be in multiples of the data_width
assert util.get_const_int(shape_vec[-1]) % lanes == 0, "Not a multiple of word size"
shape_vec[-1] = shape_vec[-1] // lanes
oshape = tuple(shape_vec)
def _bitpack(*indices):
ret = None
mask = tvm.const((1 << bits) - 1, pack_type)
for k in range(lanes):
idx = list(indices)
idx[-1] = idx[-1] * lanes + k
elem = data(*idx).astype(pack_type)
if k == 0:
ret = elem & mask
else:
val = (elem & mask) << tvm.const(k * bits, pack_type)
ret = ret | val
return ret
return tvm.compute(
oshape, _bitpack, name=name, tag='bitpack')
@register_compute("bitpack", level=15)
def compute_bitpack(attrs, inputs):
lanes = attrs.lanes
dtype = inputs[0].dtype
assert dtype == "int8"
width = 8
assert width % lanes == 0
bits = 8 // lanes
return bitpack(inputs[0], bits, dtype)
register_schedule("bitpack", schedule_injective)
register_pattern("bitpack", OpPattern.INJECTIVE)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""A Relay implementation of graph packing."""
from tvm import relay
from tvm.relay import op
from tvm.relay import ExprMutator
def _to_shape(shape):
return tuple(int(sh) for sh in shape)
def _pack_batch_channel(data, dshape, bfactor, cfactor):
"""Pack the data channel dimension.
"""
assert int(dshape[0]) % bfactor == 0
assert int(dshape[1]) % cfactor == 0
data = op.reshape(data,
newshape=(int(dshape[0]) // bfactor, bfactor,
int(dshape[1]) // cfactor, cfactor,
int(dshape[2]), int(dshape[3])))
data = op.transpose(
data, axes=(0, 2, 4, 5, 1, 3))
return data
def _unpack_batch_channel(data, old_shape):
"""Unpack the data channel dimension.
"""
data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
data = op.reshape(data, newshape=old_shape)
return data
def _pack_weight(data, dshape, cfactor):
"""Pack the weight into packed format.
"""
assert len(dshape) == 4
assert int(dshape[0]) % cfactor == 0
assert int(dshape[1]) % cfactor == 0
data = op.reshape(data,
newshape=(int(dshape[0]) // cfactor, cfactor,
int(dshape[1]) // cfactor, cfactor,
int(dshape[2]), int(dshape[3])))
data = op.transpose(
data, axes=(0, 2, 4, 5, 1, 3))
return data
def _pack_weight_conv2d_transpose(data, dshape, cfactor):
"""Pack the weight into packed format.
"""
dshape = _to_shape(dshape)
assert len(dshape) == 4
assert dshape[0] % cfactor == 0
assert dshape[1] % cfactor == 0
data = op.reshape(data,
newshape=(dshape[0] // cfactor, cfactor,
dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = op.transpose(
data, axes=(2, 0, 4, 5, 3, 1))
return data
def _pack_bias(data, dshape, dtype, bfactor, cfactor):
"""Pack the bias parameter.
"""
dshape = _to_shape(dshape)
assert len(dshape) == 3
assert dshape[0] % cfactor == 0
data = op.reshape(data,
newshape=(dshape[0] // cfactor,
cfactor, dshape[1],
dshape[2], 1))
data = op.transpose(
data, axes=(0, 2, 3, 4, 1))
# broadcast batch dimension to bfactor
data = op.broadcast_to(
data,
shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor))
return data
def _get_shape(node):
"""Get the shape of a node.
"""
return _to_shape(node.checked_type.shape)
class ExprPack(ExprMutator):
"""Visitor to perform graph packing on an AST.
"""
def __init__(self, bfactor, cfactor, weight_bits):
self.bfactor = bfactor
self.cfactor = cfactor
self.weight_bits = weight_bits
self.start_pack = False
# Cache Operator the algorithm matches against.
self.bitpack_start = op.op.get('annotation.bitpack_start')
self.bitpack_end = op.op.get('annotation.bitpack_end')
self.conv2d = op.op.get("nn.conv2d")
self.conv2d_transpose = op.op.get("nn.conv2d_transpose")
self.add = op.op.get("add")
self.bias_add = op.op.get("nn.bias_add")
self.number_of_conv2d = 0
super().__init__()
def visit_call(self, call):
# First visit the children.
oshape = _get_shape(call)
odtype = call.checked_type.dtype
input_types = [arg.checked_type for arg in call.args]
args = [self.visit(arg) for arg in call.args]
# Start and stop cases.
if call.op == self.bitpack_start:
assert not self.start_pack
self.start_pack = True
return _pack_batch_channel(args[0], oshape, self.bfactor, self.cfactor)
elif call.op == self.bitpack_end:
if self.start_pack:
self.start_pack = False
data = args[0]
data_shape = _get_shape(call.args[0])
return _unpack_batch_channel(data, data_shape)
else:
pass
if self.start_pack:
# Operator cases
if call.op == self.conv2d and odtype == 'int32':
self.number_of_conv2d += 1
assert 8 % self.weight_bits == 0
w_lanes = 8 // self.weight_bits
data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
kernel_layout = "OIHW%do%di" % (self.cfactor, self.cfactor)
data, weight = args
data_shape = _to_shape(input_types[0].shape)
kernel_shape = _to_shape(input_types[1].shape)
kernel = _pack_weight(weight, kernel_shape, self.cfactor)
# insert bit packing when necessary
if w_lanes != 1:
assert 8 % w_lanes == 0
kernel = op.bitpack(kernel, lanes=w_lanes)
conv2d = op.nn.conv2d(
data,
kernel,
strides=call.attrs.strides,
padding=call.attrs.padding,
dilation=call.attrs.dilation,
groups=call.attrs.groups,
channels=call.attrs.channels,
kernel_size=call.attrs.kernel_size,
data_layout=data_layout,
kernel_layout=kernel_layout,
out_dtype=call.attrs.out_dtype)
return conv2d
elif call.op == self.conv2d_transpose and odtype == 'int32':
self.number_of_conv2d += 1
assert 8 % self.weight_bits == 0
w_lanes = 8 // self.weight_bits
if self.start_pack:
data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
kernel_layout = "IOHW%di%do" % (self.cfactor, self.cfactor)
data, weight = args
data_shape = _to_shape(input_types[0].shape)
kernel_shape = _to_shape(input_types[1].shape)
kernel = _pack_weight_conv2d_transpose(weight, kernel_shape, self.cfactor)
conv2d = op.nn.conv2d_transpose(
data,
kernel,
strides=call.attrs.strides,
padding=call.attrs.padding,
dilation=call.attrs.dilation,
groups=call.attrs.groups,
channels=call.attrs.channels,
kernel_size=call.attrs.kernel_size,
data_layout=data_layout,
kernel_layout=kernel_layout,
output_padding=call.attrs.output_padding,
out_dtype=call.attrs.out_dtype)
return conv2d
elif call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape):
pass
elif call.op == self.add and len(input_types[1].shape) == 3:
data, bias = args
bias = _pack_bias(bias,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
return relay.Call(self.add, [data, bias])
elif self.start_pack and call.op == self.bias_add:
data, bias = args
bias = _pack_bias(bias,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
return relay.Call(self.add, [data, bias])
elif self.start_pack and call.op == op.op.get('cast') and \
input_types[0].dtype == 'int32':
cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs)
return relay.Call(op.op.get('copy'), [cast])
return relay.Call(
self.visit(call.op),
args,
call.attrs)
class BT(Exception):
pass
def get_subgraph(expr, start_name, stop_name):
""" We assume stop_name only appears once for simplicity.
This constraint will be lifted in the future.
bitpack_start and bitpack_end are both inclusive.
"""
bitpack_start = op.op.get('annotation.bitpack_start')
bitpack_end = op.op.get('annotation.bitpack_end')
anf = relay.ir_pass.to_a_normal_form(expr)
def _recursion(anf, start_found, stop_found):
""" Helper to obtain the subgraph.
"""
if isinstance(anf, relay.expr.Function):
return relay.expr.Function(anf.params,
_recursion(anf.body, start_found, stop_found),
anf.ret_type, anf.type_params, anf.attrs)
elif isinstance(anf, relay.expr.Let):
value = anf.value
if isinstance(value, relay.expr.Call):
if isinstance(value.op, relay.op.Op):
if value.op.name == start_name and not start_found:
value = relay.expr.Call(bitpack_start, [value])
start_found = True
elif value.op.name == stop_name:
raise BT()
try:
return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found))
except BT:
assert start_found
assert not stop_found
stop_found = True
value = relay.expr.Call(bitpack_end, [value])
# todo: check anf.body has no more stop_name beside that one
return relay.expr.Let(anf.var, value, anf.body)
else:
assert start_found
assert stop_found
return anf
annotated = _recursion(anf, False, False)
return relay.ir_pass.infer_type(relay.ir_pass.to_graph_normal_form(annotated))
def graph_pack(expr,
bfactor,
cfactor,
weight_bits,
start_name="nn.max_pool2d",
stop_name="nn.global_avg_pool2d"):
"""Pack the graph into batch&channel packed format.
Parameters
----------
expr : relay.Expr
The input program.
bfactor : int
The packing factor in batch
cfactor : int
The packing factor in channel
weight_bits: int
The bit-width of the weights.
start_name: str, optional
Start packing from certain known node.
stop_name: str, optional
Stop packing from certain known node.
Returns
-------
expr : Expr
The transformed expression.
"""
assert isinstance(expr, relay.Function)
expr = get_subgraph(expr, start_name, stop_name)
expr = relay.ir_pass.infer_type(expr)
packer = ExprPack(
bfactor, cfactor,
weight_bits)
expr = packer.visit(expr)
assert not packer.start_pack
return relay.ir_pass.infer_type(expr)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""Bit packing operators"""
from __future__ import absolute_import as _abs
import tvm
from topi import util
from nnvm.top import registry as reg, OpPattern
from nnvm.top.tensor import _fschedule_broadcast
def bitpack(data, bits, pack_type="int8", name="bitpack"):
"""Packs lowest dimension into format needed by VTA
Parameters
----------
pack_axis : int
index of the axis to pack in data
bit_axis : int
index of axis to place bit axis in resulting packed data
Returns
-------
packed : Tensor
The packed tensor.
"""
shape_vec = list(data.shape)
if pack_type == 'int8':
data_width = 8
elif pack_type == 'int16':
data_width = 16
elif pack_type == 'int32':
data_width = 32
else:
raise RuntimeError("Unknown pack type %s" % pack_type)
assert data_width % bits == 0
lanes = data_width // bits
# Data must be in multiples of the data_width
assert util.get_const_int(shape_vec[-1]) % lanes == 0, "Not a multiple of word size"
shape_vec[-1] = shape_vec[-1] // lanes
oshape = tuple(shape_vec)
def _bitpack(*indices):
ret = None
mask = tvm.const((1 << bits) - 1, pack_type)
for k in range(lanes):
idx = list(indices)
idx[-1] = idx[-1] * lanes + k
elem = data(*idx).astype(pack_type)
if k == 0:
ret = elem & mask
else:
val = (elem & mask) << tvm.const(k * bits, pack_type)
ret = ret | val
return ret
return tvm.compute(
oshape, _bitpack, name=name, tag='bitpack')
@reg.register_compute("bitpack", level=15)
def compute_bitpack(attrs, inputs, out):
lanes = attrs.get_int("lanes")
dtype = inputs[0].dtype
assert dtype == "int8"
width = 8
assert width % lanes == 0
bits = 8 // lanes
return bitpack(inputs[0], bits, dtype)
reg.register_schedule("bitpack", _fschedule_broadcast)
reg.register_pattern("bitpack", OpPattern.INJECTIVE)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""An NNVM implementation of graph packing."""
import nnvm
from nnvm.compiler import graph_attr, graph_util
def _pack_batch_channel(data, dshape, bfactor, cfactor):
"""Pack the data channel dimension.
"""
assert dshape[0] % bfactor == 0
assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // bfactor, bfactor,
dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(0, 2, 4, 5, 1, 3))
return data
def _unpack_batch_channel(data, old_shape):
"""Unpack the data channel dimension.
"""
data = nnvm.sym.transpose(data, axes=(0, 4, 1, 5, 2, 3))
data = nnvm.sym.reshape(data, shape=old_shape)
return data
def _pack_weight(data, dshape, cfactor):
"""Pack the weight into packed format.
"""
assert len(dshape) == 4
assert dshape[0] % cfactor == 0
assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // cfactor, cfactor,
dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(0, 2, 4, 5, 1, 3))
return data
def _pack_weight_conv2d_transpose(data, dshape, cfactor):
"""Pack the weight into packed format.
"""
assert len(dshape) == 4
assert dshape[0] % cfactor == 0
assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // cfactor, cfactor,
dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(2, 0, 4, 5, 3, 1))
return data
def _pack_bias(data, dshape, bfactor, cfactor):
"""Pack the bias parameter.
"""
assert len(dshape) == 3
assert dshape[0] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // cfactor,
cfactor, dshape[1],
dshape[2], 1))
data = nnvm.sym.transpose(
data, axes=(0, 2, 3, 4, 1))
# broadcast batch dimension to bfactor
data = nnvm.sym.broadcast_to(
data,
shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor))
return data
def _get_shape(sym, shape_dict):
"""Get the shape of a node.
"""
return graph_util.infer_shape(
nnvm.graph.create(sym), **shape_dict)[1][0]
def nnvm_graph_pack(graph,
shape_dict,
bfactor,
cfactor,
weight_bits,
start_name="max_pool2d0",
stop_name="global_avg_pool2d0"):
"""Pack the graph into batch&channel packed format.
Parameters
----------
graph : Graph
The input graph.
shape_dict : dict of str to shape
The input shape.
bfactor : int
The packing factor in batch
cfactor : int
The packing factor in channel
start_name: str, optional
Start packing from certain known node.
start_name: str, optional
Stop packing from certain known node.
Returns
-------
graph : Graph
The transformed graph.
"""
graph = graph_attr.set_shape_inputs(graph, shape_dict)
graph = graph.apply("InferShape")
shape = graph.json_attr("shape")
gidx = graph.index
node_map = {}
dset = set()
start_pack = False
for nid, node in enumerate(gidx.nodes):
children = [node_map[e[0]] for e in node["inputs"]]
ishape = [shape[gidx.entry_id(e)] for e in node["inputs"]]
oshape = shape[gidx.entry_id(nid, 0)]
attrs = node.get("attrs", {})
node_name = node["name"]
op_name = node["op"]
get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
*c, name=n_n, **a)
if op_name == "null":
new_node = nnvm.symbol.Variable(node_name)
if start_name and node_name == start_name:
start_pack = True
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
if start_pack and "_begin_state_" in node_name: # RNN -> CNN, pack
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
elif node_name == start_name:
assert not start_pack
start_pack = True
new_node = get_clone(children, op_name, node_name, attrs)
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
elif node_name == stop_name:
if start_pack:
start_pack = False
children[0] = _unpack_batch_channel(children[0], ishape[0])
new_node = getattr(nnvm.symbol, op_name)(
*children, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name == "conv2d" and attrs.get("out_dtype", None) == "int32":
assert 8 % weight_bits == 0
w_lanes = 8 // weight_bits
if start_pack:
attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
attrs["kernel_layout"] = "OIHW%do%di%dp" % (cfactor, cfactor, w_lanes)
data, weight = children
weight = _pack_weight(weight, ishape[1], cfactor)
# insert bit packing when necessary
if w_lanes != 1:
assert 8 % w_lanes == 0
weight = nnvm.sym.bitpack(weight, lanes=w_lanes)
new_node = nnvm.sym.conv2d(
data, weight, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name == "conv2d_transpose" and attrs.get("out_dtype", None) == "int32":
assert 8 % weight_bits == 0
w_lanes = 8 // weight_bits
if start_pack:
attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
attrs["kernel_layout"] = "IOHW%di%do%dp" % (cfactor, cfactor, w_lanes)
data, weight = children
weight = _pack_weight_conv2d_transpose(weight, ishape[1], cfactor)
new_node = nnvm.sym.conv2d_transpose(
data, weight, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("broadcast_") and tuple(ishape[0]) == tuple(ishape[1]):
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("broadcast") and len(ishape[1]) == 3:
if start_pack:
children[1] = _pack_bias(children[1], ishape[1], bfactor, cfactor)
new_node = getattr(nnvm.symbol, op_name)(
*children, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("elementwise_add"):
new_node = get_clone(children, op_name, node_name, attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
dset.add(op_name)
node_map[nid] = new_node
assert len(graph.index.output_entries) == 1
ret = node_map[graph.index.output_entries[0][0]]
if start_pack:
oshape = shape[graph.index.output_entries[0][0]]
ret = _unpack_batch_channel(ret, oshape)
graph = nnvm.graph.create(ret)
graph = graph_attr.set_shape_inputs(graph, shape_dict)
graph = graph.apply("InferShape")
return graph
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace for supporting packed_conv2d + ewise variant of nnvm."""
from __future__ import absolute_import as _abs
import logging
import tvm
import topi
from nnvm.top import registry as reg, OpPattern
from nnvm.top import nn as _nn
from .vta_conv2d import is_packed_layout
from ..environment import get_env
@tvm.register_func("nnvm.compiler.build_target", override=True)
def _build(funcs, target, target_host):
tvm_t = tvm.target.create(target)
if tvm_t.device_name == "vta":
return tvm.build(funcs, target="ext_dev", target_host=target_host)
if tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu":
return tvm.build(funcs, target=target_host)
return tvm.build(funcs, target=target)
@tvm.register_func("nnvm.compiler.lower", override=True)
def _lower(sch, inputs, func_name, graph):
import traceback
# pylint: disable=broad-except
try:
f = tvm.lower(sch, inputs, name=func_name)
if "quantized_conv2d" in func_name:
logging.info(graph.ir(join_entry_attrs=["shape"]))
except Exception:
msg = traceback.format_exc()
msg += "Error during compile graph\n"
msg += "--------------------------\n"
msg += graph.ir(join_entry_attrs=["shape"])
raise RuntimeError(msg)
return f if isinstance(
f, (tvm.container.Array, tuple, list)) else [f]
# override to force partition at copy
reg.register_pattern("copy", OpPattern.INJECTIVE, level=15)
@reg.register_compute("clip", level=15)
def compute_clip(attrs, inputs, _):
""" Clip operator. """
x = inputs[0]
a_min = attrs.get_float("a_min")
a_max = attrs.get_float("a_max")
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
with tvm.tag_scope(topi.tag.ELEMWISE):
x = tvm.compute(
x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(
x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
@reg.register_compute("conv2d", level=15)
def compute_conv2d(attrs, inputs, out):
""" Compute definition of conv2d """
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs["layout"]
out_dtype = attrs['out_dtype']
assert dilation == (1, 1), "not support dilate now"
if is_packed_layout(layout):
if groups == 1:
assert groups == 1
env = get_env()
assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now"
inputs = list(inputs)
assert inputs[1].dtype == "int8"
return topi.nn.conv2d(inputs[0], inputs[1], strides,
padding, dilation, layout, out_dtype)
return topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides,
padding, dilation, groups, out_dtype)
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.compute_conv2d(attrs, inputs, out)
@reg.register_schedule("conv2d", level=15)
def schedule_conv2d(attrs, outs, target):
""" Schedule definition of conv2d """
layout = attrs["layout"]
groups = attrs.get_int('groups')
if is_packed_layout(layout):
target = tvm.target.create(target)
if target.device_name == "vta":
if groups == 1:
return topi.generic.schedule_conv2d_nchw(outs)
return topi.generic.schedule_group_conv2d_nchw(outs)
elif str(target).startswith("llvm"):
return tvm.create_schedule([x.op for x in outs])
else:
raise RuntimeError("not support target %s" % target)
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.schedule_conv2d(attrs, outs, tvm.target.current_target())
@reg.register_alter_op_layout("conv2d", level=15)
def alter_conv2d_layout(attrs, inputs, out):
layout = attrs['layout']
if is_packed_layout(layout):
return None
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.alter_conv2d_layout(attrs, inputs, out)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, ungrouped-imports
"""Namespace for supporting packed_conv2d + ewise variant of nnvm."""
from __future__ import absolute_import as _abs
import tvm
import topi
from tvm.relay.op import op as reg
from tvm.relay.op.op import OpPattern
from tvm.relay.op.nn import _nn
from .vta_conv2d import is_packed_layout
from ..environment import get_env
# override to force partition at copy
reg.register_pattern("copy", OpPattern.INJECTIVE, level=15)
@reg.register_compute("clip", level=15)
def compute_clip(attrs, inputs, output_type, target):
""" Clip operator. """
x = inputs[0]
a_min = attrs.a_min
a_max = attrs.a_max
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
with tvm.tag_scope(topi.tag.ELEMWISE):
x = tvm.compute(
x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(
x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return [x]
@reg.register_compute("nn.conv2d", level=15)
def compute_conv2d(attrs, inputs, output_type, target):
""" Compute definition of conv2d """
padding = topi.util.get_const_tuple(attrs.padding)
strides = topi.util.get_const_tuple(attrs.strides)
dilation = tuple([int(d) for d in attrs.dilation])
groups = attrs.groups
layout = attrs.data_layout
out_dtype = attrs.out_dtype
if target.device_name == "vta":
assert dilation == (1, 1), "support for dilation limited to (1, 1)"
if is_packed_layout(layout):
if groups == 1:
assert groups == 1
env = get_env()
assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
assert env.LOG_WGT_WIDTH == 3, "only support 8bit wgt for now"
inputs = list(inputs)
assert inputs[1].dtype == "int8"
return [topi.nn.conv2d(inputs[0],
inputs[1],
strides,
padding,
dilation,
layout,
out_dtype)]
return [topi.nn.group_conv2d_nchw(inputs[0],
inputs[1],
strides,
padding,
dilation,
groups,
out_dtype)]
# If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.compute_conv2d(attrs, inputs, output_type, target)
# If VTA is not the target, default to _nn def
return _nn.compute_conv2d(attrs, inputs, output_type, target)
@reg.register_schedule("nn.conv2d", level=15)
def schedule_conv2d(attrs, outs, target):
""" Schedule definition of conv2d """
groups = attrs.groups
layout = attrs.data_layout
if target.device_name == "vta":
if is_packed_layout(layout):
target = tvm.target.create(target)
assert target.device_name == "vta"
if groups == 1:
return topi.generic.schedule_conv2d_nchw(outs)
return topi.generic.schedule_group_conv2d_nchw(outs)
# If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.schedule_conv2d(attrs, outs, tvm.target.current_target())
# If VTA is not the target, default to _nn def
return _nn.schedule_conv2d(attrs, outs, target)
@reg.register_compute("nn.dense", level=15)
def compute_dense(attrs, inputs, out_type, target):
"""Compute definition of dense"""
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
if target.device_name == "vta":
if inputs[0].shape == 4: # this implies the layout is packed
target = tvm.target.create(target)
return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)]
# If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.compute_dense(attrs, inputs, out_type, target)
# If VTA is not the target, default to _nn def
return _nn.compute_dense(attrs, inputs, out_type, target)
@reg.register_schedule("nn.dense", level=15)
def schedule_dense(attrs, outs, target):
"""Schedule definition of dense"""
if target.device_name == "vta":
if outs[0].shape == 4: # this implies the layout is packed
target = tvm.target.create(target)
assert target.device_name == "vta"
return topi.generic.schedule_dense(outs)
# If it's not packed, run on ARM CPU
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.schedule_dense(attrs, outs, tvm.target.current_target())
# If VTA is not the target, default to _nn def
return _nn.schedule_dense(attrs, outs, target)
......@@ -14,182 +14,49 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace for supporting packed_conv2d + ewise variant of nnvm."""
from __future__ import absolute_import as _abs
"""Conv2D operator declaration and schedule registration for VTA."""
from collections import namedtuple
import logging
import numpy as np
import tvm
from tvm import autotvm
import topi
from nnvm.top import registry as reg, OpPattern
from nnvm.top import nn as _nn
from ..environment import get_env
def is_packed_layout(layout):
"""Check if layout is packed layout"""
if layout == "NCHW":
return False
if "n" in layout and "c" in layout:
return True
return False
Workload = namedtuple("Conv2DWorkload",
['batch', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
def find_schedules(layer, vt_only=False, best_only=False):
""" Returns a schedule for a given a layer.
Parameters
----------
layer : Workload
Convolutional layer description.
vt_only : Boolean
Produce a schedule plan with virtual threading.
best_only : Boolean
Return the "best" schedule plan.
Returns
-------
fil_sched : list
List of valid schedules.
"""
# pylint: disable=too-many-nested-blocks
env = get_env()
# Helper function to get factors
def _find_factors(n):
factors = []
for f in range(1, n + 1):
if n % f == 0:
factors.append(f)
return factors
def _get_data_movement_byte(schedule, layer):
""" Estimate data movement in bytes for the schedule plan
"""
env = get_env()
b_f = schedule.b_factor
h_f = schedule.h_factor
w_f = schedule.w_factor
ci_f = schedule.ic_factor
co_f = schedule.oc_factor
# Derive data movement
inp_elem_sizeb = env.BATCH * env.BLOCK_IN * env.INP_WIDTH
wgt_elem_sizeb = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH
out_elem_sizeb = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH
input_tile_elems = b_f * \
((h_f - 1) * layer.hstride + layer.hkernel) * \
((w_f - 1) * layer.wstride + layer.wkernel) * ci_f
weight_tile_elems = layer.hkernel * layer.wkernel * ci_f
output_tile_elems = b_f * h_f * w_f * co_f
# Derive tiling factors
b_factor = layer.batch // (b_f * env.BATCH)
h_factor = (layer.height // layer.hstride) // h_f
w_factor = (layer.width // layer.wstride) // w_f
ci_factor = layer.in_filter // (ci_f * env.BLOCK_IN)
co_factor = layer.out_filter // (co_f * env.BLOCK_OUT)
# Compute input transaction count
input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
output_xfers = b_factor * h_factor * w_factor * co_factor
# Compute total transfer sizes
input_xfer_byte = input_tile_elems * input_xfers * inp_elem_sizeb // 8
weight_xfer_byte = weight_tile_elems * weight_xfers * wgt_elem_sizeb // 8
output_xfer_byte = output_tile_elems * output_xfers * out_elem_sizeb // 8
total_xfer_byte = input_xfer_byte + weight_xfer_byte + output_xfer_byte
return total_xfer_byte
# Scheduling exploration
batch_factors = _find_factors(layer.batch // env.BATCH)
height_factors = _find_factors(layer.height // layer.hstride)
width_factors = _find_factors(layer.width // layer.wstride)
cin_factors = _find_factors(layer.in_filter // env.BLOCK_IN)
cout_factors = _find_factors(layer.out_filter // env.BLOCK_OUT)
ht_factors = [1, 2]
cot_factors = [1, 2]
# Explore schedules
schedules = []
for b_f in batch_factors:
for h_f in height_factors:
for w_f in width_factors:
for ci_f in cin_factors:
for co_f in cout_factors:
# FIXME: 2D load pattern matching imposes restrictions on schedule
valid = (w_f == layer.width // layer.wstride) or \
(w_f != layer.width // layer.wstride and co_f == 1) and \
ci_f == 1
if valid:
schedules.append([b_f, h_f, w_f, ci_f, co_f])
# Filter the schedules that wouldn't work in the available BRAM sizes
inp_elem_sizeb = env.BATCH * env.BLOCK_IN * env.INP_WIDTH
wgt_elem_sizeb = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH
out_elem_sizeb = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH
inp_brams_sizeb = env.INP_BUFF_SIZE * 8
wgt_brams_sizeb = env.WGT_BUFF_SIZE * 8
out_brams_sizeb = env.OUT_BUFF_SIZE * 8
fil_sched = []
xfer_size = []
for sched in schedules:
b_f, h_f, w_f, ci_f, co_f = sched
for h_t in ht_factors:
for co_t in cot_factors:
# Make sure to filter cases where we apply threading on two axes
# or cases where the threading factors for h and co are not
# factors of h and co
if (h_t == 2 and co_t == 2) or (h_f % h_t != 0) or (co_f % co_t != 0):
continue
# Adjust tile sizes if threading is applied
h_f //= h_t
co_f //= co_t
# Derive tile sizes
input_tile_elems = b_f * \
((h_f - 1) * layer.hstride + layer.hkernel) * \
((w_f - 1) * layer.wstride + layer.wkernel) * ci_f
weight_tile_elems = layer.hkernel * layer.wkernel * ci_f * co_f
output_tile_elems = b_f * h_f * w_f * co_f
# Derive valid schedule filter
valid = True
# If in vitrual-threaded mode, only allow for threaded plans
valid &= (vt_only and (h_t == 2 or co_t == 2)) or not vt_only
# Check that we don't exceed input/weight/output capacity
valid &= input_tile_elems * inp_elem_sizeb <= inp_brams_sizeb // (co_t * h_t)
valid &= weight_tile_elems * wgt_elem_sizeb <= wgt_brams_sizeb
valid &= output_tile_elems * out_elem_sizeb <= out_brams_sizeb // (co_t * h_t)
# Make sure that we don't write to the same acc location within 2 consecutive cycles
valid &= h_f > 2 and w_f > 2
# TODO: check that we don't exceed instruction or micro-op count
if valid:
schedule = Schedule(b_factor=b_f, oc_factor=co_f, ic_factor=ci_f, h_factor=h_f,
w_factor=w_f, oc_nthread=co_t, h_nthread=h_t)
fil_sched.append(schedule)
xfer_size.append(_get_data_movement_byte(schedule, layer))
if best_only and xfer_size:
return [fil_sched[xfer_size.index(min(xfer_size))]]
return fil_sched
@autotvm.register_topi_compute(topi.nn.conv2d, 'vta', 'direct')
def _declaration_conv2d(cfg,
data,
kernel,
strides,
padding,
dilation,
layout,
out_dtype):
""" Packed conv2d function."""
if not is_packed_layout(layout):
raise topi.InvalidShapeError()
assert dilation == (1, 1)
def packed_conv2d(data,
kernel,
padding,
strides,
out_dtype="int32"):
""" Packed conv2d function.
"""
if padding[0]:
pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0, 0], name="pad_data")
else:
pad_data = data
assert len(data.shape) == 6
assert len(kernel.shape) == 6
oheight = topi.util.simplify((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1)
owidth = topi.util.simplify((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1)
oheight = topi.util.get_const_int((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1)
owidth = topi.util.get_const_int((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1)
oshape = (data.shape[0], kernel.shape[0], oheight, owidth, data.shape[4], kernel.shape[4])
ishape = topi.util.get_const_tuple(data.shape)
kshape = topi.util.get_const_tuple(kernel.shape)
assert data.dtype == "int8", data.dtype
assert kernel.dtype == "int8", kernel.dtype
d_i = tvm.reduce_axis((0, kshape[2]), name='d_i')
d_j = tvm.reduce_axis((0, kshape[3]), name='d_j')
k_o = tvm.reduce_axis((0, ishape[1]), name='k_o')
......@@ -201,167 +68,55 @@ def packed_conv2d(data,
pad_data[b_o, k_o, i*hstride+d_i, j*wstride+d_j, b_i, k_i].astype(out_dtype) *
kernel[c_o, k_o, d_i, d_j, c_i, k_i].astype(out_dtype),
axis=[k_o, d_i, d_j, k_i]),
name="res", tag="packed_conv2d")
return res
@tvm.register_func("nnvm.compiler.build_target", override=True)
def _build(funcs, target, target_host):
tvm_t = tvm.target.create(target)
if tvm_t.device_name == "vta":
return tvm.build(funcs, target="ext_dev", target_host=target_host)
if tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu":
return tvm.build(funcs, target=target_host)
return tvm.build(funcs, target=target)
@tvm.register_func("nnvm.compiler.lower", override=True)
def _lower(sch, inputs, func_name, graph):
import traceback
# pylint: disable=broad-except
try:
f = tvm.lower(sch, inputs, name=func_name)
if "quantized_conv2d" in func_name:
logging.info(graph.ir(join_entry_attrs=["shape"]))
except Exception:
msg = traceback.format_exc()
msg += "Error during compile graph\n"
msg += "--------------------------\n"
msg += graph.ir(join_entry_attrs=["shape"])
raise RuntimeError(msg)
return f if isinstance(
f, (tvm.container.Array, tuple, list)) else [f]
@reg.register_compute("clip", level=15)
def compute_clip(attrs, inputs, _):
""" Clip operator.
"""
x = inputs[0]
a_min = attrs.get_float("a_min")
a_max = attrs.get_float("a_max")
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
with tvm.tag_scope(topi.tag.ELEMWISE):
x = tvm.compute(
x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(
x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
# override to force partition at copy
reg.register_pattern("copy", OpPattern.INJECTIVE, level=15)
def is_packed_layout(layout):
"""Check if layout is packed layout"""
if layout == "NCHW":
return False
if "n" in layout and "c" in layout:
return True
return False
@reg.register_alter_op_layout("conv2d", level=15)
def alter_conv2d_layout(attrs, inputs, out):
layout = attrs['layout']
if is_packed_layout(layout):
return None
return _nn.alter_conv2d_layout(attrs, inputs, out)
@reg.register_compute("conv2d", level=15)
def compute_conv2d(attrs, inputs, out):
""" 2D convolution algorithm.
"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs["layout"]
out_dtype = attrs['out_dtype']
assert dilation == (1, 1), "not support dilate now"
if is_packed_layout(layout):
assert groups == 1
return packed_conv2d(inputs[0], inputs[1],
padding, strides, out_dtype=out_dtype)
return _nn.compute_conv2d(attrs, inputs, out)
@reg.register_schedule("conv2d", level=15)
def schedule_conv2d(attrs, outs, target):
""" 2D convolution schedule.
"""
layout = attrs["layout"]
if is_packed_layout(layout):
target = tvm.target.create(target)
if target.device_name == "vta":
return schedule_packed_conv2d(outs)
if str(target).startswith("llvm"):
return tvm.create_schedule([x.op for x in outs])
raise RuntimeError("not support target %s" % target)
return _nn.schedule_conv2d(attrs, outs, target)
name="res", tag="conv2d_dense")
cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) *
kshape[2] * kshape[3] * ishape[1] * ishape[-1])
def _get_workload(data, pad_data, kernel, output):
""" Get the workload structure.
"""
o_shape = topi.util.get_const_tuple(output.shape)
d_shape = topi.util.get_const_tuple(data.shape)
k_shape = topi.util.get_const_tuple(kernel.shape)
o_b, o_c, o_h, o_w, ob_blk, o_blk = o_shape
i_b, i_c, i_h, i_w, ib_blk, i_blk = d_shape
k_o, k_i, k_h, k_w, ko_blk, ki_blk = k_shape
# For now we need to assume that input channel blocking is the same
# as the output channel blocking
assert o_blk == i_blk
assert ob_blk == ib_blk
# Make sure that dimensions match
assert o_b == i_b
assert o_blk == ko_blk
assert i_blk == ki_blk
assert k_o == o_c
assert k_i == i_c
# Scale the channel size
i_c *= i_blk
o_c *= o_blk
if pad_data is not None:
p_shape = topi.util.get_const_tuple(pad_data.shape)
h_pad = (p_shape[2] - d_shape[2]) // 2
w_pad = (p_shape[3] - d_shape[3]) // 2
else:
h_pad, w_pad = 0, 0
h_str = (i_h + h_pad*2 - k_h) // (o_h - 1)
w_str = (i_w + w_pad*2 - k_w) // (o_w - 1)
return Workload(i_b, i_h, i_w, i_c, o_c, k_h, k_w, h_pad, w_pad, h_str, w_str)
_WL2PLAN = {}
return res
def schedule_packed_conv2d(outs):
""" Schedule the packed conv2d.
"""
@autotvm.register_topi_schedule(topi.generic.schedule_conv2d_nchw, 'vta', 'direct')
def _schedule_conv2d(cfg, outs):
assert len(outs) == 1
output = outs[0]
const_ops = []
ewise_inputs = []
ewise_ops = []
conv2d_res = []
assert output.dtype == "int8"
assert output.op.input_tensors[0].dtype == "int32"
assert "int" in output.op.input_tensors[0].dtype
def _traverse(op):
if topi.tag.is_broadcast(op.tag):
if not op.same_as(output.op):
ewise_ops.append(op)
if not op.axis:
const_ops.append(op)
else:
ewise_ops.append(op)
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.PlaceholderOp):
ewise_inputs.append((op, tensor))
else:
_traverse(tensor.op)
else:
assert op.tag == "packed_conv2d"
assert op.tag == "conv2d_dense"
conv2d_res.append(op)
_traverse(output.op)
assert len(conv2d_res) == 1
conv2d_stage = conv2d_res[0].output(0)
s = tvm.create_schedule(output.op)
##### space definition begin #####
b, c_o, x_i, x_j, _, _ = s[conv2d_stage].op.axis
c_i, _, _, _ = s[conv2d_stage].op.reduce_axis
cfg.define_split('tile_b', b, num_outputs=2)
cfg.define_split('tile_h', x_i, num_outputs=2)
cfg.define_split('tile_w', x_j, num_outputs=2)
cfg.define_split('tile_ci', c_i, num_outputs=2)
cfg.define_split('tile_co', c_o, num_outputs=2)
cfg.define_knob('oc_nthread', [1, 2])
cfg.define_knob('h_nthread', [1, 2])
###### space definition end ######
data, kernel = conv2d_stage.op.input_tensors
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
......@@ -370,21 +125,8 @@ def schedule_packed_conv2d(outs):
data = temp
else:
pad_data = None
wrkld = _get_workload(data, pad_data, kernel, output)
if wrkld in _WL2PLAN:
plan = _WL2PLAN[wrkld]
else:
plan = find_schedules(wrkld, vt_only=True, best_only=True)[0]
logging.info("Trying to find plan for %s", wrkld)
env = get_env()
load_inp = load_wgt = load_out = store_out = env.dma_copy
alu = env.alu
gemm = env.gemm
# schedule1
oshape = topi.util.get_const_tuple(output.shape)
s = tvm.create_schedule(output.op)
env = get_env()
# setup pad
if pad_data is not None:
......@@ -394,27 +136,26 @@ def schedule_packed_conv2d(outs):
cdata = s.cache_read(data, env.inp_scope, [conv2d_stage])
ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage])
s[conv2d_stage].set_scope(env.acc_scope)
# cache read input
cache_read_ewise = []
for consumer, tensor in ewise_inputs:
cache_read_ewise.append(
s.cache_read(tensor, env.acc_scope, [consumer]))
# set ewise scope
for op in ewise_ops:
s[op].set_scope(env.acc_scope)
s[op].pragma(s[op].op.axis[0], alu)
s[op].pragma(s[op].op.axis[0], env.alu)
# tile
oc_factor = (plan.oc_factor if plan.oc_factor
else plan.out_filter // env.BLOCK_OUT)
h_factor = (plan.h_factor if plan.h_factor else oshape[2])
w_factor = (plan.w_factor if plan.w_factor else oshape[3])
for op in const_ops:
s[op].compute_inline()
# tile
x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis
x_co0, x_co1 = s[output].split(x_co, factor=oc_factor)
x_i0, x_i1 = s[output].split(x_i, factor=h_factor)
x_j0, x_j1 = s[output].split(x_j, factor=w_factor)
x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co)
x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i)
x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j)
s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci)
store_pt = x_j0
......@@ -425,17 +166,17 @@ def schedule_packed_conv2d(outs):
for tensor in cache_read_ewise:
s[tensor].compute_at(s[output], store_pt)
s[tensor].pragma(s[tensor].op.axis[0], load_out)
s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy)
# virtual threading along output channel axes
if plan.oc_nthread > 1:
_, v_t = s[output].split(x_co0, factor=plan.oc_nthread)
if cfg['oc_nthread'].val > 1:
_, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val)
s[output].reorder(v_t, x_bo)
s[output].bind(v_t, tvm.thread_axis("cthread"))
# virtual threading along spatial rows
if plan.h_nthread > 1:
_, v_t = s[output].split(x_i0, factor=plan.h_nthread)
if cfg['h_nthread'].val > 1:
_, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val)
s[output].reorder(v_t, x_bo)
s[output].bind(v_t, tvm.thread_axis("cthread"))
......@@ -443,68 +184,14 @@ def schedule_packed_conv2d(outs):
k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis
s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i)
if plan.ic_factor:
k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ic_factor)
s[cdata].compute_at(s[conv2d_stage], k_o)
s[ckernel].compute_at(s[conv2d_stage], k_o)
k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o)
s[cdata].compute_at(s[conv2d_stage], k_o)
s[ckernel].compute_at(s[conv2d_stage], k_o)
# Use VTA instructions
s[cdata].pragma(s[cdata].op.axis[0], load_inp)
s[ckernel].pragma(s[ckernel].op.axis[0], load_wgt)
s[conv2d_stage].tensorize(x_bi, gemm)
s[output].pragma(x_co1, store_out)
return s
class Conv2DSchedule(object):
""" 2D convolution schedule object.
"""
def __init__(self,
b_factor=1,
oc_factor=1,
ic_factor=1,
h_factor=1,
w_factor=0,
oc_nthread=0,
h_nthread=0,
debug_sync=False):
self.b_factor = b_factor
self.oc_factor = oc_factor
self.ic_factor = ic_factor
self.h_factor = h_factor
self.w_factor = w_factor
self.oc_nthread = oc_nthread
self.h_nthread = h_nthread
self.debug_sync = debug_sync
def __str__(self):
return "{}.{}.{}.{}.{}.{}.{}".format(
self.b_factor, self.oc_factor, self.ic_factor,
self.h_factor, self.w_factor,
self.oc_nthread, self.h_nthread)
s[cdata].pragma(s[cdata].op.axis[0], env.dma_copy)
s[ckernel].pragma(s[ckernel].op.axis[0], env.dma_copy)
s[conv2d_stage].tensorize(x_bi, env.gemm)
s[output].pragma(x_co1, env.dma_copy)
Schedule = Conv2DSchedule
# Layer description of the ResNet18
RESNET = {
0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
}
for idx in RESNET:
f_schedules = find_schedules(RESNET[idx], vt_only=True, best_only=True)
if f_schedules:
scheds = f_schedules[0]
_WL2PLAN[RESNET[idx]] = scheds
else:
logging.warning("No valid schedule was found for the workload on current vta configuration")
break
return s
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""Dense operator declaration and schedule registration for VTA."""
import numpy as np
import tvm
from tvm import autotvm
import topi
from ..environment import get_env
def is_packed_layout(layout):
"""Check if layout is packed layout"""
if layout == "NCHW":
return False
if "n" in layout and "c" in layout:
return True
return False
@autotvm.register_topi_compute(topi.nn.dense, 'vta', 'direct')
def _declaration_dense(cfg,
data,
weight,
bias=None,
out_dtype=None):
"""Dense function declaration."""
# Make sure that the dense operator is packed
if len(data.shape) != 4 or len(weight.shape) != 4:
raise topi.InvalidShapeError()
# Derive shapes
ishape = topi.util.get_const_tuple(data.shape)
wshape = topi.util.get_const_tuple(weight.shape)
oshape = (data.shape[0], weight.shape[0], data.shape[2], weight.shape[2])
# Reduction axes (input channel)
assert ishape[1] == wshape[1]
assert ishape[3] == wshape[3]
k_o = tvm.reduce_axis((0, ishape[1]), name='k_o')
k_i = tvm.reduce_axis((0, ishape[3]), name='k_i')
res = tvm.compute(
oshape,
lambda b_o, c_o, b_i, c_i: tvm.sum(
data[b_o, k_o, b_i, k_i].astype(out_dtype) *
weight[c_o, k_o, c_i, k_i].astype(out_dtype),
axis=[k_o, k_i]),
name="res", tag="dense_pack")
cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) *
ishape[1] * ishape[3])
return res
@autotvm.register_topi_schedule(topi.generic.schedule_dense, 'vta', 'direct')
def _schedule_dense(cfg, outs):
"""Packed dense schedule."""
assert len(outs) == 1
output = outs[0]
const_ops = []
ewise_inputs = []
ewise_ops = []
dense_res = []
assert "int" in output.op.input_tensors[0].dtype
def _traverse(op):
if topi.tag.is_broadcast(op.tag):
if not op.same_as(output.op):
if not op.axis:
const_ops.append(op)
else:
ewise_ops.append(op)
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.PlaceholderOp):
ewise_inputs.append((op, tensor))
else:
_traverse(tensor.op)
else:
assert op.tag == "dense_pack"
dense_res.append(op)
_traverse(output.op)
assert len(dense_res) == 1
dense_stage = dense_res[0].output(0)
s = tvm.create_schedule(output.op)
##### space definition begin #####
b, c_o, _, _ = s[dense_stage].op.axis
c_i, _ = s[dense_stage].op.reduce_axis
cfg.define_split('tile_b', b, num_outputs=2)
cfg.define_split('tile_ci', c_i, num_outputs=2)
cfg.define_split('tile_co', c_o, num_outputs=2)
cfg.define_knob('oc_nthread', [1, 2])
###### space definition end ######
data, weight = dense_stage.op.input_tensors
env = get_env()
cdata = s.cache_read(data, env.inp_scope, [dense_stage])
cweight = s.cache_read(weight, env.wgt_scope, [dense_stage])
s[dense_stage].set_scope(env.acc_scope)
# cache read input
cache_read_ewise = []
for consumer, tensor in ewise_inputs:
cache_read_ewise.append(
s.cache_read(tensor, env.acc_scope, [consumer]))
# set ewise scope
for op in ewise_ops:
s[op].set_scope(env.acc_scope)
s[op].pragma(s[op].op.axis[0], env.alu)
for op in const_ops:
s[op].compute_inline()
# apply tiling for SRAM reuse
x_b, x_c, _, _ = s[output].op.axis
x_bo, x_bi = cfg['tile_b'].apply(s, output, x_b)
x_co, x_ci = cfg['tile_co'].apply(s, output, x_c)
s[output].reorder(x_bo, x_co, x_bi, x_ci)
store_pt = x_co
# set all compute scopes
s[dense_stage].compute_at(s[output], store_pt)
for op in ewise_ops:
s[op].compute_at(s[output], store_pt)
for tensor in cache_read_ewise:
s[tensor].compute_at(s[output], store_pt)
s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy)
# virtual threading along output channel axes
if cfg['oc_nthread'].val > 1:
_, v_t = s[output].split(x_co, factor=cfg['oc_nthread'].val)
s[output].reorder(v_t, x_bo)
s[output].bind(v_t, tvm.thread_axis("cthread"))
x_bo, x_co, x_bi, _ = s[dense_stage].op.axis
k_o, _ = s[dense_stage].op.reduce_axis
s[dense_stage].reorder(x_bo, k_o, x_co)
k_o, _ = cfg['tile_ci'].apply(s, dense_stage, k_o)
s[cdata].compute_at(s[dense_stage], k_o)
s[cweight].compute_at(s[dense_stage], k_o)
# Use VTA instructions
s[cdata].pragma(s[cdata].op.axis[0], env.dma_copy)
s[cweight].pragma(s[cweight].op.axis[0], env.dma_copy)
s[dense_stage].tensorize(x_bi, env.gemm)
s[output].pragma(x_ci, env.dma_copy)
return s
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tuning a single conv2d operator"""
from collections import namedtuple
import logging
import os
import tvm
from tvm import autotvm
from tvm.contrib.util import get_lower_ir
import topi
import vta
import vta.testing
env = vta.get_env()
Workload = namedtuple("Conv2DWorkload",
['batch', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
resnet_wkls = [
# Workloads of resnet18 on imagenet
# ('resnet-18.C1', Workload(1, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)),
('resnet-18.C2', Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
# ('resnet-18.C3', Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1)), # this layer does not appear in ResNet
('resnet-18.C4', Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
('resnet-18.C5', Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
('resnet-18.C6', Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
('resnet-18.C7', Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
('resnet-18.C8', Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
('resnet-18.C9', Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
('resnet-18.C10', Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
('resnet-18.C11', Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
('resnet-18.C12', Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
]
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype):
data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
bias_shape = (N//env.BATCH, CO//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
with tvm.target.vta():
res = topi.nn.conv2d(data, kernel, padding=padding, strides=strides, dilation=dilation,
layout='NCHW%dn%dc' % (env.BATCH, env.BLOCK_IN), out_dtype='int32')
res = topi.add(res, bias)
res = topi.right_shift(res, 8)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta':
s = topi.generic.schedule_conv2d_nchw([res])
else:
s = tvm.create_schedule([res.op])
return s, [data, kernel, bias, res]
if __name__ == '__main__':
# Logging config (for printing tuning log to the screen)
logging.basicConfig()
logging.getLogger('autotvm').setLevel(logging.DEBUG)
# Get tracker info from env
tracket_host = os.environ.get("TVM_TRACKER_HOST", None)
tracket_port = int(os.environ.get("TVM_TRACKER_PORT", None))
if not tracket_host or not tracket_port:
print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit()
for wl_name, wl in resnet_wkls:
# Workload parameters
N = wl.batch
CI = wl.in_filter
H = wl.height
W = wl.width
CO = wl.out_filter
KH = wl.hkernel
KW = wl.wkernel
strides = (wl.hstride, wl.wstride)
padding = (wl.hpad, wl.wpad)
dilation = (1, 1)
in_dtype = 'int8'
out_dtype = 'int32'
task = autotvm.task.create(conv2d, args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype),
target=tvm.target.vta(), target_host=env.target_host, template_key='direct')
print(task.config_space)
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(env.TARGET, tracket_host, tracket_port, number=4, repeat=3, timeout=10000,
check_correctness=True))
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=len(task.config_space),
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('conv2d.log')])
print("\nBest tuner config:")
print(tuner.best_config)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tuning a single dense operator"""
from collections import namedtuple
import logging
import os
import tvm
from tvm import autotvm
from tvm.contrib.util import get_lower_ir
import topi
import vta
import vta.testing
env = vta.get_env()
Workload = namedtuple("DenseWorkload",
['batch', 'in_filter', 'out_filter'])
resnet_wkls = [
# Workloads of resnet18 on imagenet
('resnet-18.dense', Workload(16, 512, 1024)),
]
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
def dense(N, CI, CO):
data_shape = (N//env.BATCH, CI//env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, env.BLOCK_OUT, env.BLOCK_IN)
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
with tvm.target.vta():
res = topi.nn.dense(data, kernel, None, 'int32')
res = topi.right_shift(res, 8)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta':
s = topi.generic.schedule_dense([res])
else:
s = tvm.create_schedule([res.op])
return s, [data, kernel, res]
if __name__ == '__main__':
# Logging config (for printing tuning log to the screen)
logging.basicConfig()
logging.getLogger('autotvm').setLevel(logging.DEBUG)
# Get tracker info from env
tracket_host = os.environ.get("TVM_TRACKER_HOST", None)
tracket_port = int(os.environ.get("TVM_TRACKER_PORT", None))
if not tracket_host or not tracket_port:
print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit()
for wl_name, wl in resnet_wkls:
# Workload parameters
N = wl.batch
CI = wl.in_filter
CO = wl.out_filter
task = autotvm.task.create(dense, args=(N, CI, CO),
target=tvm.target.vta(), target_host=env.target_host, template_key='direct')
print(task.config_space)
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(env.TARGET, tracket_host, tracket_port, number=4, repeat=3, timeout=10000,
check_correctness=True))
tuner = autotvm.tuner.RandomTuner(task)
tuner.tune(n_trial=len(task.config_space),
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('dense.log')])
print("\nBest tuner config:")
print(tuner.best_config)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Perform ResNet autoTVM tuning on VTA using Relay."""
import argparse, os, time
from mxnet.gluon.model_zoo import vision
import numpy as np
from PIL import Image
import topi
import tvm
from tvm import rpc, autotvm, relay
from tvm.autotvm.measure.measure_methods import request_remote
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib import graph_runtime, util, download
from tvm.contrib.debugger import debug_runtime
import vta
from vta.testing import simulator
from vta.top import graph_pack
from tvm.autotvm.task import extract_from_program
def parse_arguments():
parser = argparse.ArgumentParser(description='Train a model for image classification.')
parser.add_argument('--model', type=str, default='resnet18_v1', choices=['resnet18_v1'],
help='Input model name.')
parser.add_argument('--start-name', type=str, default='nn.max_pool2d',
help='The name of the node where packing starts')
parser.add_argument('--stop-name', type=str, default='nn.global_avg_pool2d',
help='The name of the node where packing stops')
parser.add_argument('--debug-profile', action='store_true',
help='Show layer-wise time cost profiling results')
parser.add_argument('--device', default='vta', choices=['vta', 'arm_cpu'],
help='Select device target')
parser.add_argument('--measurements', type=int, default=1,
help='Number of measurements during AutoTVM search')
parser.add_argument('--tuner', type=str, default="random",
help='AutoTVM search strategy')
parser.add_argument('--log-filename', type=str, default="resnet-18.log",
help='AutoTVM log file name')
return parser.parse_args()
def register_vta_tuning_tasks():
from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
# init autotvm env to register VTA operator
TaskExtractEnv()
@autotvm.task.register("topi_nn_conv2d", override=True)
def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
with tvm.target.vta():
res = topi.nn.conv2d(*args, **kwargs)
res = topi.right_shift(res, 8)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta':
s = topi.generic.schedule_conv2d_nchw([res])
else:
s = tvm.create_schedule([res.op])
return s, [A, W, res]
@autotvm.task.register("topi_nn_dense", override=True)
def _topi_nn_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
with tvm.target.vta():
res = topi.nn.dense(*args, **kwargs)
res = topi.right_shift(res, 8)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta':
s = topi.generic.schedule_dense([res])
else:
s = tvm.create_schedule([res.op])
return s, [A, W, res]
def compile_network(opt, env, target):
# Populate the shape and data type dictionary
dtype_dict = {"data": 'float32'}
shape_dict = {"data": (env.BATCH, 3, 224, 224)}
# Get off the shelf gluon model, and convert to relay
gluon_model = vision.get_model(opt.model, pretrained=True)
mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)
# Update shape and type dictionary
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params)
# Perform graph packing and constant folding for VTA target
if target.device_name == "vta":
assert env.BLOCK_IN == env.BLOCK_OUT
relay_prog = graph_pack(
relay_prog,
env.BATCH,
env.BLOCK_OUT,
env.WGT_WIDTH,
start_name=opt.start_name,
stop_name=opt.stop_name)
relay_prog = relay.ir_pass.fold_constant(relay_prog)
return relay_prog, params
def tune_tasks(tasks,
measure_option,
tuner='xgb',
n_trial=1000,
early_stopping=None,
log_filename='tuning.log',
use_transfer_learning=True,
try_winograd=True):
# create tmp log file
tmp_log_file = log_filename + ".tmp"
if os.path.exists(tmp_log_file):
os.remove(tmp_log_file)
for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
tuner_obj = XGBTuner(tsk, loss_type='rank')
elif tuner == 'ga':
tuner_obj = GATuner(tsk, pop_size=50)
elif tuner == 'random':
tuner_obj = RandomTuner(tsk)
elif tuner == 'gridsearch':
tuner_obj = GridSearchTuner(tsk)
else:
raise ValueError("Invalid tuner: " + tuner)
if use_transfer_learning:
if os.path.isfile(tmp_log_file):
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
# do tuning
n_trial_ = min(n_trial, len(tsk.config_space))
tuner_obj.tune(n_trial_,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial_, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
# pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_filename)
os.remove(tmp_log_file)
if __name__ == '__main__':
opt = parse_arguments()
# Make sure that TVM was compiled with RPC=1
assert tvm.module.enabled("rpc")
# Read in VTA environment
env = vta.get_env()
# Get remote from fleet node
tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
tracker_port = int(os.environ.get("TVM_TRACKER_PORT", None))
if not tracker_host or not tracker_port:
print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit()
# Get remote
if env.TARGET != "sim":
# Measure build start time
reconfig_start = time.time()
# Get remote from fleet node
remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000)
# Reconfigure the JIT runtime and FPGA.
# You can program the FPGA with your own custom bitstream
# by passing the path to the bitstream file instead of None.
vta.reconfig_runtime(remote)
vta.program_fpga(remote, bitstream=None)
# Report on reconfiguration time
reconfig_time = time.time() - reconfig_start
print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))
# In simulation mode, host the RPC server locally.
else:
remote = rpc.LocalSession()
# VTA target and execution context
target = env.target if opt.device == "vta" else env.target_vta_cpu
ctx = remote.ext_dev(0) if opt.device == "vta" else remote.cpu(0)
# Compile Relay program
print("Initial compile...")
relay_prog, params = compile_network(opt, env, target)
# Register VTA tuning tasks
register_vta_tuning_tasks()
# Perform task extraction on Relay program
print("Extracting tasks...")
tasks = extract_from_program(func=relay_prog,
params=params,
ops=(tvm.relay.op.nn.conv2d,),
target=target,
target_host=env.target_host)
# Perform Autotuning
print("Tuning...")
tuning_opt = {
'log_filename': opt.log_filename,
'tuner': opt.tuner,
'n_trial': 1e9,
'early_stopping': None,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(env.TARGET, tracker_host, tracker_port,
number=4, min_repeat_ms=150, repeat=opt.measurements, timeout=60,
check_correctness=True))
}
tune_tasks(tasks, **tuning_opt)
# Compile kernels with history best records
with autotvm.tophub.context(target, extra_files=[opt.log_filename]):
# Compile network
print("Compiling network with best tuning parameters...")
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
if target.device_name != "vta":
graph, lib, params = relay.build(
relay_prog, target=target,
params=params, target_host=env.target_host)
else:
with vta.build_config():
graph, lib, params = relay.build(
relay_prog, target=target,
params=params, target_host=env.target_host)
# Export library
temp = util.tempdir()
lib.save(temp.relpath("graphlib.o"))
remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o")
# If detailed runtime info is needed build with debug runtime
if opt.debug_profile:
m = debug_runtime.create(graph, lib, ctx)
else:
m = graph_runtime.create(graph, lib, ctx)
# Set the network parameters and synthetic input
image = tvm.nd.array(
(np.random.uniform(size=(1, 3, 224, 224))).astype('float32'))
m.set_input(**params)
m.set_input('data', image)
# Perform inference
timer = m.module.time_evaluator("run", ctx, number=4, repeat=opt.measurements)
tcost = timer()
prof_res = np.array(tcost.results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
# Display profile information
if opt.debug_profile:
m.run()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Perform ResNet autoTVM tuning on VTA using NNVM."""
import argparse
import os
import time
import numpy as np
import tvm
from tvm import rpc, autotvm
from tvm.autotvm.measure.measure_methods import request_remote
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib import graph_runtime, util
from tvm.contrib.download import download
import topi
import nnvm.compiler
import vta
import vta.testing
env = vta.get_env()
def register_vta_tuning_tasks():
from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
# init autotvm env to register VTA operator
TaskExtractEnv()
@autotvm.task.register("topi_nn_conv2d", override=True)
def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
with tvm.target.vta():
res = topi.nn.conv2d(*args, **kwargs)
res = topi.right_shift(res, 8)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta':
s = topi.generic.schedule_conv2d_nchw([res])
else:
s = tvm.create_schedule([res.op])
return s, [A, W, res]
def generate_graph(sym, params, target, target_host):
# Populate the shape and data type dictionary
shape_dict = {"data": (1, 3, 224, 224)}
dtype_dict = {"data": 'float32'}
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
# Apply NNVM graph optimization passes
sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym)
assert env.BLOCK_IN == env.BLOCK_OUT
sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)
# Compile NNVM graph
with nnvm.compiler.build_config(opt_level=3):
with vta.build_config():
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)
return graph, lib, params
def extract_tasks(sym, params, target, target_host):
# Populate the shape and data type dictionary
shape_dict = {"data": (1, 3, 224, 224)}
dtype_dict = {"data": 'float32'}
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
# Apply NNVM graph optimization passes
sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym)
assert env.BLOCK_IN == env.BLOCK_OUT
sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)
with vta.build_config():
tasks = autotvm.task.extract_from_graph(graph=sym, shape=shape_dict, dtype=dtype_dict, target=target,
params=params, symbols=(nnvm.sym.conv2d,), target_host=target_host)
return tasks
def download_model():
url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
categ_fn = 'synset.txt'
graph_fn = 'resnet18_qt8.json'
params_fn = 'resnet18_qt8.params'
data_dir = '_data'
if not os.path.exists(data_dir):
os.makedirs(data_dir)
for file in [categ_fn, graph_fn, params_fn]:
if not os.path.isfile(file):
download(os.path.join(url, file), os.path.join(data_dir, file))
sym = nnvm.graph.load_json(open(os.path.join(data_dir, graph_fn)).read())
params = nnvm.compiler.load_param_dict(open(os.path.join(data_dir, params_fn), 'rb').read())
return sym, params
def tune_tasks(tasks,
measure_option,
tuner='xgb',
n_trial=1000,
early_stopping=None,
log_filename='tuning.log',
use_transfer_learning=True,
try_winograd=True):
# create tmp log file
tmp_log_file = log_filename + ".tmp"
if os.path.exists(tmp_log_file):
os.remove(tmp_log_file)
for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
tuner_obj = XGBTuner(tsk, loss_type='rank')
elif tuner == 'ga':
tuner_obj = GATuner(tsk, pop_size=50)
elif tuner == 'random':
tuner_obj = RandomTuner(tsk)
elif tuner == 'gridsearch':
tuner_obj = GridSearchTuner(tsk)
else:
raise ValueError("Invalid tuner: " + tuner)
if use_transfer_learning:
if os.path.isfile(tmp_log_file):
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
# do tuning
n_trial_ = min(n_trial, len(tsk.config_space))
tuner_obj.tune(n_trial_,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial_, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
# pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_filename)
os.remove(tmp_log_file)
if __name__ == '__main__':
# Get tracker info from env
tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
tracker_port = int(os.environ.get("TVM_TRACKER_PORT", None))
if not tracker_host or not tracker_port:
print("Set your AutoTVM tracker node host and port variables to run the autotuner")
exit()
# Download model
sym, params = download_model()
# Register VTA tuning tasks
register_vta_tuning_tasks()
# Extract tasks
print("Extracting tasks...")
target = tvm.target.vta()
target_host = env.target_host
tasks = extract_tasks(sym, params, target, target_host)
# Perform Autotuning
print("Tuning...")
tuning_opt = {
'log_filename': 'resnet-18.log',
'tuner': 'random',
'n_trial': 1e9,
'early_stopping': None,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(env.TARGET, tracker_host, tracker_port,
number=4, repeat=3, timeout=60,
check_correctness=True))
}
tune_tasks(tasks, **tuning_opt)
# compile kernels with history best records
with autotvm.tophub.context(target, extra_files=[tuning_opt['log_filename']]):
# ResNet parameters
input_shape = (1, 3, 224, 224)
dtype = 'float32'\
# Compile network
print("Compiling network with best tuning parameters...")
graph, lib, params = generate_graph(sym, params, target, target_host)
input_shape = (1, 3, 224, 224)
dtype = 'float32'
# Export library
tmp = util.tempdir()
filename = "net.tar"
lib.export_library(tmp.relpath(filename))
# Upload module to device
print("Upload...")
remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000)
remote.upload(tmp.relpath(filename))
rlib = remote.load_module(filename)
# Upload parameters to device
ctx = remote.context(str(target), 0)
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module = graph_runtime.create(graph, rlib, ctx)
module.set_input('data', data_tvm)
module.set_input(**rparams)
# Evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=3, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
......@@ -908,12 +908,10 @@ class CommandQueue {
insn_queue_.InitSpace();
device_ = VTADeviceAlloc();
CHECK(device_ != nullptr);
printf("Initialize VTACommandHandle...\n");
}
~CommandQueue() {
VTADeviceFree(device_);
printf("Close VTACommandhandle...\n");
}
uint32_t GetElemBytes(uint32_t memory_id) {
......
......@@ -35,6 +35,11 @@
namespace vta {
namespace sim {
/*! \brief debug flag for skipping computation */
enum DebugFlagMask {
kSkipExec = 1
};
/*!
* \brief Helper class to pack and unpack bits
* Applies truncation when pack to low level bits.
......@@ -253,8 +258,12 @@ class SRAM {
return &(data_[index]);
}
// Execute the load instruction on this SRAM
void Load(const VTAMemInsn* op, DRAM* dram, uint64_t* load_counter) {
void Load(const VTAMemInsn* op,
DRAM* dram,
uint64_t* load_counter,
bool skip_exec) {
load_counter[0] += (op->x_size * op->y_size) * kElemBytes;
if (skip_exec) return;
DType* sram_ptr = data_ + op->sram_base;
uint8_t* dram_ptr = static_cast<uint8_t*>(dram->GetAddr(
op->dram_base * kElemBytes));
......@@ -325,6 +334,8 @@ class Profiler {
uint64_t gemm_counter{0};
/*! \brief instr counter for ALU ops */
uint64_t alu_counter{0};
/*! \brief set debug mode */
int64_t debug_flag{0};
/*! \brief clear the profiler */
void Clear() {
inp_load_nbytes = 0;
......@@ -335,6 +346,10 @@ class Profiler {
gemm_counter = 0;
alu_counter = 0;
}
/*! \return Whether we should skip execution. */
bool SkipExec() const {
return (debug_flag & DebugFlagMask::kSkipExec) != 0;
}
std::string AsJSON() {
std::ostringstream os;
......@@ -398,13 +413,15 @@ class Device {
void RunLoad(const VTAMemInsn* op) {
if (op->x_size == 0) return;
if (op->memory_type == VTA_MEM_ID_INP) {
inp_.Load(op, dram_, &(prof_->inp_load_nbytes));
inp_.Load(op, dram_, &(prof_->inp_load_nbytes), prof_->SkipExec());
} else if (op->memory_type == VTA_MEM_ID_WGT) {
wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes));
wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes), prof_->SkipExec());
} else if (op->memory_type == VTA_MEM_ID_ACC) {
acc_.Load(op, dram_, &(prof_->acc_load_nbytes));
acc_.Load(op, dram_, &(prof_->acc_load_nbytes), prof_->SkipExec());
} else if (op->memory_type == VTA_MEM_ID_UOP) {
uop_.Load(op, dram_, &(prof_->uop_load_nbytes));
// always load in uop, since uop is stateful
// subsequent non-debug mode exec can depend on it.
uop_.Load(op, dram_, &(prof_->uop_load_nbytes), false);
} else {
LOG(FATAL) << "Unknown memory_type=" << op->memory_type;
}
......@@ -416,7 +433,9 @@ class Device {
op->memory_type == VTA_MEM_ID_UOP) {
prof_->out_store_nbytes += (
op->x_size * op->y_size * VTA_BATCH * VTA_BLOCK_OUT * VTA_OUT_WIDTH / 8);
acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
if (!prof_->SkipExec()) {
acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
}
} else {
LOG(FATAL) << "Store do not support memory_type="
<< op->memory_type;
......@@ -425,7 +444,8 @@ class Device {
void RunGEMM(const VTAGemInsn* op) {
if (!op->reset_reg) {
prof_->gemm_counter += op->iter_out * op->iter_in;
prof_->gemm_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn);
if (prof_->SkipExec()) return;
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
......@@ -459,6 +479,7 @@ class Device {
}
}
} else {
if (prof_->SkipExec()) return;
// reset
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
......@@ -477,7 +498,6 @@ class Device {
}
void RunALU(const VTAAluInsn* op) {
prof_->alu_counter += op->iter_out * op->iter_in;
if (op->use_imm) {
RunALU_<true>(op);
} else {
......@@ -520,6 +540,8 @@ class Device {
template<bool use_imm, typename F>
void RunALULoop(const VTAAluInsn* op, F func) {
prof_->alu_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn);
if (prof_->SkipExec()) return;
for (int y = 0; y < op->iter_out; ++y) {
for (int x = 0; x < op->iter_in; ++x) {
for (int k = op->uop_bgn; k < op->uop_end; ++k) {
......@@ -566,6 +588,10 @@ TVM_REGISTER_GLOBAL("vta.simulator.profiler_status")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Profiler::ThreadLocal()->AsJSON();
});
TVM_REGISTER_GLOBAL("vta.simulator.profiler_debug_mode")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Profiler::ThreadLocal()->debug_flag = args[0];
});
} // namespace sim
} // namespace vta
......
......@@ -14,7 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Testing if we can generate code in topi style"""
"""Testing topi conv2d operator for VTA"""
import os
import json
from collections import namedtuple
import numpy as np
import tvm
from tvm import autotvm
......@@ -23,11 +30,32 @@ from tvm.contrib.pickle_memoize import memoize
import topi
import topi.testing
import vta
from vta import program_fpga, reconfig_runtime
import vta.testing
import numpy as np
Workload = vta.top.vta_conv2d.Workload
from vta.testing import simulator
Workload = namedtuple("Conv2DWorkload",
['batch', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
# ResNet18 workloads
resnet_wkls = [
# Workloads of resnet18 on imagenet
# ('resnet-18.C1', Workload(1, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)),
('resnet-18.C2', Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)),
# ('resnet-18.C3', Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1)), # this layer does not appear in ResNet
('resnet-18.C4', Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)),
('resnet-18.C5', Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)),
('resnet-18.C6', Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)),
('resnet-18.C7', Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)),
('resnet-18.C8', Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)),
('resnet-18.C9', Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)),
('resnet-18.C10', Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)),
('resnet-18.C11', Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)),
('resnet-18.C12', Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)),
]
# FIXME: we need a custom clip operator to circumvent a pattern detection limitation
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
......@@ -37,249 +65,168 @@ def my_clip(x, a_min, a_max):
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
def test_cpu_conv2d():
def run_cpu_conv2d(env, remote, key, batch_size, wl, profile=True):
data_shape = (batch_size, wl.in_filter, wl.height, wl.width)
kernel_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
res_conv = topi.nn.conv2d(
data, kernel, padding=(wl.hpad, wl.wpad),
strides=(wl.hstride, wl.wstride),
dilation=(1, 1),
out_dtype="int32")
res = topi.right_shift(res_conv, 8)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
# To compute number of ops, use a x2 factor for FMA
num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
a_shape = (batch_size, wl.in_filter, wl.height, wl.width)
w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
stride = (wl.hstride, wl.wstride)
data_dtype = data.dtype
kernel_dtype = kernel.dtype
acc_dtype = env.acc_dtype
assert wl.hpad == wl.wpad
padding = wl.hpad
@memoize("vta.tests.test_benchmark_topi.conv2d.cpu.verify_nhwc")
def get_ref_data():
a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype)
w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype)
a_np = np.abs(a_np)
w_np = np.abs(w_np)
b_np = topi.testing.conv2d_nchw_python(
a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype)
return a_np, w_np, b_np
def verify(s, check_correctness):
mod = tvm.build(s, [data, kernel, res],
target_host=env.target_host,
name="conv2d")
temp = util.tempdir()
mod.save(temp.relpath("conv2d.o"))
remote.upload(temp.relpath("conv2d.o"))
f = remote.load_module("conv2d.o")
# verify
ctx = remote.cpu(0)
# Data in original format
data_orig, kernel_orig, res_ref = get_ref_data()
res_shape = topi.util.get_const_tuple(res.shape)
res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_orig, ctx)
kernel_arr = tvm.nd.array(kernel_orig, ctx)
res_arr = tvm.nd.array(res_np, ctx)
time_f = f.time_evaluator("conv2d", ctx, number=5)
cost = time_f(data_arr, kernel_arr, res_arr)
res_unpack = res_arr.asnumpy()
if check_correctness:
assert wl.hpad == wl.wpad
stride = (wl.hstride, wl.wstride)
padding = wl.hpad
res_ref = res_ref >> 8
res_ref = np.clip(res_ref, 0, 127).astype("int8")
tvm.testing.assert_allclose(res_unpack, res_ref)
return cost
def conv_normal(print_ir):
print("----- CONV2D CPU End-to-End Test-------")
s = topi.generic.schedule_conv2d_nchw([res])
if print_ir:
print(tvm.lower(s, [data, kernel, res], simple_mode=True))
cost = verify(s, True)
gops = (num_ops / cost.mean) / float(10 ** 9)
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
conv_normal(False)
def _run(env, remote):
# ResNet18 workloads
resnet = {
# Workloads of resnet18 on imagenet
0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
}
batch_size = 1
for i in range(1, len(resnet)):
wl = resnet[i]
key = "resnet-cfg[%d]" % i
print("key=%s" % key)
print(wl)
with tvm.target.create("llvm -device=vtacpu"):
run_cpu_conv2d(env, remote, key, batch_size, wl)
# load pre-tuned operator parameters for ARM CPU
autotvm.tophub.check_backend('vta')
with autotvm.tophub.context('llvm -device=vtacpu'):
vta.testing.run(_run)
def test_vta_conv2d():
def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True):
data_shape = (batch_size//env.BATCH, wl.in_filter//env.BLOCK_IN,
wl.height, wl.width, env.BATCH, env.BLOCK_IN)
def run_conv2d(env, remote, wl, target,
check_correctness=True, print_ir=False,
samples=4):
# Workload assertions
assert wl.hpad == wl.wpad
# Perform packing only if we are targeting the accelerator
if "arm_cpu" in target.keys:
data_pack = False
layout = "NCHW"
elif "vta" in target.keys:
data_pack = True
layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN)
# Derive shapes depending upon packing
a_shape = (wl.batch, wl.in_filter, wl.height, wl.width)
w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
b_shape = (wl.batch, wl.out_filter, 1, 1)
if data_pack:
data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN,
wl.height, wl.width, env.BATCH, env.BLOCK_IN)
kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN,
wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN)
bias_shape = (1, wl.out_filter//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype)
res_conv = vta.top.packed_conv2d(
data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride))
res = topi.right_shift(res_conv, 8)
bias_shape = (wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT,
1, 1, env.BATCH, env.BLOCK_OUT)
else:
data_shape = a_shape
kernel_shape = w_shape
bias_shape = b_shape
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)
# Define base computation schedule
with target:
res = topi.nn.conv2d(
data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), (1, 1),
layout, env.acc_dtype)
res = topi.right_shift(res, 8)
res = topi.add(res, bias)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
# To compute number of ops, use a x2 factor for FMA
num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
a_shape = (batch_size, wl.in_filter, wl.height, wl.width)
w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
stride = (wl.hstride, wl.wstride)
data_dtype = data.dtype
kernel_dtype = kernel.dtype
acc_dtype = env.acc_dtype
assert wl.hpad == wl.wpad
padding = wl.hpad
@memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc")
def get_ref_data():
a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype)
w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype)
a_np = np.abs(a_np)
w_np = np.abs(w_np)
b_np = topi.testing.conv2d_nchw_python(
a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype)
return a_np, w_np, b_np
def verify(s, check_correctness):
mod = vta.build(s, [data, kernel, bias, res], "ext_dev",
env.target_host, name="conv2d")
temp = util.tempdir()
mod.save(temp.relpath("conv2d.o"))
remote.upload(temp.relpath("conv2d.o"))
f = remote.load_module("conv2d.o")
# verify
ctx = remote.ext_dev(0)
# Data in original format
data_orig, kernel_orig, res_ref = get_ref_data()
bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32")
bias_orig = np.abs(bias_orig)
data_packed = data_orig.reshape(
batch_size//env.BATCH, env.BATCH,
wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3))
kernel_packed = kernel_orig.reshape(
wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT,
wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
bias_packed = bias_orig.reshape(
1, wl.out_filter // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
res_shape = topi.util.get_const_tuple(res.shape)
res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_packed, ctx)
kernel_arr = tvm.nd.array(kernel_packed, ctx)
bias_arr = tvm.nd.array(bias_packed, ctx)
res_arr = tvm.nd.array(res_np, ctx)
time_f = f.time_evaluator("conv2d", ctx, number=5)
res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
res = topi.cast(res, env.out_dtype)
# Derive base schedule
s = topi.generic.schedule_conv2d_nchw([res])
if print_ir:
print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
# Derive number of ops
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
# @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc")
def get_ref_data():
# derive min max for act, wgt, and bias types (max non inclusive)
a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1))
w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1))
b_min, b_max = 0 - 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2), 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2)
a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype)
w_np = np.random.randint(w_min, w_max, size=w_shape).astype(kernel.dtype)
b_np = np.random.randint(b_min, b_max, size=b_shape).astype(env.acc_dtype)
r_np = topi.testing.conv2d_nchw_python(
a_np.astype(env.acc_dtype), w_np.astype(env.acc_dtype), (wl.hstride, wl.wstride), wl.hpad).astype(env.acc_dtype)
return a_np, w_np, b_np, r_np
# Data in original format
data_np, kernel_np, bias_np, res_ref = get_ref_data()
if data_pack:
data_np = data_np.reshape(
wl.batch//env.BATCH, env.BATCH,
wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3))
kernel_np = kernel_np.reshape(
wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT,
wl.in_filter//env.BLOCK_IN, env.BLOCK_IN,
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
bias_np = bias_np.reshape(
wl.batch // env.BATCH, wl.out_filter // env.BLOCK_OUT,
1, 1, env.BATCH, env.BLOCK_OUT)
# Build
if "vta" in target.keys:
mod = vta.build(s, [data, kernel, bias, res],
target=target,
target_host=env.target_host,
name="conv2d")
else:
mod = tvm.build(s, [data, kernel, bias, res],
target=target,
target_host=env.target_host,
name="conv2d")
temp = util.tempdir()
mod.save(temp.relpath("conv2d.o"))
remote.upload(temp.relpath("conv2d.o"))
f = remote.load_module("conv2d.o")
ctx = remote.context(str(target))
res_np = np.zeros(topi.util.get_const_tuple(res.shape)).astype(res.dtype)
data_arr = tvm.nd.array(data_np, ctx)
kernel_arr = tvm.nd.array(kernel_np, ctx)
bias_arr = tvm.nd.array(bias_np, ctx)
res_arr = tvm.nd.array(res_np, ctx)
time_f = f.time_evaluator("conv2d", ctx, number=samples)
# In vta sim mode, collect simulator runtime statistics
stats = {}
cost = None
if env.TARGET == "sim":
# Check if we're in local RPC mode (allows us to rebuild the
# runtime on the fly when varying the VTA designs)
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
remote.get_function("vta.simulator.profiler_clear")()
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
res_unpack = res_arr.asnumpy().transpose(
(0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
if check_correctness:
assert wl.hpad == wl.wpad
stride = (wl.hstride, wl.wstride)
padding = wl.hpad
res_ref = res_ref >> 8
res_ref += bias_orig.reshape(wl.out_filter, 1, 1)
res_ref = np.clip(res_ref, 0, 127).astype("int8")
tvm.testing.assert_allclose(res_unpack, res_ref)
return cost
def conv_normal(print_ir):
print("----- CONV2D End-to-End Test-------")
with vta.build_config():
s = vta.top.schedule_packed_conv2d([res])
if print_ir:
print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
cost = verify(s, True)
gops = (num_ops / cost.mean) / float(10 ** 9)
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
conv_normal(False)
stats = json.loads(remote.get_function("vta.simulator.profiler_status")())
else:
simulator.clear_stats()
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
stats = simulator.stats()
else:
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
# Check correctness
correct = False
if check_correctness:
res_orig = res_arr.asnumpy()
if data_pack:
res_orig = res_orig.transpose(
(0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width)
res_ref = res_ref >> 8
res_ref += bias_np.reshape(wl.out_filter, 1, 1)
res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1)
res_ref = res_ref.astype(env.out_dtype)
correct = np.allclose(res_orig, res_ref)
gops = (num_ops / cost.mean) / float(10 ** 9)
status = "PASSED" if correct else "FAILED"
if "arm_cpu" in target.keys:
device = "CPU"
elif "vta" in target.keys:
device = "VTA"
print("%s CONV2D TEST %s: Time cost = %g sec/op, %g GOPS" % (device, status, cost.mean, gops))
return correct, cost, stats
def test_conv2d(device="vta"):
def _run(env, remote):
# ResNet18 workloads
resnet = {
# Workloads of resnet18 on imagenet
0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
}
batch_size = 1
for i in range(0, len(resnet)):
wl = resnet[i]
key = "resnet-cfg[%d]" % i
print("key=%s" % key)
print(wl)
run_vta_conv2d(env, remote, key, batch_size, wl)
if device == "vta":
target = env.target
if env.TARGET != "sim":
assert tvm.module.enabled("rpc")
program_fpga(remote, bitstream=None)
reconfig_runtime(remote)
elif device == "arm_cpu":
target = env.target_vta_cpu
with autotvm.tophub.context(target): # load pre-tuned schedule parameters
for _, wl in resnet_wkls:
print(wl)
run_conv2d(env, remote, wl, target)
vta.testing.run(_run)
if __name__ == "__main__":
test_cpu_conv2d()
test_vta_conv2d()
test_conv2d(device="arm_cpu")
test_conv2d(device="vta")
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Testing topi gemm operator for VTA"""
import os
import json
from collections import namedtuple
import numpy as np
import tvm
from tvm import autotvm
from tvm.contrib import util
from tvm.contrib.pickle_memoize import memoize
import topi
import topi.testing
import vta
from vta import program_fpga, reconfig_runtime
import vta.testing
from vta.testing import simulator
# FIXME: we need a custom clip operator to circumvent a pattern detection limitation
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
def run_gemm(env, remote, target,
batch_size, in_feat, out_feat,
check_correctness=True, print_ir=True,
samples=4):
# Perform packing only if we are targeting the accelerator
if "arm_cpu" in target.keys:
data_pack = False
elif "vta" in target.keys:
data_pack = True
# Derive shapes depending upon packing
a_shape = (batch_size, in_feat)
w_shape = (out_feat, in_feat)
if data_pack:
data_shape = (batch_size//env.BATCH, in_feat//env.BLOCK_IN,
env.BATCH, env.BLOCK_IN)
kernel_shape = (out_feat//env.BLOCK_OUT, in_feat//env.BLOCK_IN,
env.BLOCK_OUT, env.BLOCK_IN)
else:
data_shape = a_shape
kernel_shape = w_shape
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
# Define base computation schedule
with target:
res = topi.nn.dense(
data, kernel, out_dtype=env.acc_dtype)
res = topi.right_shift(res, 8)
res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
res = topi.cast(res, env.out_dtype)
# Derive base schedule
s = topi.generic.schedule_dense([res])
if print_ir:
print(vta.lower(s, [data, kernel, res], simple_mode=True))
# Derive number of ops
num_ops = 2 * batch_size * in_feat * out_feat
# @memoize("vta.tests.test_benchmark_topi.dense.verify")
def get_ref_data():
# derive min max for act, wgt types (max non inclusive)
a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1))
w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1))
a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype)
w_np = np.random.randint(w_min, w_max, size=w_shape).astype(kernel.dtype)
r_np = np.dot(a_np.astype(env.acc_dtype), w_np.T.astype(env.acc_dtype)).astype(env.acc_dtype)
return a_np, w_np, r_np
# Data in original format
data_np, kernel_np, res_ref = get_ref_data()
if data_pack:
data_np = data_np.reshape(
batch_size//env.BATCH, env.BATCH,
in_feat//env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
kernel_np = kernel_np.reshape(
out_feat//env.BLOCK_OUT, env.BLOCK_OUT,
in_feat//env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
# Build
if "vta" in target.keys:
mod = vta.build(s, [data, kernel, res],
target=target,
target_host=env.target_host,
name="dense")
else:
mod = tvm.build(s, [data, kernel, res],
target=target,
target_host=env.target_host,
name="dense")
temp = util.tempdir()
mod.save(temp.relpath("dense.o"))
remote.upload(temp.relpath("dense.o"))
f = remote.load_module("dense.o")
ctx = remote.context(str(target))
res_np = np.zeros(topi.util.get_const_tuple(res.shape)).astype(res.dtype)
data_arr = tvm.nd.array(data_np, ctx)
kernel_arr = tvm.nd.array(kernel_np, ctx)
res_arr = tvm.nd.array(res_np, ctx)
time_f = f.time_evaluator("dense", ctx, number=samples)
# In vta sim mode, collect simulator runtime statistics
stats = {}
cost = None
if env.TARGET == "sim":
# Check if we're in local RPC mode (allows us to rebuild the
# runtime on the fly when varying the VTA designs)
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
remote.get_function("vta.simulator.profiler_clear")()
cost = time_f(data_arr, kernel_arr, res_arr)
stats = json.loads(remote.get_function("vta.simulator.profiler_status")())
else:
simulator.clear_stats()
cost = time_f(data_arr, kernel_arr, res_arr)
stats = simulator.stats()
else:
cost = time_f(data_arr, kernel_arr, res_arr)
# Check correctness
correct = False
if check_correctness:
res_orig = res_arr.asnumpy()
if data_pack:
res_orig = res_orig.reshape(batch_size, out_feat)
res_ref = res_ref >> 8
res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1)
res_ref = res_ref.astype(env.out_dtype)
correct = np.allclose(res_orig, res_ref)
gops = (num_ops / cost.mean) / float(10 ** 9)
status = "PASSED" if correct else "FAILED"
if "arm_cpu" in target.keys:
device = "CPU"
elif "vta" in target.keys:
device = "VTA"
print("%s DENSE TEST %s: Time cost = %g sec/op, %g GOPS" % (device, status, cost.mean, gops))
return correct, cost, stats
def test_gemm(device="vta", batch=128, in_feat=128, out_feat=128):
def _run(env, remote):
if device == "vta":
target = env.target
if env.TARGET != "sim":
assert tvm.module.enabled("rpc")
program_fpga(remote, bitstream=None)
reconfig_runtime(remote)
elif device == "arm_cpu":
target = env.target_vta_cpu
with autotvm.tophub.context(target): # load pre-tuned schedule parameters
run_gemm(env, remote, target, batch, in_feat, out_feat)
vta.testing.run(_run)
if __name__ == "__main__":
test_gemm("vta", 16, 512, 1008)
VTA Tutorials
=============
This page contains tutorials about VTA and how to use TVM/Relay to target VTA.
Auto tuning
-------------
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Auto-tuning a convolutional network on VTA
==========================================
**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
Auto-tuning for a specific accelerator design is critical for getting the best
performance for any given operator. This is a tutorial showcases how to tune a
whole convolutional network on VTA.
The operator implementation for VTA in TVM is written in template form.
The template has many tunable knobs (tile factor, virtual threads, etc).
We will tune all convolution operators in the neural network. After tuning,
we produce a log file which stores the best schedule parameters for all tuned
operators. When the TVM compiler compiles these operators, it will query this
log file to get the best knob parameters.
"""
######################################################################
# Install dependencies
# --------------------
# To use the autotvm package in tvm, we need to install some extra dependencies.
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user psutil xgboost tornado mxnet requests pillow
#
# To make TVM run faster during tuning, it is recommended to use cython
# as FFI of TVM. In the root directory of TVM, execute
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user cython
# sudo make cython3
#
# Now return to python code. Import packages.
import os
from mxnet.gluon.model_zoo import vision
import numpy as np
from PIL import Image
import topi
import tvm
from tvm import rpc, autotvm, relay
from tvm.contrib import graph_runtime, util, download
from tvm.autotvm.measure.measure_methods import request_remote
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
import vta
from vta.testing import simulator
from vta.top import graph_pack
#################################################################
# Compile network
# ---------------
# Perform vta-specific compilation with Relay from a Gluon model
def compile_network(env, target, model, start_pack, stop_pack):
# Populate the shape and data type dictionary
dtype_dict = {"data": 'float32'}
shape_dict = {"data": (env.BATCH, 3, 224, 224)}
# Get off the shelf gluon model, and convert to relay
gluon_model = vision.get_model(model, pretrained=True)
mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)
# Update shape and type dictionary
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params)
# Perform graph packing and constant folding for VTA target
if target.device_name == "vta":
assert env.BLOCK_IN == env.BLOCK_OUT
relay_prog = graph_pack(
relay_prog,
env.BATCH,
env.BLOCK_OUT,
env.WGT_WIDTH,
start_name=start_pack,
stop_name=stop_pack)
relay_prog = relay.ir_pass.fold_constant(relay_prog)
return relay_prog, params
#################################################################
# Start RPC Tracker
# -----------------
# TVM uses an RPC session to communicate with Pynq boards.
# During tuning, the tuner will send the generated code to the board and
# measure the speed of code on the board.
#
# To scale up tuning, TVM uses an RPC Tracker to manage multiple devices.
# The RPC Tracker is a centralized master node. We can register all devices to
# the tracker. For example, if we have 10 Pynq boards, we can register all of them
# to the tracker, and run 10 measurements in parallel, accelerating the tuning process.
#
# To start an RPC tracker, run this command on the host machine. The tracker is
# required during the whole tuning process, so we need to open a new terminal for
# this command:
#
# .. code-block:: bash
#
# python -m tvm.exec.rpc_tracker --host=0.0.0.0 --port=9190
#
# The expected output is:
#
# .. code-block:: bash
#
# INFO:RPCTracker:bind to 0.0.0.0:9190
#################################################################
# Register devices to RPC Tracker
# -----------------------------------
# Now we can register our devices to the tracker. The first step is to
# build the TVM runtime for the Pynq devices.
#
# Follow `this section <https://docs.tvm.ai/vta/install.html#pynq-side-rpc-server-build-deployment>`_
# to build the TVM runtime on the device. Then register the device to the tracker with:
#
# .. code-block:: bash
#
# python -m tvm.exec.rpc_server --tracker=[HOST_IP]:9190 --key=pynq
#
# (replace :code:`[HOST_IP]` with the IP address of your host machine)
#
# After registering devices, we can confirm it by querying the rpc_tracker:
#
# .. code-block:: bash
#
# python -m tvm.exec.query_rpc_tracker --host=0.0.0.0 --port=9190
#
# For example, if we have 6 Pynq boards and 11 Raspberry Pi 3B,
# the output can be
#
# .. code-block:: bash
#
# Queue Status
# ----------------------------------
# key total free pending
# ----------------------------------
# pynq 6 6 0
# rpi3b 11 11 0
# ----------------------------------
#
# You can register multiple devices to the tracker to accelerate tuning.
###########################################
# Set Tuning Options
# ------------------
# Before tuning, we should apply some configurations.
# Here we use an Pynq-Z1 board as an example.
# Tracker host and port can be set by your environment
tracker_host = os.environ.get("TVM_TRACKER_HOST", '0.0.0.0')
tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190))
# Load VTA parameters from the vta/config/vta_config.json file
env = vta.get_env()
# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
# Set ``device=arm_cpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
device = "vta"
target = env.target if device == "vta" else env.target_vta_cpu
# Name of Gluon model to compile
# The ``start_pack`` and ``stop_pack`` labels indicate where
# to start and end the graph packing relay pass: in other words
# where to start and finish offloading to VTA.
network = "resnet18_v1"
start_pack="nn.max_pool2d"
stop_pack="nn.global_avg_pool2d"
# Tuning option
log_file = "%s.%s.log" % (device, network)
tuning_option = {
'log_filename': log_file,
'tuner': 'random',
'n_trial': 1000,
'early_stopping': None,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.RPCRunner(
env.TARGET, host=tracker_host, port=tracker_port,
number=5,
timeout=60,
check_correctness=True
),
),
}
####################################################################
#
# .. note:: How to set tuning options
#
# In general, the default values provided here work well.
# If you have enough time budget, you can set :code:`n_trial`, :code:`early_stopping`
# to larger values, makes the tuning run for longer.
# If your device is under-powered or your conv2d operators are large, consider
# setting a longer timeout.
#
###################################################################
# Begin Tuning
# ------------
# Now we can extract tuning tasks from the network and begin tuning.
# Here, we provide a simple utility function to tune a list of tasks.
# This function is just an initial implementation which tunes them in sequential order.
# We will introduce a more sophisticated tuning scheduler in the future.
#
# Given that the tuning will be done on Pynq FPGA boards, make sure that
# the ```TARGET`` entry in the ``vta_config.json`` file is set to ``pynq``.
# You can skip the implementation of this function for this tutorial.
def tune_tasks(tasks,
measure_option,
tuner='xgb',
n_trial=1000,
early_stopping=None,
log_filename='tuning.log',
use_transfer_learning=True):
# create tmp log file
tmp_log_file = log_filename + ".tmp"
if os.path.exists(tmp_log_file):
os.remove(tmp_log_file)
for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
tuner_obj = XGBTuner(tsk, loss_type='rank')
elif tuner == 'xgb_knob':
tuner_obj = XGBTuner(tsk, loss_type='rank', feature_type='knob')
elif tuner == 'ga':
tuner_obj = GATuner(tsk, pop_size=50)
elif tuner == 'random':
tuner_obj = RandomTuner(tsk)
elif tuner == 'gridsearch':
tuner_obj = GridSearchTuner(tsk)
else:
raise ValueError("Invalid tuner: " + tuner)
if use_transfer_learning:
if os.path.isfile(tmp_log_file):
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
# pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_filename)
os.remove(tmp_log_file)
########################################################################
# Register VTA-specific tuning tasks
def register_vta_tuning_tasks():
from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
# init autotvm env to register VTA operator
TaskExtractEnv()
@autotvm.task.register("topi_nn_conv2d", override=True)
def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
with tvm.target.vta():
res = topi.nn.conv2d(*args, **kwargs)
res = topi.right_shift(res, 8)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
if tvm.target.current_target().device_name == 'vta':
s = topi.generic.schedule_conv2d_nchw([res])
else:
s = tvm.create_schedule([res.op])
return s, [A, W, res]
########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance.
def tune_and_evaluate(tuning_opt):
if env.TARGET != "sim":
# Get remote from fleet node
remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000)
# Reconfigure the JIT runtime and FPGA.
vta.reconfig_runtime(remote)
vta.program_fpga(remote, bitstream=None)
else:
# In simulation mode, host the RPC server locally.
remote = rpc.LocalSession()
# Register VTA tuning tasks
register_vta_tuning_tasks()
# Perform task extraction on Relay program
print("Extract tasks...")
relay_prog, params = compile_network(env, target, network, start_pack, stop_pack)
tasks = autotvm.task.extract_from_program(func=relay_prog,
params=params,
ops=(tvm.relay.op.nn.conv2d,),
target=target,
target_host=env.target_host)
# We should have extracted 10 convolution tasks
assert len(tasks) == 10
print("Extracted {} conv2d tasks:".format(len(tasks)))
for tsk in tasks:
print("\t{}".format(tsk))
# We do not run the tuning in our webpage server since it takes too long.
# Comment the following line to run it by yourself.
return
# run tuning tasks
print("Tuning...")
tune_tasks(tasks, **tuning_opt)
# compile kernels with history best records
with autotvm.tophub.context(target, extra_files=[log_file]):
# Compile network
print("Compile...")
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
if target.device_name != "vta":
graph, lib, params = relay.build(
relay_prog, target=target,
params=params, target_host=env.target_host)
else:
with vta.build_config():
graph, lib, params = relay.build(
relay_prog, target=target,
params=params, target_host=env.target_host)
# Export library
print("Upload...")
temp = util.tempdir()
lib.save(temp.relpath("graphlib.o"))
remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o")
# Generate the graph runtime
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
# upload parameters to device
image = tvm.nd.array(
(np.random.uniform(size=(1, 3, 224, 224))).astype('float32'))
m.set_input(**params)
m.set_input('data', image)
# evaluate
print("Evaluate inference time cost...")
timer = m.module.time_evaluator("run", ctx, number=1, repeat=10)
tcost = timer()
prof_res = np.array(tcost.results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
# Run the tuning and evaluate the results
tune_and_evaluate(tuning_option)
######################################################################
# Sample Output
# -------------
# The tuning needs to compile many programs and extract feature from them.
# So a high performance CPU is recommended.
# One sample output is listed below.
# It takes about 2 hours on a 16T CPU, and 6 Pynq boards.
#
# .. code-block:: bash
#
# Extract tasks...
# [Warning] Invalid shape during AutoTVM task creation
# Extracted 10 conv2d tasks:
# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 16, 14, 14, 1, 16), 'int8'), ('TENSOR', (32, 16, 1, 1, 16, 16), 'int8'), (2, 2), (0, 0), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 16, 14, 14, 1, 16, 'int8'), (32, 16, 1, 1, 16, 16, 'int8'), (2, 2), (0, 0), (1, 1), 'NCHW1n16c', 'int32'))
# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 8, 28, 28, 1, 16), 'int8'), ('TENSOR', (16, 8, 1, 1, 16, 16), 'int8'), (2, 2), (0, 0), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 8, 28, 28, 1, 16, 'int8'), (16, 8, 1, 1, 16, 16, 'int8'), (2, 2), (0, 0), (1, 1), 'NCHW1n16c', 'int32'))
# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 4, 56, 56, 1, 16), 'int8'), ('TENSOR', (8, 4, 1, 1, 16, 16), 'int8'), (2, 2), (0, 0), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 4, 56, 56, 1, 16, 'int8'), (8, 4, 1, 1, 16, 16, 'int8'), (2, 2), (0, 0), (1, 1), 'NCHW1n16c', 'int32'))
# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 4, 56, 56, 1, 16), 'int8'), ('TENSOR', (4, 4, 3, 3, 16, 16), 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 4, 56, 56, 1, 16, 'int8'), (4, 4, 3, 3, 16, 16, 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32'))
# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 8, 28, 28, 1, 16), 'int8'), ('TENSOR', (8, 8, 3, 3, 16, 16), 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 8, 28, 28, 1, 16, 'int8'), (8, 8, 3, 3, 16, 16, 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32'))
# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 4, 56, 56, 1, 16), 'int8'), ('TENSOR', (8, 4, 3, 3, 16, 16), 'int8'), (2, 2), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 4, 56, 56, 1, 16, 'int8'), (8, 4, 3, 3, 16, 16, 'int8'), (2, 2), (1, 1), (1, 1), 'NCHW1n16c', 'int32'))
# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 16, 14, 14, 1, 16), 'int8'), ('TENSOR', (16, 16, 3, 3, 16, 16), 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 16, 14, 14, 1, 16, 'int8'), (16, 16, 3, 3, 16, 16, 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32'))
# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 8, 28, 28, 1, 16), 'int8'), ('TENSOR', (16, 8, 3, 3, 16, 16), 'int8'), (2, 2), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 8, 28, 28, 1, 16, 'int8'), (16, 8, 3, 3, 16, 16, 'int8'), (2, 2), (1, 1), (1, 1), 'NCHW1n16c', 'int32'))
# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 32, 7, 7, 1, 16), 'int8'), ('TENSOR', (32, 32, 3, 3, 16, 16), 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 32, 7, 7, 1, 16, 'int8'), (32, 32, 3, 3, 16, 16, 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32'))
# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 16, 14, 14, 1, 16), 'int8'), ('TENSOR', (32, 16, 3, 3, 16, 16), 'int8'), (2, 2), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 16, 14, 14, 1, 16, 'int8'), (32, 16, 3, 3, 16, 16, 'int8'), (2, 2), (1, 1), (1, 1), 'NCHW1n16c', 'int32'))
# Tuning...
# [Task 1/10] Current/Best: 0.72/ 23.24 GFLOPS | Progress: (480/1000) | 640.31 s Done.
# [Task 2/10] Current/Best: 0.00/ 27.69 GFLOPS | Progress: (576/1000) | 810.09 s Done.
# [Task 3/10] Current/Best: 0.00/ 22.97 GFLOPS | Progress: (1000/1000) | 1125.37 s Done.
# [Task 4/10] Current/Best: 0.00/ 31.26 GFLOPS | Progress: (1000/1000) | 1025.52 s Done.
# [Task 5/10] Current/Best: 0.00/ 15.15 GFLOPS | Progress: (1000/1000) | 1236.58 s Done.
# [Task 6/10] Current/Best: 0.00/ 22.74 GFLOPS | Progress: (1000/1000) | 906.60 s Done.
# [Task 7/10] Current/Best: 0.00/ 15.27 GFLOPS | Progress: (1000/1000) | 1056.25 s Done.
# [Task 8/10] Current/Best: 0.00/ 2.18 GFLOPS | Progress: (1000/1000) | 2275.29 s Done.
# [Task 9/10] Current/Best: 2.23/ 3.99 GFLOPS | Progress: (1000/1000) | 2527.25 s Done.
# [Task 10/10] Current/Best: 1.56/ 6.32 GFLOPS | Progress: (480/1000) | 1304.84 s Done.
# Compile...
# Upload...
# Evaluate inference time cost...
# Mean inference time (std dev): 621.79 ms (0.14 ms)
######################################################################
#
# .. note:: **Experiencing Difficulties?**
#
# The auto tuning module is error-prone. If you always see " 0.00/ 0.00 GFLOPS",
# then there must be something wrong.
#
# First, make sure you set the correct configuration of your device.
# Then, you can print debug information by adding these lines in the beginning
# of the script. It will print every measurement result, where you can find useful
# error messages.
#
# .. code-block:: python
#
# import logging
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
#
# Finally, always feel free to ask our community for help on https://discuss.tvm.ai
.. _tutorial-frontend:
Compile Deep Learning Models
----------------------------
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Deploy Pretrained ResNet Model from MxNet on VTA
================================================
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
This tutorial provides an end-to-end demo, on how to run ResNet-18 inference
onto the VTA accelerator design to perform ImageNet classification tasks.
It showcases Relay as a front end compiler that can perform quantization (VTA
only supports int8/32 inference) as well as graph packing (in order to enable
tensorization in the core) to massage the compute graph for the hardware target.
"""
######################################################################
# Install dependencies
# --------------------
# To use the autotvm package in tvm, we need to install some extra dependencies.
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user mxnet requests pillow
#
# Now return to the python code. Import packages.
from __future__ import absolute_import, print_function
import argparse, json, os, requests, time
from io import BytesIO
from os.path import join, isfile
from PIL import Image
from mxnet.gluon.model_zoo import vision
import numpy as np
from matplotlib import pyplot as plt
import tvm
from tvm import rpc, autotvm, relay
from tvm.contrib import graph_runtime, util, download
from tvm.contrib.debugger import debug_runtime
import vta
from vta.testing import simulator
from vta.top import graph_pack
# Make sure that TVM was compiled with RPC=1
assert tvm.module.enabled("rpc")
######################################################################
# Define the platform and model targets
# -------------------------------------
# Execute on CPU vs. VTA, and define the model.
# Load VTA parameters from the vta/config/vta_config.json file
env = vta.get_env()
# Set ``device=arm_cpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
device = "vta"
target = env.target if device == "vta" else env.target_vta_cpu
# Name of Gluon model to compile
# The ``start_pack`` and ``stop_pack`` labels indicate where
# to start and end the graph packing relay pass: in other words
# where to start and finish offloading to VTA.
model = "resnet18_v1"
start_pack="nn.max_pool2d"
stop_pack="nn.global_avg_pool2d"
######################################################################
# Obtain an execution remote
# ---------------------------------
# When target is 'pynq', reconfigure FPGA and runtime.
# Otherwise, if target is 'sim', execute locally.
if env.TARGET != "sim":
# Get remote from tracker node if environment variable is set.
# To set up the tracker, you'll need to follow the "Auto-tuning
# a convolutional network for VTA" tutorial.
tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
tracker_port = int(os.environ.get("TVM_TRACKER_PORT", None))
# Otherwise if you have a device you want to program directly from
# the host, make sure you've set the variables below to the IP of
# your board.
device_host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99")
device_port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091"))
if not tracker_host or not tracker_port:
remote = rpc.connect(device_host, device_port)
else:
remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000)
# Reconfigure the JIT runtime and FPGA.
# You can program the FPGA with your own custom bitstream
# by passing the path to the bitstream file instead of None.
reconfig_start = time.time()
vta.reconfig_runtime(remote)
vta.program_fpga(remote, bitstream=None)
reconfig_time = time.time() - reconfig_start
print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))
# In simulation mode, host the RPC server locally.
else:
remote = rpc.LocalSession()
# Get execution context from remote
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
######################################################################
# Build the inference graph runtime
# ---------------------------------
# Grab ResNet-18 model from Gluon model zoo and compile with Relay.
# The compilation steps are:
# 1) Front end translation from MxNet into Relay module.
# 2) Apply 8-bit quantization: here we skip the first conv layer,
# and dense layer which will both be executed in fp32 on the CPU.
# 3) Perform graph packing to alter the data layout for tensorization.
# 4) Perform constant folding to reduce number of operators (e.g. eliminate
# batch norm multiply).
# 5) Perform relay build to object file.
# 6) Load the object file onto remote (FPGA device).
# 7) Generate graph runtime, `m`.
# Load pre-configured AutoTVM schedules
with autotvm.tophub.context(target):
# Populate the shape and data type dictionary for ResNet input
dtype_dict = {"data": 'float32'}
shape_dict = {"data": (env.BATCH, 3, 224, 224)}
# Get off the shelf gluon model, and convert to relay
gluon_model = vision.get_model(model, pretrained=True)
# Measure build start time
build_start = time.time()
# Start front end compilation
mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)
# Update shape and type dictionary
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params)
# Perform graph packing and constant folding for VTA target
if target.device_name == "vta":
assert env.BLOCK_IN == env.BLOCK_OUT
relay_prog = graph_pack(
relay_prog,
env.BATCH,
env.BLOCK_OUT,
env.WGT_WIDTH,
start_name=start_pack,
stop_name=stop_pack)
relay_prog = relay.ir_pass.fold_constant(relay_prog)
# Compile Relay program with AlterOpLayout disabled
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
if target.device_name != "vta":
graph, lib, params = relay.build(
relay_prog, target=target,
params=params, target_host=env.target_host)
else:
with vta.build_config():
graph, lib, params = relay.build(
relay_prog, target=target,
params=params, target_host=env.target_host)
# Measure Relay build time
build_time = time.time() - build_start
print(model + " inference graph built in {0:.2f}s!".format(build_time))
# Send the inference library over to the remote RPC server
temp = util.tempdir()
lib.save(temp.relpath("graphlib.o"))
remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o")
# Graph runtime
m = graph_runtime.create(graph, lib, ctx)
######################################################################
# Perform ResNet-18 inference
# ---------------------------
# We run classification on an image sample from ImageNet
# We just need to download the categories files, `synset.txt`
# and an input test image.
# Download ImageNet categories
categ_url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
categ_fn = "synset.txt"
download.download(join(categ_url, categ_fn), categ_fn)
synset = eval(open(categ_fn).read())
# Download test image
image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
response = requests.get(image_url)
# Prepare test image for inference
image = Image.open(BytesIO(response.content)).resize((224, 224))
plt.imshow(image)
plt.show()
image = np.array(image) - np.array([123., 117., 104.])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
image = np.repeat(image, env.BATCH, axis=0)
# Set the network parameters and inputs
m.set_input(**params)
m.set_input('data', image)
# Perform inference: we run the module 4 times,
# and repeat 3 times to get error bounds
timer = m.module.time_evaluator("run", ctx, number=4, repeat=3)
tcost = timer()
# Get classification results
tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0)))
top_categories = np.argsort(tvm_output.asnumpy()[0])
# Report top-5 classification results
std = np.std(tcost.results) * 1000 / env.BATCH
mean = tcost.mean * 1000 / env.BATCH
print("%s prediction" % model)
print(" #1:", synset[top_categories[-1]])
print(" #2:", synset[top_categories[-2]])
print(" #3:", synset[top_categories[-3]])
print(" #4:", synset[top_categories[-4]])
print(" #5:", synset[top_categories[-5]])
print("Performed inference in %.2fms/sample (std = %.2f)" % (mean, std))
# This just checks that one of the 5 top categories
# is one variety of cat; this is by no means an accurate
# assessment of how quantization affects classification
# accuracy but is meant to catch changes to the
# quantization pass that would accuracy in the CI.
cat_detected = False
for k in top_categories[-5:]:
if "cat" in synset[k]:
cat_detected = True
assert(cat_detected)
Optimize Tensor Operators
-------------------------
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
ResNet Inference Example
========================
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
This tutorial provides an end-to-end demo, on how to run ResNet-18 inference
onto the VTA accelerator design to perform ImageNet classification tasks.
"""
######################################################################
# Import Libraries
# ----------------
# We start by importing the tvm, vta, nnvm libraries to run this example.
from __future__ import absolute_import, print_function
import os
import time
from io import BytesIO
import numpy as np
import requests
from matplotlib import pyplot as plt
from PIL import Image
import tvm
from tvm import rpc, autotvm
from tvm.contrib import graph_runtime, util
from tvm.contrib.download import download
import nnvm.compiler
import vta
import vta.testing
# Load VTA parameters from the vta/config/vta_config.json file
env = vta.get_env()
# Helper to crop an image to a square (224, 224)
# Takes in an Image object, returns an Image object
def thumbnailify(image, pad=15):
w, h = image.size
crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
image = image.crop(crop)
image = image.resize((224, 224))
return image
# Helper function to read in image
# Takes in Image object, returns an ND array
def process_image(image):
# Convert to neural network input format
image = np.array(image) - np.array([123., 117., 104.])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
return tvm.nd.array(image.astype("float32"))
# Classification helper function
# Takes in the graph runtime, and an image, and returns top result and time
def classify(m, image):
m.set_input('data', image)
timer = m.module.time_evaluator("run", ctx, number=1)
tcost = timer()
tvm_output = m.get_output(0)
top = np.argmax(tvm_output.asnumpy()[0])
tcost = "t={0:.2f}s".format(tcost.mean)
return tcost + " {}".format(synset[top])
# Helper function to compile the NNVM graph
# Takes in a path to a graph file, params file, and device target
# Returns the NNVM graph object, a compiled library object, and the params dict
def generate_graph(graph_fn, params_fn, device="vta"):
# Measure build start time
build_start = time.time()
# Derive the TVM target
target = tvm.target.create("llvm -device={}".format(device))
# Derive the LLVM compiler flags
# When targetting the Pynq, cross-compile to ARMv7 ISA
if env.TARGET == "sim":
target_host = "llvm"
elif env.TARGET == "pynq":
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
# Load the ResNet-18 graph and parameters
sym = nnvm.graph.load_json(open(graph_fn).read())
params = nnvm.compiler.load_param_dict(open(params_fn, 'rb').read())
# Populate the shape and data type dictionary
shape_dict = {"data": (1, 3, 224, 224)}
dtype_dict = {"data": 'float32'}
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
# Apply NNVM graph optimization passes
sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym)
if target.device_name == "vta":
assert env.BLOCK_IN == env.BLOCK_OUT
sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)
# Compile NNVM graph
with nnvm.compiler.build_config(opt_level=3):
if target.device_name != "vta":
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)
else:
with vta.build_config():
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)
# Save the compiled inference graph library
assert tvm.module.enabled("rpc")
temp = util.tempdir()
lib.save(temp.relpath("graphlib.o"))
# Send the inference library over to the remote RPC server
remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o")
# Measure build time
build_time = time.time() - build_start
print("ResNet-18 inference graph built in {0:.2f}s!".format(build_time))
return graph, lib, params
######################################################################
# Download ResNet Model
# --------------------------------------------
# Download the necessary files to run ResNet-18.
#
# Obtain ResNet model and download them into _data dir
url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
categ_fn = 'synset.txt'
graph_fn = 'resnet18_qt8.json'
params_fn = 'resnet18_qt8.params'
# Create data dir
data_dir = "_data/"
if not os.path.exists(data_dir):
os.makedirs(data_dir)
# Download files
for file in [categ_fn, graph_fn, params_fn]:
download(os.path.join(url, file), os.path.join(data_dir, file))
# Read in ImageNet Categories
synset = eval(open(os.path.join(data_dir, categ_fn)).read())
# Download pre-tuned op parameters of conv2d for ARM CPU used in VTA
autotvm.tophub.check_backend('vta')
######################################################################
# Setup the Pynq Board's RPC Server
# ---------------------------------
# Build the RPC server's VTA runtime and program the Pynq FPGA.
# Measure build start time
reconfig_start = time.time()
# We read the Pynq RPC host IP address and port number from the OS environment
host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99")
port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091"))
# We configure both the bitstream and the runtime system on the Pynq
# to match the VTA configuration specified by the vta_config.json file.
if env.TARGET == "pynq":
# Make sure that TVM was compiled with RPC=1
assert tvm.module.enabled("rpc")
remote = rpc.connect(host, port)
# Reconfigure the JIT runtime
vta.reconfig_runtime(remote)
# Program the FPGA with a pre-compiled VTA bitstream.
# You can program the FPGA with your own custom bitstream
# by passing the path to the bitstream file instead of None.
vta.program_fpga(remote, bitstream=None)
# Report on reconfiguration time
reconfig_time = time.time() - reconfig_start
print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))
# In simulation mode, host the RPC server locally.
elif env.TARGET == "sim":
remote = rpc.LocalSession()
######################################################################
# Build the ResNet Runtime
# ------------------------
# Build the ResNet graph runtime, and configure the parameters.
# Set ``device=vtacpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
device = "vta"
# Device context
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
# Build the graph runtime
graph, lib, params = generate_graph(os.path.join(data_dir, graph_fn),
os.path.join(data_dir, params_fn),
device)
m = graph_runtime.create(graph, lib, ctx)
# Set the parameters
m.set_input(**params)
######################################################################
# Run ResNet-18 inference on a sample image
# -----------------------------------------
# Perform image classification on test image.
# You can change the test image URL to any image of your choosing.
# Read in test image
image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
# Read in test image
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).resize((224, 224))
# Show Image
plt.imshow(image)
plt.show()
# Set the input
image = process_image(image)
m.set_input('data', image)
# Perform inference
timer = m.module.time_evaluator("run", ctx, number=1)
tcost = timer()
# Get classification results
tvm_output = m.get_output(0)
top_categories = np.argsort(tvm_output.asnumpy()[0])
# Report top-5 classification results
print("ResNet-18 Prediction #1:", synset[top_categories[-1]])
print(" #2:", synset[top_categories[-2]])
print(" #3:", synset[top_categories[-3]])
print(" #4:", synset[top_categories[-4]])
print(" #5:", synset[top_categories[-5]])
print("Performed inference in {0:.2f}s".format(tcost.mean))
######################################################################
# Run a Youtube Video Image Classifier
# ------------------------------------
# Perform image classification on test stream on 1 frame every 48 frames.
# Comment the `if False:` out to run the demo
# Early exit - remove for Demo
if False:
import cv2
import pafy
from IPython.display import clear_output
# Helper to crop an image to a square (224, 224)
# Takes in an Image object, returns an Image object
def thumbnailify(image, pad=15):
w, h = image.size
crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad)
image = image.crop(crop)
image = image.resize((224, 224))
return image
# 16:16 inches
plt.rcParams['figure.figsize'] = [16, 16]
# Stream the video in
url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s"
video = pafy.new(url)
best = video.getbest(preftype="mp4")
cap = cv2.VideoCapture(best.url)
# Process one frame out of every 48 for variety
count = 0
guess = ""
while(count<2400):
# Capture frame-by-frame
ret, frame = cap.read()
# Process one every 48 frames
if count % 48 == 1:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
# Crop and resize
thumb = np.array(thumbnailify(frame))
image = process_image(thumb)
guess = classify(m, image)
# Insert guess in frame
frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50)
cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA)
plt.imshow(thumb)
plt.axis('off')
plt.show()
if cv2.waitKey(1) & 0xFF == ord('q'):
break
clear_output(wait=True)
count += 1
# When everything done, release the capture
cap.release()
cv2.destroyAllWindows()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment