Unverified Commit 4fbc2fbe by masahi Committed by GitHub

[Torch, QNN] Remove FP32 piggy back and use QNN add/mul/concatenate (#5061)

* use qnn add/mul/concatenate

* remove logging
parent d7a74838
......@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name, import-outside-toplevel
""" Functions to convert quantized torch models to QNN """
import logging
import numpy as np
......@@ -536,21 +537,23 @@ def _linear(with_relu=False):
return _impl
def _binop(relay_op, with_relu=False):
def _binop(relay_op, with_relu=False, fp32_piggy_back=False):
def qnn_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point):
qnn_out = relay_op(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point)
if with_relu:
clip_min = _get_scalar(output_zero_point)
return _op.tensor.clip(qnn_out, clip_min, 255)
return qnn_out
# refer to aten/src/ATen/native/quantized/cpu/{qadd, qmul}.cpp
# they piggy backs to fp32 math by dequantize -> fp32 math -> quantize
def _impl(inputs, _):
output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])
assert len(inputs) == 8, "Input quant params not found in op inputs"
# Manually added by add_input_quant_params_to_op_inputs above
input_scale_lhs = _expr.const(inputs[4])
input_zero_point_lhs = _expr.const(inputs[5])
input_scale_rhs = _expr.const(inputs[6])
input_zero_point_rhs = _expr.const(inputs[7])
lhs = inputs[0]
rhs = inputs[1]
def torch_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point):
if isinstance(lhs, _expr.Call) and lhs.op.name == 'qnn.quantize':
lhs = lhs.args[0]
else:
......@@ -574,30 +577,68 @@ def _binop(relay_op, with_relu=False):
output_zero_point,
axis=-1,
out_dtype="uint8")
def _impl(inputs, _):
lhs = inputs[0]
rhs = inputs[1]
output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])
assert len(inputs) == 8, "Input quant params not found in op inputs"
# Manually added by add_input_quant_params_to_op_inputs above
input_scale_lhs = _expr.const(inputs[4])
input_zero_point_lhs = _expr.const(inputs[5])
input_scale_rhs = _expr.const(inputs[6])
input_zero_point_rhs = _expr.const(inputs[7])
if fp32_piggy_back:
logging.info("Piggy backing to FP32 op (PyTorch way)")
return torch_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point)
return qnn_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point)
return _impl
def _cat():
def _cat(fp32_piggy_back=False):
# refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp
# for concat they also piggy backs to fp32(!)
# dequantize -> fp32 math -> quantize
# we can also use QNN concat op. we observed no change in accuracy
def torch_impl(inputs, input_scales, input_zero_points,
output_scale, output_zero_point, axis):
dequantized = []
for inp, inp_scale, inp_zp in zip(inputs, input_scales,
input_zero_points):
dequantized.append(relay.qnn.op.dequantize(inp, inp_scale, inp_zp))
concat = _op.tensor.concatenate(dequantized, axis=axis)
return relay.qnn.op.quantize(concat, output_scale, output_zero_point,
axis=axis, out_dtype="uint8")
def _impl(inputs, _):
axis = inputs[1]
output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])
num_inputs = (len(inputs) - 4) // 2
dequantized = []
input_scales = []
input_zero_points = []
for i in range(0, num_inputs):
inp_scale = _expr.const(inputs[4+i*2])
inp_zp = _expr.const(inputs[4+i*2+1])
dequantized.append(relay.qnn.op.dequantize(inputs[0][i],
inp_scale, inp_zp))
input_scales.append(_expr.const(inputs[4+i*2]))
input_zero_points.append(_expr.const(inputs[4+i*2+1]))
concat = _op.tensor.concatenate(dequantized, axis=axis)
return relay.qnn.op.quantize(concat, output_scale, output_zero_point,
axis=1, out_dtype="uint8")
if fp32_piggy_back:
return torch_impl(inputs[0], input_scales, input_zero_points,
output_scale, output_zero_point, axis)
return relay.qnn.op.concatenate(inputs[0],
input_scales, input_zero_points,
output_scale, output_zero_point,
axis)
return _impl
......@@ -676,15 +717,15 @@ def _mul_scalar():
convert_map = {
'aten::quantize_per_tensor': _quantize_per_tensor(),
'quantized::conv2d_relu': _quantized_conv2d(True),
'quantized::conv2d_relu': _quantized_conv2d(with_relu=True),
'aten::dequantize': _dequantize(),
'quantized::conv2d': _quantized_conv2d(),
'quantized::add_relu': _binop(relay.add, True),
'quantized::add': _binop(relay.add),
'quantized::mul_relu': _binop(relay.multiply, True),
'quantized::mul': _binop(relay.multiply),
'quantized::add_relu': _binop(relay.qnn.op.add, with_relu=True),
'quantized::add': _binop(relay.qnn.op.add),
'quantized::mul_relu': _binop(relay.qnn.op.mul, with_relu=True),
'quantized::mul': _binop(relay.qnn.op.mul),
'quantized::linear': _linear(),
'quantized::linear_relu': _linear(True),
'quantized::linear_relu': _linear(with_relu=True),
'quantized::cat': _cat(),
'quantized::add_scalar': _add_scalar(),
'quantized::mul_scalar': _mul_scalar(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment