Unverified Commit 93dff448 by masahi Committed by GitHub

[REDO AFTER GH BUG] Add support for quantized models via QNN (#5016)

This reverts commit f346c602.
parent f346c602
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend.""" """PT: PyTorch frontend."""
import itertools import itertools
import logging
import numpy as np import numpy as np
...@@ -32,6 +33,8 @@ from .common import get_relay_op ...@@ -32,6 +33,8 @@ from .common import get_relay_op
from .common import infer_shape as _infer_shape from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value from .common import infer_value as _infer_value
from . import qnn_torch
__all__ = ["from_pytorch"] __all__ = ["from_pytorch"]
# operator implementation # operator implementation
...@@ -146,6 +149,10 @@ def _zeros(): ...@@ -146,6 +149,10 @@ def _zeros():
def _relu(): def _relu():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
if input_types[0] == "quint8":
assert len(inputs) == 3, "Input quant param not found in op inputs"
input_zero_point = _expr.const(inputs[2], dtype="int32")
return qnn_torch.quantized_relu(data, input_zero_point)
return _op.nn.relu(data) return _op.nn.relu(data)
return _impl return _impl
...@@ -154,9 +161,14 @@ def _adaptive_avg_2d(): ...@@ -154,9 +161,14 @@ def _adaptive_avg_2d():
data = inputs[0] data = inputs[0]
output_size = _infer_shape(inputs[1]) output_size = _infer_shape(inputs[1])
return _op.nn.adaptive_avg_pool2d( def func(x):
data, return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
output_size=output_size)
if input_types[0] == "quint8":
return qnn_torch.quantized_adaptive_avg_2d(data, func)
return func(data)
return _impl return _impl
def _adaptive_max_2d(): def _adaptive_max_2d():
...@@ -506,7 +518,18 @@ def _mean(): ...@@ -506,7 +518,18 @@ def _mean():
else: else:
exclude = False exclude = False
return _op.mean(data, axis, keepdims, exclude) def func(x):
return _op.mean(x, axis, keepdims, exclude)
if input_types[0] == "quint8":
assert len(inputs) == 6, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
return qnn_torch.quantized_mean(data, input_scale,
input_zero_point, func)
return func(data)
return _impl return _impl
def _chunk(): def _chunk():
...@@ -668,10 +691,40 @@ def _upsample(method): ...@@ -668,10 +691,40 @@ def _upsample(method):
else: else:
coord_trans = "half_pixel" coord_trans = "half_pixel"
return _op.image.resize(data, out_size, "NCHW", method, coord_trans) def func(x):
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)
if input_types[0] == "quint8":
import torch
from packaging import version
# Torch version > 1.4 changed upsampling API
if version.parse(torch.__version__) > version.parse("1.4.0"):
num_inputs = 7
else:
num_inputs = 5
assert len(inputs) == num_inputs, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[-2])
input_zero_point = _expr.const(inputs[-1])
return qnn_torch.quantized_upsample(data, input_scale,
input_zero_point, func)
return func(data)
return _impl
def _expand_as():
def _impl(inputs, input_types):
# TODO: maybe fix this
# This assumes expand_as can be removed because TVM has broadcast op
msg = "aten::expand_as(...) found, assume it is part of broadcast op"
logging.warning(msg)
return inputs[0]
return _impl return _impl
# Helper functions for operator implementation # Helper functions for operator implementation
def _convert_data_type(input_type): def _convert_data_type(input_type):
...@@ -792,6 +845,7 @@ _convert_map = { ...@@ -792,6 +845,7 @@ _convert_map = {
"aten::detach" : _identity(), "aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"), "aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"), "aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
"aten::expand_as" : _expand_as()
} }
...@@ -842,6 +896,7 @@ def _report_missing_conversion(op_names): ...@@ -842,6 +896,7 @@ def _report_missing_conversion(op_names):
"prim::ListConstruct", "prim::ListUnpack", "prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack"] "prim::TupleConstruct", "prim::TupleUnpack"]
known_ops += list(_convert_map.keys()) known_ops += list(_convert_map.keys())
known_ops += list(qnn_torch.convert_map.keys())
missing = [op_name for op_name in op_names missing = [op_name for op_name in op_names
if op_name not in known_ops] if op_name not in known_ops]
...@@ -1008,6 +1063,7 @@ def parse_params(graph, state_dict): ...@@ -1008,6 +1063,7 @@ def parse_params(graph, state_dict):
getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True) getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
params = {} params = {}
param_tensors = {} param_tensors = {}
packed_param_map = {}
seen = set() seen = set()
for node in getattr_nodes: for node in getattr_nodes:
...@@ -1020,14 +1076,18 @@ def parse_params(graph, state_dict): ...@@ -1020,14 +1076,18 @@ def parse_params(graph, state_dict):
full_attr = _getattr_full_name(getattrs) full_attr = _getattr_full_name(getattrs)
full_attr_node_name = _get_output_name(getattrs[-1]) full_attr_node_name = _get_output_name(getattrs[-1])
if full_attr in state_dict: if full_attr.endswith("_packed_params"): # for quantized models
err_msg = "parameter %s not found in state dict" % full_attr
assert full_attr in state_dict, err_msg
packed_param_map[full_attr_node_name] = full_attr
elif full_attr in state_dict:
torch_tensor = state_dict[full_attr] torch_tensor = state_dict[full_attr]
tensor, var = _get_tensor_and_var(torch_tensor, tensor, var = _get_tensor_and_var(torch_tensor,
full_attr_node_name) full_attr_node_name)
param_tensors[full_attr_node_name] = tensor param_tensors[full_attr_node_name] = tensor
params[full_attr_node_name] = var params[full_attr_node_name] = var
return params, param_tensors return params, param_tensors, packed_param_map
def parse_operators(operators, outputs, output_index_map, ret_name): def parse_operators(operators, outputs, output_index_map, ret_name):
...@@ -1108,16 +1168,26 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): ...@@ -1108,16 +1168,26 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
params = script_module.state_dict() params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes) input_vars = parse_inputs(graph.inputs(), input_shapes)
param_vars, tensors = parse_params(graph, params) param_vars, tensors, packed_param_map = parse_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
input_vars.update(param_vars) input_vars.update(param_vars)
outputs = list(input_vars.values()) outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
ret_name = _get_input_names(graph.return_node())[0] ret_name = _get_input_names(graph.return_node())[0]
# For quantized models
if "aten::quantize_per_tensor" in op_names:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(outputs, output_index_map,
packed_param_map,
weight_quant_params)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map)
body = parse_operators(_get_operator_nodes(graph.nodes()), outputs, body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name) output_index_map, ret_name)
func = tvm.relay.Function(_analysis.free_vars(body), body) func = tvm.relay.Function(_analysis.free_vars(body), body)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
return _module.IRModule.from_expr(func), tvm_params return _module.IRModule.from_expr(func), tvm_params
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, import-outside-toplevel
""" Functions to convert quantized torch models to QNN """
import numpy as np
import tvm
from tvm import relay
from tvm.relay import expr as _expr
from tvm.relay import op as _op
from tvm.relay.frontend.common import infer_shape
class QNNParam:
""" A placeholder for weight quantization parameters """
def __init__(self, weight, bias, scale, zero_point, param_key):
param_prefix = param_key[:-len("._packed_params")]
self.weight_var = _expr.var(param_prefix + "_weight",
shape=weight.shape)
self.weight = weight
if bias is not None:
self.bias_var = _expr.var(param_prefix + "_bias",
shape=bias.shape)
self.bias = bias.detach().numpy()
else:
self.bias_var = None
self.bias = None
self.scale = _expr.const(scale)
self.zero_point = _expr.const(zero_point, dtype="int32")
def _unpack_quant_params(param_name, packed_params, unpack_func):
# Torch stores quantized params in a custom packed format,
# need to unpack and retrieve them as numpy arrays
qweight, bias = unpack_func(packed_params)
weight_np = qweight.dequantize().numpy()
import torch
if qweight.qscheme() == torch.per_tensor_affine:
param = QNNParam(weight_np, bias, qweight.q_scale(),
int(qweight.q_zero_point()), param_name)
else:
scales = qweight.q_per_channel_scales().numpy()
zero_points = qweight.q_per_channel_zero_points().numpy()
# This is an assumption posed by QNN
msg = "The values of zero points should be all zero for per channel"
assert np.all(zero_points == 0), msg
param = QNNParam(weight_np, bias, scales, 0, param_name)
return param
def get_weight_quant_params(script_module):
""" Retrive and unpack weight parameters from quantized modules """
conv_packed_params = []
linear_packed_params = []
import torch
# conv and linear requires different unpacking function
# extract all conv and linear parameters separately to distinguish them
for name, m in script_module.named_modules():
if isinstance(m, torch.jit.RecursiveScriptModule):
if "Conv" in m.original_name:
conv_packed_params.append((name, m.state_dict()))
elif m.original_name == "LinearPackedParams":
linear_packed_params.append((name, m.state_dict()))
pairs = [(torch.ops.quantized.conv2d_unpack, conv_packed_params),
(torch.ops.quantized.linear_unpack, linear_packed_params)]
quant_params = {}
param_name = "_packed_params"
for unpack_func, params in pairs:
for name, state_dict in params:
assert len(state_dict) == 1
assert param_name in state_dict
key = name + "." + param_name
packed_param = state_dict[param_name]
quant_params[key] = _unpack_quant_params(key, packed_param,
unpack_func)
return quant_params
def add_quant_params_to_outputs(outputs, output_index_map,
packed_param_map, quant_params):
"""
Add quant params to outputs so that they can be referenced by other
ops later. Weights are quantized here.
"""
for node_name, packed_param_name in packed_param_map.items():
qparam = quant_params[packed_param_name]
output_index_map[node_name] = len(outputs)
qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale,
qparam.zero_point, out_dtype="int8",
axis=0)
param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var)
outputs.append(param_tup)
def _get_quant_param_for_input(input_value):
"""
We want to know the input scale and zp of this input_value, since
input quant params are not explicitly passed around in torch (they
are embeded in a QTensor data structure, not visible statically).
We know that it is quantized using output scale and zp
of some previous quantized op. The purpose of this function
is to find that pair of parameters.
"""
# Indices for output scale and zp
# For example, in quantized::conv2d(%input, %1, %2, %3, %4, %5, %6, %7),
# 6th and 7th arg are output scale and zp respectively.
output_quant_param_indices = {
"aten::quantize_per_tensor": (1, 2),
"quantized::conv2d": (6, 7),
"quantized::conv2d_relu": (6, 7),
"quantized::linear": (2, 3),
"quantized::linear_relu": (2, 3),
"quantized::add_relu": (2, 3),
"quantized::add": (2, 3),
"quantized::mul_relu": (2, 3),
"quantized::mul": (2, 3),
"quantized::cat": (2, 3),
"quantized::mul_scalar": (2, 3),
"quantized::add_scalar": (2, 3)
}
def dfs(current_node):
# trace back to find the producer of this input value
current_op = current_node.kind()
if current_op in output_quant_param_indices:
indices = output_quant_param_indices[current_op]
scale = current_node.inputsAt(indices[0])
zp = current_node.inputsAt(indices[1])
return scale, zp
# Trace back eariler nodes, dfs order
# Assume quantized tensor comes earlier in the args
for arg in current_node.inputs():
return dfs(arg.node())
# shouldn't happen
assert False, "No producer for %s" % (str(current_node))
return dfs(input_value.node())
def _get_add_scalar_output_quant_param(input_scale, input_zero_point,
scalar):
"""
Determine the output scale and zp of quantized::add_scalar op
This is used for mobilenet v3
Refer to aten/src/ATen/native/quantized/cpu/qadd.cpp
The names of variables are the same as torch impl
"""
q_min = 0
q_max = 255
s = input_scale
z = input_zero_point
c = scalar
c_q = round(c / s)
if q_min > z - c_q:
s_prime = (float(q_max) - (z - c_q)) / (float(q_max) - q_min) * s
z_prime = q_min
elif q_max < z - c_q:
s_prime = (float(z - c_q) - q_min) / (float(q_max) - q_min) * s
z_prime = q_max
else:
s_prime = s
z_prime = z - c_q
return s_prime, z_prime
def _get_mul_scalar_output_quant_param(input_scale, input_zero_point,
scalar):
"""
Determine the output scale and zp of quantized::mul_scalar op
This is used for mobilenet v3
Refer to aten/src/ATen/native/quantized/cpu/qmul.cpp
The names of variables are the same as torch impl
"""
q_min = 0
q_max = 255
self_scale = input_scale
self_zero_point = input_zero_point
other_val = scalar
if other_val > 0.0:
s_prime = other_val * self_scale
z_prime = self_zero_point
elif other_val == 0.0:
s_prime = 1.0
z_prime = 0
else:
s_prime = abs(other_val) * self_scale
z_prime = q_max - (self_zero_point - q_min)
return s_prime, z_prime
def _add_output_quant_params_to_scalar_op(node, graph,
input_scale, input_zero_point,
scalar):
"""
The output scale and zp of {add,mul}_scalar op are not explicit in the IR
They are required for _get_quant_param_for_input above to work correctly
So calculate these params using the same way torch does, and make new
constant nodes in the input IR. Also add these params to the inputs of
scalar op.
For example,
%6 : float = prim::Constant[value=3.]()
%input : QUInt8(1, 3, 224, 224) = quantized::add_scalar(%x.1, %6)
becomes
%6 : float = prim::Constant[value=3.]()
%7 : float = prim::Constant[value=0.015686161816120148]()
%8 : int = prim::Constant[value=0]()
%input : UInt8(1, 3, 224, 224) = quantized::add_scalar(%x.1, %6, %7, %8)
%7 and %8 are newly created output scale and zp constant nodes
"""
import torch
operator = node.kind()
if operator == "quantized::mul_scalar":
out_scale, out_zero_point = \
_get_mul_scalar_output_quant_param(input_scale, input_zero_point,
scalar)
elif operator == "quantized::add_scalar":
out_scale, out_zero_point = \
_get_add_scalar_output_quant_param(input_scale, input_zero_point,
scalar)
else:
raise NotImplementedError("unsupported scalar op: %s" % operator)
# create new constant nodes and add them to graph
out_scale_node = graph.create("prim::Constant")
out_zero_point_node = graph.create("prim::Constant")
out_scale_node.insertBefore(node)
out_zero_point_node.insertBefore(node)
out_scale_node.f_("value", out_scale)
out_zero_point_node.i_("value", out_zero_point)
out_scale_node.output().setType(torch._C.FloatType.get())
out_zero_point_node.output().setType(torch._C.IntType.get())
node.addInput(out_scale_node.output())
node.addInput(out_zero_point_node.output())
def add_input_quant_params_to_op_inputs(graph):
"""
In Torch, input quant params are not explicitly passed around
Instead, they are stored in QTensor data structure, and retrieved
at runtime by each quantized ops.
However, they need to be known statically for QNN translation.
To workaround and simplify the translation of inputs, we manually add
input quant params to inputs of Torch quantized operators listed below.
See _quantized_conv2d() below for example of why this is helpful.
For example,
%input : QUInt8(1, 512, 7, 7) = quantized::add(%x.8, %x.9, %434, %435)
becomes
%395 : float = prim::Constant[value=0.036212071776390076]()
%396 : int = prim::Constant[value=0]()
%430 : float = prim::Constant[value=0.16080744564533234]()
%431 : int = prim::Constant[value=42]()
%input : QUInt8(1, 512, 7, 7) = quantized::add(%x.8, %x.9, %434, %435,
%430, %431, %395, %396)
%434, %435 are output scale and zp of quantized::add op
%430, %431, %395, %396 are two pairs of input (scale, zp) for two tensors
added by this function
"""
# How many quantized tensors each op takes as inputs?
# A pair of (scale, zp) for each input quantized tensor will be added
# to the input nodes
num_quantized_inputs = {"quantized::conv2d": 1,
"quantized::conv2d_relu": 1,
"quantized::linear": 1,
"quantized::linear_relu": 1,
"quantized::add_relu": 2,
"quantized::add": 2,
"quantized::mul_relu": 2,
"quantized::mul": 2,
"aten::dequantize": 1,
"aten::mean": 1,
"aten::upsample_bilinear2d": 1,
"aten::relu_": 1,
"aten::relu": 1,
"quantized::add_scalar": 1,
"quantized::mul_scalar": 1,
'quantized::relu6': 1}
need_input_quant_param = set(num_quantized_inputs.keys())
need_input_quant_param.add("quantized::cat")
for node in graph.nodes():
operator = node.kind()
if operator not in need_input_quant_param:
continue
input_scales = []
input_zero_points = []
if operator == "quantized::cat":
# the number of inputs to concat is not constant
# so handle it separately
inputs = node.inputsAt(0).node().inputs()
for inp in inputs:
scale, zp = _get_quant_param_for_input(inp)
input_scales.append(scale)
input_zero_points.append(zp)
else:
for i in range(num_quantized_inputs[operator]):
scale, zp = _get_quant_param_for_input(node.inputsAt(i))
input_scales.append(scale)
input_zero_points.append(zp)
if operator in ["quantized::add_scalar", "quantized::mul_scalar"]:
scalar = node.inputsAt(1).node().f("value")
inp_scale = input_scales[0].node().f("value")
inp_zero_point = input_zero_points[0].node().i("value")
# see the comments in this function above
_add_output_quant_params_to_scalar_op(node, graph,
inp_scale, inp_zero_point,
scalar)
for scale, zp in zip(input_scales, input_zero_points):
node.addInput(scale)
node.addInput(zp)
def add_quant_params(params, quant_params):
""" Add quant parameters to TVM param map """
for qparam in quant_params.values():
params[qparam.weight_var.name_hint] = tvm.nd.array(qparam.weight)
if qparam.bias is not None:
params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias)
def quantized_adaptive_avg_2d(data, func_fp32):
# this follows tflite impl
inp = _op.cast(data, dtype="int32")
out = func_fp32(inp)
return _op.cast(out, "uint8")
def quantized_mean(data, input_scale, input_zero_point, func_fp32):
# refer to aten/src/ATen/native/quantized/cpu/qreduction.cpp
dequantized = relay.qnn.op.dequantize(data, input_scale, input_zero_point)
out = func_fp32(dequantized)
return relay.qnn.op.quantize(out, input_scale, input_zero_point,
out_dtype="uint8", axis=1)
def quantized_upsample(data, input_scale, input_zero_point, func_fp32):
# currently piggy backs to fp32, it gets identical output as torch
data = relay.qnn.op.dequantize(data, input_scale, input_zero_point)
out = func_fp32(data)
return relay.qnn.op.quantize(out, input_scale, input_zero_point,
out_dtype="uint8", axis=1)
def quantized_relu(data, input_zero_point):
# refer to aten/src/ATen/native/quantized/cpu/qrelu.cpp
zp = _op.cast(input_zero_point, dtype="uint8")
return _op.tensor.maximum(data, zp)
def _quantize_per_tensor():
def _impl(inputs, _):
return relay.qnn.op.quantize(inputs[0], _expr.const(inputs[1]),
_expr.const(inputs[2]), out_dtype="uint8",
axis=1)
return _impl
def _dequantize():
def _impl(inputs, _):
assert len(inputs) == 3, "Input quant params not found in op inputs"
inp_scale = _expr.const(inputs[1])
inp_zero_point = _expr.const(inputs[2])
return relay.qnn.op.dequantize(inputs[0], inp_scale, inp_zero_point)
return _impl
def _get_numpy(relay_const_scalar):
return relay_const_scalar.data.asnumpy()
def _get_scalar(relay_const_scalar):
return np.asscalar(_get_numpy(relay_const_scalar))
def _do_bias_and_requantize(output, bias, input_scale, weight_scale,
output_scale, output_zero_point,
with_relu):
""" Output processing for conv and linear """
# this is a vector for per channel case
requant_input_scale = _expr.const(_get_numpy(input_scale) *
_get_numpy(weight_scale))
# Torch does bias add and requanize scale in fp32
# refer to third_party/fbgemm/include/fbgemm/OutputProcessing-inl.h
# Instead, we do bias add in int32 and use qnn requantize, which needs
# integer input.
# We observed no loss in accuracy in doing this way, and it is better
# for tvm because bias quantization can be done at compile time
# Instead, the torch way requires rounding of activation at runtime
if bias is not None:
qbias = relay.qnn.op.quantize(bias, requant_input_scale,
_expr.const(0, "int32"),
out_dtype="int32", axis=0)
requantize_input = _op.nn.bias_add(output, qbias)
else:
requantize_input = output
requantized = relay.qnn.op.requantize(requantize_input,
requant_input_scale,
relay.const(0, 'int32'),
output_scale, output_zero_point,
out_dtype="int32", axis=1)
clip_min = 0
if with_relu:
clip_min = _get_scalar(output_zero_point)
clip = _op.tensor.clip(requantized, clip_min, 255.)
return _op.cast(clip, dtype="uint8")
def _quantized_conv2d(with_relu=False):
def _impl(inputs, _):
# refer to src/ATen/native/quantized/cpu/qconv.cpp
# inputs[0]: input tensor
# inputs[1]: (weight, scale, zero_point, bias)
# inputs[2-5]: stride, padding, dilation, groups
# inputs[6]: output_scale
# inputs[7]: output_zero_point
# inputs[8]: input_scale (added manually by frontend)
# inputs[9]: input_zero_point (added manually by frontend)
weight = inputs[1][0]
weight_scale = inputs[1][1]
weight_zero_point = inputs[1][2]
output_scale = _expr.const(inputs[6])
output_zero_point = _expr.const(inputs[7])
assert len(inputs) == 10, "Input quant params not found in op inputs"
# These are manually added by add_input_quant_params_to_op_inputs above
# In torch, they are retrieved from QTensor data structure at runtime
input_scale = _expr.const(inputs[8])
input_zero_point = _expr.const(inputs[9])
strides, padding, dilation = inputs[2], inputs[3], inputs[4]
strides = infer_shape(inputs[2])
padding = infer_shape(inputs[3])
dilation = infer_shape(inputs[4])
groups = inputs[5]
weight_shape = infer_shape(weight)
kernel_size = (weight_shape[2], weight_shape[3])
out_channels = weight_shape[0]
if padding[0] != 0 or padding[1] != 0:
pad_val = _get_scalar(input_zero_point)
inp = _op.nn.pad(inputs[0], pad_width=((0, 0),
(0, 0),
(padding[0], padding[0]),
(padding[1], padding[1])),
pad_value=float(pad_val))
else:
inp = inputs[0]
# padding is (0, 0) because we did explicit pad op with
# pad value being zero point above
conv_out = relay.qnn.op.conv2d(inp, weight,
input_zero_point, weight_zero_point,
input_scale, weight_scale,
kernel_size=kernel_size,
dilation=dilation, strides=strides,
padding=(0, 0), groups=groups,
channels=out_channels)
bias_var = inputs[1][3]
return _do_bias_and_requantize(conv_out, bias_var, input_scale,
weight_scale, output_scale,
output_zero_point, with_relu)
return _impl
def _linear(with_relu=False):
# similar to conv
def _impl(inputs, _):
weight = inputs[1][0]
weight_scale = inputs[1][1]
weight_zero_point = inputs[1][2]
output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])
assert len(inputs) == 6, "Input quant params not found in op inputs"
# Manually added by add_input_quant_params_to_op_inputs above
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
weight_shape = infer_shape(weight)
dense = relay.qnn.op.dense(inputs[0], weight,
input_zero_point, weight_zero_point,
input_scale, weight_scale,
units=weight_shape[0])
bias_var = inputs[1][3]
return _do_bias_and_requantize(dense, bias_var, input_scale,
weight_scale, output_scale,
output_zero_point, with_relu)
return _impl
def _binop(relay_op, with_relu=False):
# refer to aten/src/ATen/native/quantized/cpu/{qadd, qmul}.cpp
# they piggy backs to fp32 math by dequantize -> fp32 math -> quantize
def _impl(inputs, _):
output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])
assert len(inputs) == 8, "Input quant params not found in op inputs"
# Manually added by add_input_quant_params_to_op_inputs above
input_scale_lhs = _expr.const(inputs[4])
input_zero_point_lhs = _expr.const(inputs[5])
input_scale_rhs = _expr.const(inputs[6])
input_zero_point_rhs = _expr.const(inputs[7])
lhs = inputs[0]
rhs = inputs[1]
if isinstance(lhs, _expr.Call) and lhs.op.name == 'qnn.quantize':
lhs = lhs.args[0]
else:
lhs = relay.qnn.op.dequantize(lhs,
input_scale_lhs,
input_zero_point_lhs)
if isinstance(rhs, _expr.Call) and rhs.op.name == 'qnn.quantize':
rhs = rhs.args[0]
else:
rhs = relay.qnn.op.dequantize(rhs,
input_scale_rhs,
input_zero_point_rhs)
fp32_out = relay_op(lhs, rhs)
if with_relu:
fp32_out = _op.nn.relu(fp32_out)
return relay.qnn.op.quantize(fp32_out,
output_scale,
output_zero_point,
axis=-1,
out_dtype="uint8")
return _impl
def _cat():
# refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp
# for concat they also piggy backs to fp32(!)
# dequantize -> fp32 math -> quantize
# we can also use QNN concat op. we observed no change in accuracy
def _impl(inputs, _):
axis = inputs[1]
output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])
num_inputs = (len(inputs) - 4) // 2
dequantized = []
for i in range(0, num_inputs):
inp_scale = _expr.const(inputs[4+i*2])
inp_zp = _expr.const(inputs[4+i*2+1])
dequantized.append(relay.qnn.op.dequantize(inputs[0][i],
inp_scale, inp_zp))
concat = _op.tensor.concatenate(dequantized, axis=axis)
return relay.qnn.op.quantize(concat, output_scale, output_zero_point,
axis=1, out_dtype="uint8")
return _impl
def _add_scalar():
# this is used for mobilenet v3
def _impl(inputs, _):
# refer to aten/src/ATen/native/quantized/cpu/qadd.cpp
assert len(inputs) == 6, "Input quant params not found in op inputs"
s = inputs[4]
z = inputs[5]
c = inputs[1]
c_q = round(c / s)
q_min = 0
q_max = 255
# math for calculating output scale and zp are already done
# during _add_output_quant_params_to_scalar_op above
out_scale = _expr.const(inputs[2])
out_zp = _expr.const(inputs[3])
if q_min > z - c_q or q_max < z - c_q:
dequant = relay.qnn.op.dequantize(inputs[0],
_expr.const(s), _expr.const(z))
dequantized_add = _op.tensor.add(dequant, _expr.const(c_q * s))
return relay.qnn.op.quantize(dequantized_add, out_scale, out_zp,
axis=1, out_dtype="uint8")
# only scale change
return inputs[0]
return _impl
def quantize_scalar(data, scale, zero_point):
# used to quantize 6., in mobilenet v3
transformed = zero_point + data / scale
return max(0, min(round(transformed), 255))
def _relu6():
# refer to src/ATen/native/quantized/cpu/qrelu.cpp
def _impl(inputs, _):
assert len(inputs) == 4, "Input quant params not found in op inputs"
input_scale = inputs[2]
input_zero_point = inputs[3]
six = quantize_scalar(6., input_scale, input_zero_point)
return _op.tensor.clip(inputs[0], input_zero_point, six)
return _impl
def _mul_scalar():
# this is used for mobilenet v3
def _impl(inputs, _):
# refer to aten/src/ATen/native/quantized/cpu/qmul.cpp
# math for calculating output scale and zp are already done
# during _add_output_quant_params_to_scalar_op above
assert len(inputs) == 6, "Input quant params not found in op inputs"
other_val = inputs[1] # scalar
if other_val > 0.0:
# only scale change
return inputs[0]
if other_val == 0.0:
shape = infer_shape(inputs[0])
return _op.full(_expr.const(0), shape, dtype="uint8")
# negative scale case
q_min = 0
q_max = 255
bias = _expr.const(q_max + q_min, dtype="int8")
int8 = bias - _op.cast(inputs[0], "int8")
return _op.cast(int8, "uint8")
return _impl
convert_map = {
'aten::quantize_per_tensor': _quantize_per_tensor(),
'quantized::conv2d_relu': _quantized_conv2d(True),
'aten::dequantize': _dequantize(),
'quantized::conv2d': _quantized_conv2d(),
'quantized::add_relu': _binop(relay.add, True),
'quantized::add': _binop(relay.add),
'quantized::mul_relu': _binop(relay.multiply, True),
'quantized::mul': _binop(relay.multiply),
'quantized::linear': _linear(),
'quantized::linear_relu': _linear(True),
'quantized::cat': _cat(),
'quantized::add_scalar': _add_scalar(),
'quantized::mul_scalar': _mul_scalar(),
'quantized::relu6': _relu6()
}
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Tests on quantized torch model conversion """
import os
from PIL import Image
import numpy as np
import torch
from torch import nn
from torch.quantization import QuantStub, DeQuantStub
from torch.quantization import fuse_modules, QuantWrapper
import tvm
from tvm import relay
from tvm.relay.frontend.pytorch import get_graph_input_names
from tvm.contrib.download import download_testdata
def torch_version_check():
from packaging import version
return version.parse(torch.__version__) > version.parse("1.4.0")
def get_tvm_runtime(script_module, input_name, ishape):
input_shapes = {input_name: ishape}
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
with relay.build_config(opt_level=3):
# test on only cpu for now, torch cannot run quant models on cuda
# also not to make CI too slow
json, lib, params = relay.build(mod, target="llvm", params=params)
runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.cpu(0))
runtime.set_input(**params)
return runtime
def get_qconfig(per_channel):
from torch.quantization.observer import MovingAverageMinMaxObserver
from torch.quantization.observer import default_weight_observer
if per_channel:
return torch.quantization.get_default_qconfig('fbgemm')
else:
act = MovingAverageMinMaxObserver.with_args(reduce_range=False)
return torch.quantization.QConfig(activation=act,
weight=default_weight_observer)
def quantize_model(model, inp, per_channel=False, dummy=True):
model.fuse_model()
model.qconfig = get_qconfig(per_channel)
torch.quantization.prepare(model, inplace=True)
model(inp)
torch.quantization.convert(model, inplace=True)
class ConvBn(nn.Module):
def __init__(self, with_relu=False):
super().__init__()
layers = [nn.Conv2d(3, 32, 3, bias=True),
nn.BatchNorm2d(32)]
if with_relu:
layers.append(nn.ReLU())
self.conv = nn.Sequential(*layers)
self.quant_wrap = QuantWrapper(self.conv)
self.with_relu = with_relu
def forward(self, x):
return self.quant_wrap(x)
def fuse_model(self):
indices = ["0", "1"]
if self.with_relu:
indices.append("2")
fuse_modules(self.conv, indices, inplace=True)
class Linear(nn.Module):
def __init__(self, with_relu=False):
super().__init__()
layers = [nn.Linear(16, 32)]
if with_relu:
layers.append(nn.ReLU())
self.fc = nn.Sequential(*layers)
self.quant_wrap = QuantWrapper(self.fc)
self.with_relu = with_relu
def forward(self, x):
return self.quant_wrap(x)
def fuse_model(self):
if self.with_relu:
fuse_modules(self.fc, ["0", "1"], inplace=True)
class ReLU(nn.Module):
def __init__(self):
super().__init__()
self.relu = QuantWrapper(nn.ReLU())
def forward(self, x):
return self.relu(x)
def fuse_model(self):
pass
# Mobilenet V3 related modules
class Hsigmoid(nn.Module):
def __init__(self, inplace=True, add_stub=False):
super().__init__()
self.float_op = nn.quantized.FloatFunctional()
self.relu6 = nn.ReLU6(inplace=inplace)
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.add_stub = add_stub
def forward(self, x):
if self.add_stub:
x = self.quant(x)
relu6 = self.relu6(self.float_op.add_scalar(x, 3.))
mul = self.float_op.mul_scalar(relu6, 1/6.)
if self.add_stub:
mul = self.dequant(mul)
return mul
def fuse_model(self):
pass
class Hswish(nn.Module):
def __init__(self, inplace=True, add_stub=False):
super(Hswish, self).__init__()
self.float_op = nn.quantized.FloatFunctional()
self.hsigmoid = Hsigmoid(inplace, add_stub=False)
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.add_stub = add_stub
def forward(self, x):
if self.add_stub:
x = self.quant(x)
mul = self.float_op.mul(x, self.hsigmoid(x))
if self.add_stub:
mul = self.dequant(mul)
return mul
def fuse_model(self):
pass
class SqueezeExcite(nn.Module):
def __init__(self, channel, reduction=4, add_stub=False):
super(SqueezeExcite, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
Hsigmoid(add_stub=False)
)
self.fmul = nn.quantized.FloatFunctional()
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.add_stub = add_stub
def forward(self, x):
b, c, _, _ = x.size()
if self.add_stub:
x = self.quant(x)
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
out = self.fmul.mul(x, y.expand_as(x))
if self.add_stub:
return self.dequant(out)
else:
return out
def fuse_model(self):
fuse_modules(self.fc, ["0", "1"], inplace=True)
# test on quantized::mul_scalar with negative scale
class MulScalarNegative(nn.Module):
def __init__(self, ):
super().__init__()
self.float_op = nn.quantized.FloatFunctional()
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
mul = self.float_op.mul_scalar(x, -0.3)
return self.dequant(mul)
def fuse_model(self):
pass
class UpsamplingBilinear(nn.Module):
def __init__(self):
super().__init__()
self.relu = QuantWrapper(nn.ReLU())
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
upsample = nn.functional.interpolate(x, scale_factor=2,
mode='bilinear',
align_corners=True)
return self.dequant(upsample)
def fuse_model(self):
pass
def test_quantized_modules():
imagenet_ishape = (1, 3, 224, 224)
qmodules = [
("relu", imagenet_ishape, ReLU(), False),
("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False),
]
for per_channel in [False, True]:
if per_channel:
postfix = ", per_channel"
else:
postfix = ""
qmodules += [
("conv_bn" + postfix, imagenet_ishape, ConvBn(), per_channel),
("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel),
("linear" + postfix, (16, 16), Linear(), per_channel),
("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel)
]
if torch_version_check():
qmodules += [
("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
("hswish", imagenet_ishape, Hswish(add_stub=True), False),
("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True),
("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False)
]
else:
print("Skipping tests that require torch > 1.4")
for (module_name, ishape, raw_module, per_channel) in qmodules:
raw_module.eval()
inp = torch.rand(ishape)
quantize_model(raw_module, inp, per_channel=per_channel, dummy=True)
script_module = torch.jit.trace(raw_module, inp).eval()
with torch.no_grad():
pt_result = script_module(inp.clone()).numpy()
input_name = get_graph_input_names(script_module)[0]
runtime = get_tvm_runtime(script_module, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
tvm_result = runtime.get_output(0).asnumpy()
max_abs_diff = np.max(np.abs(tvm_result - pt_result))
mean_abs_diff = np.mean(np.abs(tvm_result - pt_result))
num_identical = np.sum(tvm_result == pt_result)
match_ratio = num_identical / float(np.prod(tvm_result.shape))
print(module_name, max_abs_diff, mean_abs_diff, match_ratio)
# sample outputs
"""
relu 0.0039215684 2.6052087e-08 0.9999933567176871
upsample bilinear 0.0 0.0 1.0
conv_bn 0.22062653 0.011478779 0.6909348115006899
conv_bn_relu 0.3700896 0.010921672 0.7489366477964451
linear 0.15987062 0.009231662 0.794921875
linear_relu 0.14180502 0.0053220326 0.8828125
conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019
conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732
linear, per_channel 0.0 0.0 1.0
linear_relu, per_channel 0.0 0.0 1.0
hsigmoid 0.002614379 0.00020525524 0.9214896896258503
hswish 0.0052286386 0.00063522335 0.7587359162414966
semodule, per_channel 0.0039885044 0.0008620687 0.7838592529296875
mul_scalar negative 0.0011764616 7.815566e-09 0.9999933567176871
"""
# we cannot make any guarantee on how close the raw output is to torch
# tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1)
def test_quantized_imagenet():
def get_transform():
import torchvision.transforms as transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
def get_real_image(im_height, im_width):
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)
img_path = download_testdata(image_url, img_name, module='data')
return Image.open(img_path).resize((im_height, im_width))
def get_imagenet_input():
im = get_real_image(224, 224)
preprocess = get_transform()
pt_tensor = preprocess(im)
return np.expand_dims(pt_tensor.numpy(), 0)
from torchvision.models.quantization import resnet as qresnet
from torchvision.models.quantization import mobilenet as qmobilenet
from torchvision.models.quantization import inception as qinception
from torchvision.models.quantization import googlenet as qgooglenet
qmodels = []
for per_channel in [False, True]:
qmodels += [
("resnet18", qresnet.resnet18(pretrained=True), per_channel),
("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel),
("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
("googlenet", qgooglenet(pretrained=True), per_channel),
]
results = []
for (model_name, raw_model, per_channel) in qmodels:
raw_model.eval()
if per_channel:
model_name += ", per channel quantization"
else:
model_name += ", per tensor quantization"
inp = get_imagenet_input()
pt_inp = torch.from_numpy(inp)
quantize_model(raw_model, pt_inp, per_channel=per_channel, dummy=False)
script_module = torch.jit.trace(raw_model, pt_inp).eval()
with torch.no_grad():
pt_result = script_module(pt_inp).numpy()
input_name = get_graph_input_names(script_module)[0]
runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224))
runtime.set_input(input_name, inp)
runtime.run()
tvm_result = runtime.get_output(0).asnumpy()
results.append((model_name, pt_result[0], tvm_result[0]))
for (model_name, pt_result, tvm_result) in results:
max_abs_diff = np.max(np.abs(tvm_result - pt_result))
mean_abs_diff = np.mean(np.abs(tvm_result - pt_result))
num_identical = np.sum(tvm_result == pt_result)
pt_top3_labels = np.argsort(pt_result)[::-1][:3]
tvm_top3_labels = np.argsort(pt_result)[::-1][:3]
print("\nModel name: %s" % model_name)
print("PyTorch top3 label:", pt_top3_labels)
print("TVM top3 label:", tvm_top3_labels)
print("max abs diff:", max_abs_diff)
print("mean abs_diff:", mean_abs_diff)
print("%d in 1000 raw outputs identical." % num_identical)
assert set(pt_top3_labels) == set(tvm_top3_labels)
# sample outputs
"""
Model name: resnet18, per tensor quantization
PyTorch top3 label: [386 101 385]
TVM top3 label: [386 101 385]
max abs diff: 0.65681696
mean abs_diff: 0.14055882
236 in 1000 raw outputs identical.
Model name: mobilenet_v2, per tensor quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 2.1262953
mean abs_diff: 0.41025686
101 in 1000 raw outputs identical.
Model name: inception_v3, per tensor quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.9994669
mean abs_diff: 0.098697364
272 in 1000 raw outputs identical.
Model name: googlenet, per tensor quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.28248847
mean abs_diff: 0.0634469
274 in 1000 raw outputs identical.
Model name: resnet18, per channel quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.65908074
mean abs_diff: 0.1274223
469 in 1000 raw outputs identical.
Model name: mobilenet_v2, per channel quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.71120834
mean abs_diff: 0.15883648
423 in 1000 raw outputs identical.
Model name: inception_v3, per channel quantization
PyTorch top3 label: [386 101 385]
TVM top3 label: [386 101 385]
max abs diff: 1.3372154
mean abs_diff: 0.1225224
401 in 1000 raw outputs identical.
Model name: googlenet, per channel quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.34015465
mean abs_diff: 0.054197952
558 in 1000 raw outputs identical.
"""
...@@ -854,3 +854,9 @@ if __name__ == "__main__": ...@@ -854,3 +854,9 @@ if __name__ == "__main__":
test_custom_conversion_map() test_custom_conversion_map()
test_segmentaton_models() test_segmentaton_models()
# Quantization test
from qnn_test import test_quantized_imagenet, test_quantized_modules
test_quantized_modules()
test_quantized_imagenet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment