Commit 1e4aea81 by Animesh Jain Committed by Yizhi Liu

[Legalize][QNN] Pass out_types to Legalize. Update QNN requantize to read from out_types. (#3782)

parent 17f8f96b
...@@ -206,10 +206,24 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos): ...@@ -206,10 +206,24 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
@reg.register_legalize("nn.conv2d") @reg.register_legalize("nn.conv2d")
def legalize_conv2d(attrs, inputs, arg_dtypes): def legalize_conv2d(attrs, inputs, types):
"""Legalize conv2d""" """Legalize conv2d op.
from ... import op
return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op) Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.conv2d_legalize(attrs, inputs, types)
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
......
...@@ -42,11 +42,17 @@ Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ...@@ -42,11 +42,17 @@ Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef&
Expr new_e; Expr new_e;
bool modified = false; bool modified = false;
if (fop_legalize.count(op)) { if (fop_legalize.count(op)) {
tvm::Array<tvm::relay::Type> arg_types; // Collect input and output dtypes to pass on to Legalize API.
tvm::Array<tvm::relay::Type> types;
for (auto& expr : ref_call->args) { for (auto& expr : ref_call->args) {
arg_types.push_back(expr->checked_type()); types.push_back(expr->checked_type());
} }
Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types); types.push_back(ref_call->checked_type());
// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, types);
// Check if the transformation succeeded. If not, revert back to the original ref_call->op.
if (legalized_value.defined()) { if (legalized_value.defined()) {
new_e = legalized_value; new_e = legalized_value;
modified = true; modified = true;
......
...@@ -74,12 +74,12 @@ Expr DequantizeLower(const Expr& input_tensor, ...@@ -74,12 +74,12 @@ Expr DequantizeLower(const Expr& input_tensor,
Expr DequantizeLegalize(const Attrs& attrs, Expr DequantizeLegalize(const Attrs& attrs,
const Array<Expr>& new_args, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) { const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1); CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0]; auto& data = new_args[0];
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>(); const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
CHECK(dequantize_attrs != nullptr); CHECK(dequantize_attrs != nullptr);
CHECK_EQ(arg_types.size(), 1); CHECK_EQ(types.size(), 2);
return DequantizeLower(data, dequantize_attrs); return DequantizeLower(data, dequantize_attrs);
} }
......
...@@ -85,13 +85,13 @@ Expr QuantizeLower(const Expr& input_tensor, ...@@ -85,13 +85,13 @@ Expr QuantizeLower(const Expr& input_tensor,
Expr QuantizeLegalize(const Attrs& attrs, Expr QuantizeLegalize(const Attrs& attrs,
const Array<Expr>& new_args, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) { const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1); CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0]; auto& data = new_args[0];
const auto* quantize_attrs = attrs.as<QuantizeAttrs>(); const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
CHECK(quantize_attrs != nullptr); CHECK(quantize_attrs != nullptr);
CHECK_EQ(arg_types.size(), 1); CHECK_EQ(types.size(), 2);
return QuantizeLower(data, quantize_attrs); return QuantizeLower(data, quantize_attrs);
} }
......
...@@ -109,7 +109,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie ...@@ -109,7 +109,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
* 7) Cast to the out_dtype. * 7) Cast to the out_dtype.
*/ */
Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape) { const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
double double_multiplier = param->input_scale / param->output_scale; double double_multiplier = param->input_scale / param->output_scale;
// Choose high precision datatype to be int64. This is for avoiding overflow // Choose high precision datatype to be int64. This is for avoiding overflow
...@@ -173,10 +173,10 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, ...@@ -173,10 +173,10 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
auto shifted_int64_t = Add(output_zp, scaled_int64_t); auto shifted_int64_t = Add(output_zp, scaled_int64_t);
// 7) Clip to the out_dtype min/max. // 7) Clip to the out_dtype min/max.
auto q_min = GetQmin(param->out_dtype); auto q_min = GetQmin(out_dtype);
auto q_max = GetQmax(param->out_dtype); auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max); auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
return Cast(clipped_t, param->out_dtype); return Cast(clipped_t, out_dtype);
} }
/* /*
...@@ -193,25 +193,32 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, ...@@ -193,25 +193,32 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
* Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
*/ */
Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args, Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) { const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1); CHECK_EQ(new_args.size(), 1);
auto& quantized_data = new_args[0]; auto& quantized_data = new_args[0];
const auto* param = attrs.as<RequantizeAttrs>(); const auto* param = attrs.as<RequantizeAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
// Find input shape. // Find input shape.
CHECK_EQ(arg_types.size(), 1); CHECK_EQ(types.size(), 2);
auto input_dtype = arg_types[0]; auto in_type = types[0];
auto input_tensor_type = input_dtype.as<TensorTypeNode>(); auto in_tensor_type = in_type.as<TensorTypeNode>();
CHECK(input_tensor_type != nullptr) << "Type information missing." CHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass."; << " Please run infer_type pass.";
Array<IndexExpr> input_shape = input_tensor_type->shape; Array<IndexExpr> input_shape = in_tensor_type->shape;
// Find the output dtype.
auto out_type = types[1];
auto out_tensor_type = out_type.as<TensorTypeNode>();
CHECK(out_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
auto out_dtype = out_tensor_type->dtype;
// Check rounding validity. // Check rounding validity.
CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
<< "QNN requantize supports two rounding modes - UPWARD and " << "QNN requantize supports two rounding modes - UPWARD and "
<< "TONEAREST"; << "TONEAREST";
return RequantizeLower(quantized_data, param, input_shape); return RequantizeLower(quantized_data, param, input_shape, out_dtype);
} }
/* /*
...@@ -261,7 +268,7 @@ The requantize operator converts one quantized tensor to another quantized ...@@ -261,7 +268,7 @@ The requantize operator converts one quantized tensor to another quantized
tensor. For the output tensor, we are provided with output scale and zero tensor. For the output tensor, we are provided with output scale and zero
point. The computation looks like this point. The computation looks like this
Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.RequantizeAttrs") .set_attrs_type_key("relay.attrs.RequantizeAttrs")
......
...@@ -47,7 +47,7 @@ def test_legalize(): ...@@ -47,7 +47,7 @@ def test_legalize():
return y return y
@register_legalize("nn.conv2d", level=100) @register_legalize("nn.conv2d", level=100)
def legalize_conv2d(attrs, inputs, arg_types): def legalize_conv2d(attrs, inputs, types):
data, weight = inputs data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32")) weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs) return relay.nn.conv2d(data, weight, **attrs)
...@@ -80,7 +80,7 @@ def test_legalize_none(): ...@@ -80,7 +80,7 @@ def test_legalize_none():
called = [False] called = [False]
@register_legalize("nn.global_max_pool2d", level=101) @register_legalize("nn.global_max_pool2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types): def legalize_conv2d(attrs, inputs, types):
called[0] = True called[0] = True
return None return None
...@@ -103,12 +103,13 @@ def test_legalize_multi_input(): ...@@ -103,12 +103,13 @@ def test_legalize_multi_input():
return func return func
@register_legalize("concatenate", level=100) @register_legalize("concatenate", level=100)
def legalize_concatenate(attrs, inputs, arg_types): def legalize_concatenate(attrs, inputs, types):
# Check that the correct multi-input case is handled. # Check that the correct multi-input case is handled.
assert len(inputs) == 1 assert len(inputs) == 1
assert isinstance(inputs[0], tvm.relay.expr.Tuple) assert isinstance(inputs[0], tvm.relay.expr.Tuple)
assert len(arg_types) == 1 assert len(types) == 2
assert isinstance(arg_types[0], tvm.relay.ty.TupleType) assert isinstance(types[0], tvm.relay.ty.TupleType)
assert isinstance(types[1], tvm.relay.ty.TensorType)
return None return None
def expected(): def expected():
...@@ -153,9 +154,9 @@ def test_legalize_arm_layout_functional(): ...@@ -153,9 +154,9 @@ def test_legalize_arm_layout_functional():
return func return func
@register_legalize("nn.conv2d", level=101) @register_legalize("nn.conv2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types): def legalize_conv2d(attrs, inputs, types):
from topi.arm_cpu.conv2d import _conv2d_legalize from topi.arm_cpu.conv2d import _conv2d_legalize
return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op) return _conv2d_legalize(attrs, inputs, types)
a = before() a = before()
b = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(a, transform.Legalize())
......
...@@ -18,10 +18,11 @@ ...@@ -18,10 +18,11 @@
"""Conv2D schedule for ARM CPU""" """Conv2D schedule for ARM CPU"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import warnings import logging
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm import relay
import tvm.contrib.nnpack import tvm.contrib.nnpack
from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \ from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
...@@ -35,6 +36,8 @@ from ..nn import conv2d_legalize ...@@ -35,6 +36,8 @@ from ..nn import conv2d_legalize
from ..nn.util import get_const_int, get_pad_tuple from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices from ..nn.winograd_util import winograd_transform_matrices
logger = logging.getLogger('topi')
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct']) @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
"""TOPI compute callback for conv2d """TOPI compute callback for conv2d
...@@ -671,7 +674,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): ...@@ -671,7 +674,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
if layout != 'NCHW': if layout != 'NCHW':
return None return None
if dilation != (1, 1): if dilation != (1, 1):
warnings.warn("Does not support weight pre-transform for dilated convolution.") logger.warning("Does not support weight pre-transform for dilated convolution.")
return None return None
data, kernel = tinfos[0:2] data, kernel = tinfos[0:2]
...@@ -786,31 +789,46 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): ...@@ -786,31 +789,46 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
return None return None
@conv2d_legalize.register("arm_cpu") @conv2d_legalize.register("arm_cpu")
def _conv2d_legalize(attrs, inputs, arg_types, F): def _conv2d_legalize(attrs, inputs, arg_types):
if F.__name__ != 'tvm.relay.op': """Legalizes Conv2D op.
return None
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
if attrs['data_layout'] == 'NHWC': if attrs['data_layout'] == 'NHWC':
data, kernel = inputs data, kernel = inputs
if attrs['kernel_layout'] == 'HWIO': if attrs['kernel_layout'] == 'HWIO':
# Handle HWIO layout. This is common in TF graph. # Handle HWIO layout. This is common in TF graph.
kernel = F.transpose(kernel, axes=(3, 2, 0, 1)) kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
elif attrs['kernel_layout'] == 'HWOI': elif attrs['kernel_layout'] == 'HWOI':
# Handle HWOI layout. This is common in TF depthwise conv2d graph. # Handle HWOI layout. This is common in TF depthwise conv2d graph.
kernel = F.transpose(kernel, axes=(2, 3, 0, 1)) kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
elif attrs['kernel_layout'] != 'OIHW': elif attrs['kernel_layout'] != 'OIHW':
return None return None
warnings.warn("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to " logger.warning("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
+ "fallback to NCHW. This can result in performance degradation.") + "fallback to NCHW. This can result in performance degradation.")
# Set new attrs for the tranposed conv. # Set new attrs for the tranposed conv.
new_attrs = {k: attrs[k] for k in attrs.keys()} new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['data_layout'] = 'NCHW' new_attrs['data_layout'] = 'NCHW'
new_attrs['kernel_layout'] = 'OIHW' new_attrs['kernel_layout'] = 'OIHW'
# Convert from NHWC to NCHW. # Convert from NHWC to NCHW.
data = F.transpose(data, axes=(0, 3, 1, 2)) data = relay.transpose(data, axes=(0, 3, 1, 2))
conv = F.nn.conv2d(data, kernel, **new_attrs) conv = relay.nn.conv2d(data, kernel, **new_attrs)
# Convert back to original NHWC layout. # Convert back to original NHWC layout.
out = F.transpose(conv, axes=(0, 2, 3, 1)) out = relay.transpose(conv, axes=(0, 2, 3, 1))
return out return out
return None return None
...@@ -72,22 +72,22 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N ...@@ -72,22 +72,22 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
@tvm.target.generic_func @tvm.target.generic_func
def conv2d_legalize(attrs, inputs, arg_dtypes, F): def conv2d_legalize(attrs, inputs, types):
"""Legalizes Conv2D op. """Legalizes Conv2D op.
Parameters Parameters
---------- ----------
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs attrs : tvm.attrs.Attrs
Attributes of current convolution Attributes of current convolution
inputs : list of tvm.relay.Expr inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized. The args of the Relay expr to be legalized
arg_dtypes : list of types types : list of types
List of types of input arguments List of input and output types
F: symbol
The context, can be either nnvm.sym or relay.op Returns
Note -------
---- result : tvm.relay.Expr
Unlike other TOPI functions, this function operates on both graph level and operator level, The legalized expr
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
""" """
# not to change by default # not to change by default
return None return None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment