Unverified Commit fc7f0783 by Animesh Jain Committed by GitHub

[Torch, QNN] Add support for quantized models via QNN (#4977)

* qnn support initial import

* fix upsampling num input

* imagenet tests added

* add qunatized module tests

* quantized module tests working

* imagenet test working

* fix lint

* remove top level torch import to fix ci error

* disable lint warning on outside toplevel import

* revert parse -> convert change

* add comments to qnn translation

* address comments, add sample outputs

* add more comments

* refactor bias add and requantize step
parent 585f9ce6
......@@ -19,6 +19,7 @@
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend."""
import itertools
import logging
import numpy as np
......@@ -32,6 +33,8 @@ from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from . import qnn_torch
__all__ = ["from_pytorch"]
# operator implementation
......@@ -146,6 +149,10 @@ def _zeros():
def _relu():
def _impl(inputs, input_types):
data = inputs[0]
if input_types[0] == "quint8":
assert len(inputs) == 3, "Input quant param not found in op inputs"
input_zero_point = _expr.const(inputs[2], dtype="int32")
return qnn_torch.quantized_relu(data, input_zero_point)
return _op.nn.relu(data)
return _impl
......@@ -154,9 +161,14 @@ def _adaptive_avg_2d():
data = inputs[0]
output_size = _infer_shape(inputs[1])
return _op.nn.adaptive_avg_pool2d(
data,
output_size=output_size)
def func(x):
return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
if input_types[0] == "quint8":
return qnn_torch.quantized_adaptive_avg_2d(data, func)
return func(data)
return _impl
def _adaptive_max_2d():
......@@ -503,7 +515,18 @@ def _mean():
else:
exclude = False
return _op.mean(data, axis, keepdims, exclude)
def func(x):
return _op.mean(x, axis, keepdims, exclude)
if input_types[0] == "quint8":
assert len(inputs) == 6, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
return qnn_torch.quantized_mean(data, input_scale,
input_zero_point, func)
return func(data)
return _impl
def _chunk():
......@@ -665,10 +688,40 @@ def _upsample(method):
else:
coord_trans = "half_pixel"
return _op.image.resize(data, out_size, "NCHW", method, coord_trans)
def func(x):
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)
if input_types[0] == "quint8":
import torch
from packaging import version
# Torch version > 1.4 changed upsampling API
if version.parse(torch.__version__) > version.parse("1.4.0"):
num_inputs = 7
else:
num_inputs = 5
assert len(inputs) == num_inputs, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[-2])
input_zero_point = _expr.const(inputs[-1])
return qnn_torch.quantized_upsample(data, input_scale,
input_zero_point, func)
return func(data)
return _impl
def _expand_as():
def _impl(inputs, input_types):
# TODO: maybe fix this
# This assumes expand_as can be removed because TVM has broadcast op
msg = "aten::expand_as(...) found, assume it is part of broadcast op"
logging.warning(msg)
return inputs[0]
return _impl
# Helper functions for operator implementation
def _convert_data_type(input_type):
......@@ -789,6 +842,7 @@ _convert_map = {
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
"aten::expand_as" : _expand_as()
}
......@@ -839,6 +893,7 @@ def _report_missing_conversion(op_names):
"prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack"]
known_ops += list(_convert_map.keys())
known_ops += list(qnn_torch.convert_map.keys())
missing = [op_name for op_name in op_names
if op_name not in known_ops]
......@@ -991,6 +1046,7 @@ def parse_params(graph, state_dict):
getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
params = {}
param_tensors = {}
packed_param_map = {}
seen = set()
for node in getattr_nodes:
......@@ -1003,14 +1059,18 @@ def parse_params(graph, state_dict):
full_attr = _getattr_full_name(getattrs)
full_attr_node_name = _get_output_name(getattrs[-1])
if full_attr in state_dict:
if full_attr.endswith("_packed_params"): # for quantized models
err_msg = "parameter %s not found in state dict" % full_attr
assert full_attr in state_dict, err_msg
packed_param_map[full_attr_node_name] = full_attr
elif full_attr in state_dict:
torch_tensor = state_dict[full_attr]
tensor, var = _get_tensor_and_var(torch_tensor,
full_attr_node_name)
param_tensors[full_attr_node_name] = tensor
params[full_attr_node_name] = var
return params, param_tensors
return params, param_tensors, packed_param_map
def parse_operators(operators, outputs, output_index_map, ret_name):
......@@ -1090,16 +1150,26 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes)
param_vars, tensors = parse_params(graph, params)
param_vars, tensors, packed_param_map = parse_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
input_vars.update(param_vars)
outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
ret_name = _get_input_names(graph.return_node())[0]
# For quantized models
if "aten::quantize_per_tensor" in op_names:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(outputs, output_index_map,
packed_param_map,
weight_quant_params)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map)
body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
func = tvm.relay.Function(_analysis.free_vars(body), body)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
return _module.IRModule.from_expr(func), tvm_params
......@@ -849,3 +849,9 @@ if __name__ == "__main__":
test_custom_conversion_map()
test_segmentaton_models()
# Quantization test
from qnn_test import test_quantized_imagenet, test_quantized_modules
test_quantized_modules()
test_quantized_imagenet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment