Unverified Commit 4fbc2fbe by masahi Committed by GitHub

[Torch, QNN] Remove FP32 piggy back and use QNN add/mul/concatenate (#5061)

* use qnn add/mul/concatenate

* remove logging
parent d7a74838
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# under the License. # under the License.
# pylint: disable=invalid-name, import-outside-toplevel # pylint: disable=invalid-name, import-outside-toplevel
""" Functions to convert quantized torch models to QNN """ """ Functions to convert quantized torch models to QNN """
import logging
import numpy as np import numpy as np
...@@ -536,21 +537,23 @@ def _linear(with_relu=False): ...@@ -536,21 +537,23 @@ def _linear(with_relu=False):
return _impl return _impl
def _binop(relay_op, with_relu=False): def _binop(relay_op, with_relu=False, fp32_piggy_back=False):
def qnn_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point):
qnn_out = relay_op(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point)
if with_relu:
clip_min = _get_scalar(output_zero_point)
return _op.tensor.clip(qnn_out, clip_min, 255)
return qnn_out
# refer to aten/src/ATen/native/quantized/cpu/{qadd, qmul}.cpp # refer to aten/src/ATen/native/quantized/cpu/{qadd, qmul}.cpp
# they piggy backs to fp32 math by dequantize -> fp32 math -> quantize # they piggy backs to fp32 math by dequantize -> fp32 math -> quantize
def _impl(inputs, _): def torch_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
output_scale = _expr.const(inputs[2]) input_scale_rhs, input_zero_point_rhs,
output_zero_point = _expr.const(inputs[3]) output_scale, output_zero_point):
assert len(inputs) == 8, "Input quant params not found in op inputs"
# Manually added by add_input_quant_params_to_op_inputs above
input_scale_lhs = _expr.const(inputs[4])
input_zero_point_lhs = _expr.const(inputs[5])
input_scale_rhs = _expr.const(inputs[6])
input_zero_point_rhs = _expr.const(inputs[7])
lhs = inputs[0]
rhs = inputs[1]
if isinstance(lhs, _expr.Call) and lhs.op.name == 'qnn.quantize': if isinstance(lhs, _expr.Call) and lhs.op.name == 'qnn.quantize':
lhs = lhs.args[0] lhs = lhs.args[0]
else: else:
...@@ -574,30 +577,68 @@ def _binop(relay_op, with_relu=False): ...@@ -574,30 +577,68 @@ def _binop(relay_op, with_relu=False):
output_zero_point, output_zero_point,
axis=-1, axis=-1,
out_dtype="uint8") out_dtype="uint8")
def _impl(inputs, _):
lhs = inputs[0]
rhs = inputs[1]
output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])
assert len(inputs) == 8, "Input quant params not found in op inputs"
# Manually added by add_input_quant_params_to_op_inputs above
input_scale_lhs = _expr.const(inputs[4])
input_zero_point_lhs = _expr.const(inputs[5])
input_scale_rhs = _expr.const(inputs[6])
input_zero_point_rhs = _expr.const(inputs[7])
if fp32_piggy_back:
logging.info("Piggy backing to FP32 op (PyTorch way)")
return torch_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point)
return qnn_impl(lhs, rhs, input_scale_lhs, input_zero_point_lhs,
input_scale_rhs, input_zero_point_rhs,
output_scale, output_zero_point)
return _impl return _impl
def _cat(): def _cat(fp32_piggy_back=False):
# refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp # refer to aten/src/ATen/native/quantized/cpu/qconcat.cpp
# for concat they also piggy backs to fp32(!) # for concat they also piggy backs to fp32(!)
# dequantize -> fp32 math -> quantize # dequantize -> fp32 math -> quantize
# we can also use QNN concat op. we observed no change in accuracy def torch_impl(inputs, input_scales, input_zero_points,
output_scale, output_zero_point, axis):
dequantized = []
for inp, inp_scale, inp_zp in zip(inputs, input_scales,
input_zero_points):
dequantized.append(relay.qnn.op.dequantize(inp, inp_scale, inp_zp))
concat = _op.tensor.concatenate(dequantized, axis=axis)
return relay.qnn.op.quantize(concat, output_scale, output_zero_point,
axis=axis, out_dtype="uint8")
def _impl(inputs, _): def _impl(inputs, _):
axis = inputs[1] axis = inputs[1]
output_scale = _expr.const(inputs[2]) output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3]) output_zero_point = _expr.const(inputs[3])
num_inputs = (len(inputs) - 4) // 2 num_inputs = (len(inputs) - 4) // 2
dequantized = []
input_scales = []
input_zero_points = []
for i in range(0, num_inputs): for i in range(0, num_inputs):
inp_scale = _expr.const(inputs[4+i*2]) input_scales.append(_expr.const(inputs[4+i*2]))
inp_zp = _expr.const(inputs[4+i*2+1]) input_zero_points.append(_expr.const(inputs[4+i*2+1]))
dequantized.append(relay.qnn.op.dequantize(inputs[0][i],
inp_scale, inp_zp))
concat = _op.tensor.concatenate(dequantized, axis=axis) if fp32_piggy_back:
return relay.qnn.op.quantize(concat, output_scale, output_zero_point, return torch_impl(inputs[0], input_scales, input_zero_points,
axis=1, out_dtype="uint8") output_scale, output_zero_point, axis)
return relay.qnn.op.concatenate(inputs[0],
input_scales, input_zero_points,
output_scale, output_zero_point,
axis)
return _impl return _impl
...@@ -676,15 +717,15 @@ def _mul_scalar(): ...@@ -676,15 +717,15 @@ def _mul_scalar():
convert_map = { convert_map = {
'aten::quantize_per_tensor': _quantize_per_tensor(), 'aten::quantize_per_tensor': _quantize_per_tensor(),
'quantized::conv2d_relu': _quantized_conv2d(True), 'quantized::conv2d_relu': _quantized_conv2d(with_relu=True),
'aten::dequantize': _dequantize(), 'aten::dequantize': _dequantize(),
'quantized::conv2d': _quantized_conv2d(), 'quantized::conv2d': _quantized_conv2d(),
'quantized::add_relu': _binop(relay.add, True), 'quantized::add_relu': _binop(relay.qnn.op.add, with_relu=True),
'quantized::add': _binop(relay.add), 'quantized::add': _binop(relay.qnn.op.add),
'quantized::mul_relu': _binop(relay.multiply, True), 'quantized::mul_relu': _binop(relay.qnn.op.mul, with_relu=True),
'quantized::mul': _binop(relay.multiply), 'quantized::mul': _binop(relay.qnn.op.mul),
'quantized::linear': _linear(), 'quantized::linear': _linear(),
'quantized::linear_relu': _linear(True), 'quantized::linear_relu': _linear(with_relu=True),
'quantized::cat': _cat(), 'quantized::cat': _cat(),
'quantized::add_scalar': _add_scalar(), 'quantized::add_scalar': _add_scalar(),
'quantized::mul_scalar': _mul_scalar(), 'quantized::mul_scalar': _mul_scalar(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment