Commit 028f47ce by Thierry Moreau Committed by Jared Roesch

[VTA][Relay] Extending Vision model coverage compilation for VTA (#3740)

* adding support for graphpack over multiply op

* increasing resnet model coverage

* fix indentation

* lint

* moving recursion limit fix into graphpack pass

* moving recursionlimit to relay init

* pooling on NCHWnc format

* adding more models

* deploy_resnet_on_vta.py

* trailing line

* generalizing to vision models

* merge conflicts

* fix, apply quantization to VTA only

* improving comments

* trimming models that have runtime issues for the moment

* lint

* lint

* lint
parent dee11b41
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# pylint: disable=wildcard-import, redefined-builtin, invalid-name # pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler.""" """The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import from __future__ import absolute_import
from sys import setrecursionlimit
from ..api import register_func from ..api import register_func
from . import base from . import base
from . import ty from . import ty
...@@ -59,6 +60,9 @@ from . import qnn ...@@ -59,6 +60,9 @@ from . import qnn
from .scope_builder import ScopeBuilder from .scope_builder import ScopeBuilder
# Required to traverse large programs
setrecursionlimit(10000)
# Span # Span
Span = base.Span Span = base.Span
......
...@@ -161,9 +161,12 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs, ...@@ -161,9 +161,12 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "max_pool2d does not support input split on width"; << "max_pool2d does not support input split on width";
CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) CHECK(inputs[0].ndim() == 4U ||
inputs[0].ndim() == 5U ||
inputs[0].ndim() == 6U)
<< "Pool2D only support 4-D input (e.g., NCHW)" << "Pool2D only support 4-D input (e.g., NCHW)"
<< " or 5-D input (last dimension is a split of channel)"; << " or 5-D input (e.g. NCHWc on for vector instructions)"
<< " or 6-D input (e.g. NCHWnc for tensor accelerators)";
if (param->padding.size() == 1) { if (param->padding.size() == 1) {
padding.push_back(padding[0]); padding.push_back(padding[0]);
......
...@@ -85,8 +85,8 @@ def _pack_weight_conv2d_transpose(data, dshape, cfactor): ...@@ -85,8 +85,8 @@ def _pack_weight_conv2d_transpose(data, dshape, cfactor):
return data return data
def _pack_bias(data, dshape, dtype, bfactor, cfactor): def _pack_const(data, dshape, dtype, bfactor, cfactor):
"""Pack the bias parameter. """Pack a constant parameter.
""" """
dshape = _to_shape(dshape) dshape = _to_shape(dshape)
assert len(dshape) == 3 assert len(dshape) == 3
...@@ -124,6 +124,7 @@ class ExprPack(ExprMutator): ...@@ -124,6 +124,7 @@ class ExprPack(ExprMutator):
self.conv2d = op.op.get("nn.conv2d") self.conv2d = op.op.get("nn.conv2d")
self.conv2d_transpose = op.op.get("nn.conv2d_transpose") self.conv2d_transpose = op.op.get("nn.conv2d_transpose")
self.add = op.op.get("add") self.add = op.op.get("add")
self.multiply = op.op.get("multiply")
self.bias_add = op.op.get("nn.bias_add") self.bias_add = op.op.get("nn.bias_add")
self.number_of_conv2d = 0 self.number_of_conv2d = 0
super().__init__() super().__init__()
...@@ -203,23 +204,35 @@ class ExprPack(ExprMutator): ...@@ -203,23 +204,35 @@ class ExprPack(ExprMutator):
output_padding=call.attrs.output_padding, output_padding=call.attrs.output_padding,
out_dtype=call.attrs.out_dtype) out_dtype=call.attrs.out_dtype)
return conv2d return conv2d
elif call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape): elif call.op == self.add and \
tuple(input_types[0].shape) == tuple(input_types[1].shape):
pass pass
elif call.op == self.add and len(input_types[1].shape) == 3: elif call.op == self.add and len(input_types[1].shape) == 3:
data, bias = args data, const = args
bias = _pack_bias(bias, const = _pack_const(const,
_to_shape(input_types[1].shape), _to_shape(input_types[1].shape),
input_types[1].dtype, input_types[1].dtype,
self.bfactor, self.bfactor,
self.cfactor) self.cfactor)
return relay.Call(self.add, [data, bias]) return relay.Call(self.add, [data, const])
elif call.op == self.multiply and \
tuple(input_types[0].shape) == tuple(input_types[1].shape):
pass
elif call.op == self.multiply and len(input_types[1].shape) == 3:
data, const = args
const = _pack_const(const,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
return relay.Call(self.multiply, [data, const])
elif self.start_pack and call.op == self.bias_add: elif self.start_pack and call.op == self.bias_add:
data, bias = args data, bias = args
bias = _pack_bias(bias, bias = _pack_const(bias,
_to_shape(input_types[1].shape), _to_shape(input_types[1].shape),
input_types[1].dtype, input_types[1].dtype,
self.bfactor, self.bfactor,
self.cfactor) self.cfactor)
return relay.Call(self.add, [data, bias]) return relay.Call(self.add, [data, bias])
elif self.start_pack and call.op == op.op.get('cast') and \ elif self.start_pack and call.op == op.op.get('cast') and \
input_types[0].dtype == 'int32': input_types[0].dtype == 'int32':
......
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
""" """
Deploy Pretrained ResNet Model from MxNet on VTA Deploy Pretrained Vision Model from MxNet on VTA
================================================ ================================================
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_ **Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
This tutorial provides an end-to-end demo, on how to run ResNet-18 inference This tutorial provides an end-to-end demo, on how to run ImageNet classification
onto the VTA accelerator design to perform ImageNet classification tasks. inference onto the VTA accelerator design to perform ImageNet classification tasks.
It showcases Relay as a front end compiler that can perform quantization (VTA It showcases Relay as a front end compiler that can perform quantization (VTA
only supports int8/32 inference) as well as graph packing (in order to enable only supports int8/32 inference) as well as graph packing (in order to enable
tensorization in the core) to massage the compute graph for the hardware target. tensorization in the core) to massage the compute graph for the hardware target.
...@@ -40,7 +40,7 @@ tensorization in the core) to massage the compute graph for the hardware target. ...@@ -40,7 +40,7 @@ tensorization in the core) to massage the compute graph for the hardware target.
from __future__ import absolute_import, print_function from __future__ import absolute_import, print_function
import argparse, json, os, requests, time import argparse, json, os, requests, sys, time
from io import BytesIO from io import BytesIO
from os.path import join, isfile from os.path import join, isfile
from PIL import Image from PIL import Image
...@@ -53,6 +53,7 @@ import tvm ...@@ -53,6 +53,7 @@ import tvm
from tvm import rpc, autotvm, relay from tvm import rpc, autotvm, relay
from tvm.contrib import graph_runtime, util, download from tvm.contrib import graph_runtime, util, download
from tvm.contrib.debugger import debug_runtime from tvm.contrib.debugger import debug_runtime
from tvm.relay import transform
import vta import vta
from vta.testing import simulator from vta.testing import simulator
...@@ -61,7 +62,6 @@ from vta.top import graph_pack ...@@ -61,7 +62,6 @@ from vta.top import graph_pack
# Make sure that TVM was compiled with RPC=1 # Make sure that TVM was compiled with RPC=1
assert tvm.module.enabled("rpc") assert tvm.module.enabled("rpc")
###################################################################### ######################################################################
# Define the platform and model targets # Define the platform and model targets
# ------------------------------------- # -------------------------------------
...@@ -75,13 +75,22 @@ env = vta.get_env() ...@@ -75,13 +75,22 @@ env = vta.get_env()
device = "vta" device = "vta"
target = env.target if device == "vta" else env.target_vta_cpu target = env.target if device == "vta" else env.target_vta_cpu
# Dictionary lookup for when to start/end bit packing
pack_dict = {
"resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
}
# Name of Gluon model to compile # Name of Gluon model to compile
# The ``start_pack`` and ``stop_pack`` labels indicate where # The ``start_pack`` and ``stop_pack`` labels indicate where
# to start and end the graph packing relay pass: in other words # to start and end the graph packing relay pass: in other words
# where to start and finish offloading to VTA. # where to start and finish offloading to VTA.
model = "resnet18_v1" model = "resnet18_v1"
start_pack="nn.max_pool2d" assert model in pack_dict
stop_pack="nn.global_avg_pool2d"
###################################################################### ######################################################################
# Obtain an execution remote # Obtain an execution remote
...@@ -125,7 +134,7 @@ ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) ...@@ -125,7 +134,7 @@ ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
###################################################################### ######################################################################
# Build the inference graph runtime # Build the inference graph runtime
# --------------------------------- # ---------------------------------
# Grab ResNet-18 model from Gluon model zoo and compile with Relay. # Grab vision model from Gluon model zoo and compile with Relay.
# The compilation steps are: # The compilation steps are:
# 1) Front end translation from MxNet into Relay module. # 1) Front end translation from MxNet into Relay module.
# 2) Apply 8-bit quantization: here we skip the first conv layer, # 2) Apply 8-bit quantization: here we skip the first conv layer,
...@@ -140,7 +149,7 @@ ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) ...@@ -140,7 +149,7 @@ ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
# Load pre-configured AutoTVM schedules # Load pre-configured AutoTVM schedules
with autotvm.tophub.context(target): with autotvm.tophub.context(target):
# Populate the shape and data type dictionary for ResNet input # Populate the shape and data type dictionary for ImageNet classifier input
dtype_dict = {"data": 'float32'} dtype_dict = {"data": 'float32'}
shape_dict = {"data": (env.BATCH, 3, 224, 224)} shape_dict = {"data": (env.BATCH, 3, 224, 224)}
...@@ -157,21 +166,22 @@ with autotvm.tophub.context(target): ...@@ -157,21 +166,22 @@ with autotvm.tophub.context(target):
shape_dict.update({k: v.shape for k, v in params.items()}) shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod["main"], params=params)
# Perform graph packing and constant folding for VTA target
if target.device_name == "vta": if target.device_name == "vta":
# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod["main"], params=params)
# Perform graph packing and constant folding for VTA target
assert env.BLOCK_IN == env.BLOCK_OUT assert env.BLOCK_IN == env.BLOCK_OUT
relay_prog = graph_pack( relay_prog = graph_pack(
relay_prog, relay_prog,
env.BATCH, env.BATCH,
env.BLOCK_OUT, env.BLOCK_OUT,
env.WGT_WIDTH, env.WGT_WIDTH,
start_name=start_pack, start_name=pack_dict[model][0],
stop_name=stop_pack) stop_name=pack_dict[model][1])
else:
relay_prog = mod["main"]
# Compile Relay program with AlterOpLayout disabled # Compile Relay program with AlterOpLayout disabled
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
...@@ -199,8 +209,8 @@ with autotvm.tophub.context(target): ...@@ -199,8 +209,8 @@ with autotvm.tophub.context(target):
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
###################################################################### ######################################################################
# Perform ResNet-18 inference # Perform image classification inference
# --------------------------- # --------------------------------------
# We run classification on an image sample from ImageNet # We run classification on an image sample from ImageNet
# We just need to download the categories files, `synset.txt` # We just need to download the categories files, `synset.txt`
# and an input test image. # and an input test image.
...@@ -256,7 +266,6 @@ else: ...@@ -256,7 +266,6 @@ else:
tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0))) tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0)))
for b in range(env.BATCH): for b in range(env.BATCH):
top_categories = np.argsort(tvm_output.asnumpy()[b]) top_categories = np.argsort(tvm_output.asnumpy()[b])
# Report top-5 classification results # Report top-5 classification results
print("\n{} prediction for sample {}".format(model, b)) print("\n{} prediction for sample {}".format(model, b))
print("\t#1:", synset[top_categories[-1]]) print("\t#1:", synset[top_categories[-1]])
...@@ -264,7 +273,6 @@ for b in range(env.BATCH): ...@@ -264,7 +273,6 @@ for b in range(env.BATCH):
print("\t#3:", synset[top_categories[-3]]) print("\t#3:", synset[top_categories[-3]])
print("\t#4:", synset[top_categories[-4]]) print("\t#4:", synset[top_categories[-4]])
print("\t#5:", synset[top_categories[-5]]) print("\t#5:", synset[top_categories[-5]])
# This just checks that one of the 5 top categories # This just checks that one of the 5 top categories
# is one variety of cat; this is by no means an accurate # is one variety of cat; this is by no means an accurate
# assessment of how quantization affects classification # assessment of how quantization affects classification
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment