Unverified Commit 93dff448 by masahi Committed by GitHub

[REDO AFTER GH BUG] Add support for quantized models via QNN (#5016)

This reverts commit f346c602.
parent f346c602
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend.""" """PT: PyTorch frontend."""
import itertools import itertools
import logging
import numpy as np import numpy as np
...@@ -32,6 +33,8 @@ from .common import get_relay_op ...@@ -32,6 +33,8 @@ from .common import get_relay_op
from .common import infer_shape as _infer_shape from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value from .common import infer_value as _infer_value
from . import qnn_torch
__all__ = ["from_pytorch"] __all__ = ["from_pytorch"]
# operator implementation # operator implementation
...@@ -146,6 +149,10 @@ def _zeros(): ...@@ -146,6 +149,10 @@ def _zeros():
def _relu(): def _relu():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
if input_types[0] == "quint8":
assert len(inputs) == 3, "Input quant param not found in op inputs"
input_zero_point = _expr.const(inputs[2], dtype="int32")
return qnn_torch.quantized_relu(data, input_zero_point)
return _op.nn.relu(data) return _op.nn.relu(data)
return _impl return _impl
...@@ -154,9 +161,14 @@ def _adaptive_avg_2d(): ...@@ -154,9 +161,14 @@ def _adaptive_avg_2d():
data = inputs[0] data = inputs[0]
output_size = _infer_shape(inputs[1]) output_size = _infer_shape(inputs[1])
return _op.nn.adaptive_avg_pool2d( def func(x):
data, return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
output_size=output_size)
if input_types[0] == "quint8":
return qnn_torch.quantized_adaptive_avg_2d(data, func)
return func(data)
return _impl return _impl
def _adaptive_max_2d(): def _adaptive_max_2d():
...@@ -506,7 +518,18 @@ def _mean(): ...@@ -506,7 +518,18 @@ def _mean():
else: else:
exclude = False exclude = False
return _op.mean(data, axis, keepdims, exclude) def func(x):
return _op.mean(x, axis, keepdims, exclude)
if input_types[0] == "quint8":
assert len(inputs) == 6, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
return qnn_torch.quantized_mean(data, input_scale,
input_zero_point, func)
return func(data)
return _impl return _impl
def _chunk(): def _chunk():
...@@ -668,10 +691,40 @@ def _upsample(method): ...@@ -668,10 +691,40 @@ def _upsample(method):
else: else:
coord_trans = "half_pixel" coord_trans = "half_pixel"
return _op.image.resize(data, out_size, "NCHW", method, coord_trans) def func(x):
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)
if input_types[0] == "quint8":
import torch
from packaging import version
# Torch version > 1.4 changed upsampling API
if version.parse(torch.__version__) > version.parse("1.4.0"):
num_inputs = 7
else:
num_inputs = 5
assert len(inputs) == num_inputs, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[-2])
input_zero_point = _expr.const(inputs[-1])
return qnn_torch.quantized_upsample(data, input_scale,
input_zero_point, func)
return func(data)
return _impl return _impl
def _expand_as():
def _impl(inputs, input_types):
# TODO: maybe fix this
# This assumes expand_as can be removed because TVM has broadcast op
msg = "aten::expand_as(...) found, assume it is part of broadcast op"
logging.warning(msg)
return inputs[0]
return _impl
# Helper functions for operator implementation # Helper functions for operator implementation
def _convert_data_type(input_type): def _convert_data_type(input_type):
...@@ -792,6 +845,7 @@ _convert_map = { ...@@ -792,6 +845,7 @@ _convert_map = {
"aten::detach" : _identity(), "aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"), "aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"), "aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
"aten::expand_as" : _expand_as()
} }
...@@ -842,6 +896,7 @@ def _report_missing_conversion(op_names): ...@@ -842,6 +896,7 @@ def _report_missing_conversion(op_names):
"prim::ListConstruct", "prim::ListUnpack", "prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack"] "prim::TupleConstruct", "prim::TupleUnpack"]
known_ops += list(_convert_map.keys()) known_ops += list(_convert_map.keys())
known_ops += list(qnn_torch.convert_map.keys())
missing = [op_name for op_name in op_names missing = [op_name for op_name in op_names
if op_name not in known_ops] if op_name not in known_ops]
...@@ -1008,6 +1063,7 @@ def parse_params(graph, state_dict): ...@@ -1008,6 +1063,7 @@ def parse_params(graph, state_dict):
getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True) getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
params = {} params = {}
param_tensors = {} param_tensors = {}
packed_param_map = {}
seen = set() seen = set()
for node in getattr_nodes: for node in getattr_nodes:
...@@ -1020,14 +1076,18 @@ def parse_params(graph, state_dict): ...@@ -1020,14 +1076,18 @@ def parse_params(graph, state_dict):
full_attr = _getattr_full_name(getattrs) full_attr = _getattr_full_name(getattrs)
full_attr_node_name = _get_output_name(getattrs[-1]) full_attr_node_name = _get_output_name(getattrs[-1])
if full_attr in state_dict: if full_attr.endswith("_packed_params"): # for quantized models
err_msg = "parameter %s not found in state dict" % full_attr
assert full_attr in state_dict, err_msg
packed_param_map[full_attr_node_name] = full_attr
elif full_attr in state_dict:
torch_tensor = state_dict[full_attr] torch_tensor = state_dict[full_attr]
tensor, var = _get_tensor_and_var(torch_tensor, tensor, var = _get_tensor_and_var(torch_tensor,
full_attr_node_name) full_attr_node_name)
param_tensors[full_attr_node_name] = tensor param_tensors[full_attr_node_name] = tensor
params[full_attr_node_name] = var params[full_attr_node_name] = var
return params, param_tensors return params, param_tensors, packed_param_map
def parse_operators(operators, outputs, output_index_map, ret_name): def parse_operators(operators, outputs, output_index_map, ret_name):
...@@ -1108,16 +1168,26 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): ...@@ -1108,16 +1168,26 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
params = script_module.state_dict() params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes) input_vars = parse_inputs(graph.inputs(), input_shapes)
param_vars, tensors = parse_params(graph, params) param_vars, tensors, packed_param_map = parse_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
input_vars.update(param_vars) input_vars.update(param_vars)
outputs = list(input_vars.values()) outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
ret_name = _get_input_names(graph.return_node())[0] ret_name = _get_input_names(graph.return_node())[0]
# For quantized models
if "aten::quantize_per_tensor" in op_names:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(outputs, output_index_map,
packed_param_map,
weight_quant_params)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map)
body = parse_operators(_get_operator_nodes(graph.nodes()), outputs, body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name) output_index_map, ret_name)
func = tvm.relay.Function(_analysis.free_vars(body), body) func = tvm.relay.Function(_analysis.free_vars(body), body)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
return _module.IRModule.from_expr(func), tvm_params return _module.IRModule.from_expr(func), tvm_params
...@@ -854,3 +854,9 @@ if __name__ == "__main__": ...@@ -854,3 +854,9 @@ if __name__ == "__main__":
test_custom_conversion_map() test_custom_conversion_map()
test_segmentaton_models() test_segmentaton_models()
# Quantization test
from qnn_test import test_quantized_imagenet, test_quantized_modules
test_quantized_modules()
test_quantized_imagenet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment