Commit 1e4aea81 by Animesh Jain Committed by Yizhi Liu

[Legalize][QNN] Pass out_types to Legalize. Update QNN requantize to read from out_types. (#3782)

parent 17f8f96b
......@@ -206,10 +206,24 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
@reg.register_legalize("nn.conv2d")
def legalize_conv2d(attrs, inputs, arg_dtypes):
"""Legalize conv2d"""
from ... import op
return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op)
def legalize_conv2d(attrs, inputs, types):
"""Legalize conv2d op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.conv2d_legalize(attrs, inputs, types)
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
......
......@@ -42,11 +42,17 @@ Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef&
Expr new_e;
bool modified = false;
if (fop_legalize.count(op)) {
tvm::Array<tvm::relay::Type> arg_types;
// Collect input and output dtypes to pass on to Legalize API.
tvm::Array<tvm::relay::Type> types;
for (auto& expr : ref_call->args) {
arg_types.push_back(expr->checked_type());
types.push_back(expr->checked_type());
}
Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types);
types.push_back(ref_call->checked_type());
// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, types);
// Check if the transformation succeeded. If not, revert back to the original ref_call->op.
if (legalized_value.defined()) {
new_e = legalized_value;
modified = true;
......
......@@ -74,12 +74,12 @@ Expr DequantizeLower(const Expr& input_tensor,
Expr DequantizeLegalize(const Attrs& attrs,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
CHECK(dequantize_attrs != nullptr);
CHECK_EQ(arg_types.size(), 1);
CHECK_EQ(types.size(), 2);
return DequantizeLower(data, dequantize_attrs);
}
......
......@@ -85,13 +85,13 @@ Expr QuantizeLower(const Expr& input_tensor,
Expr QuantizeLegalize(const Attrs& attrs,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
CHECK(quantize_attrs != nullptr);
CHECK_EQ(arg_types.size(), 1);
CHECK_EQ(types.size(), 2);
return QuantizeLower(data, quantize_attrs);
}
......
......@@ -109,7 +109,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
* 7) Cast to the out_dtype.
*/
Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape) {
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
double double_multiplier = param->input_scale / param->output_scale;
// Choose high precision datatype to be int64. This is for avoiding overflow
......@@ -173,10 +173,10 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
auto shifted_int64_t = Add(output_zp, scaled_int64_t);
// 7) Clip to the out_dtype min/max.
auto q_min = GetQmin(param->out_dtype);
auto q_max = GetQmax(param->out_dtype);
auto q_min = GetQmin(out_dtype);
auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
return Cast(clipped_t, param->out_dtype);
return Cast(clipped_t, out_dtype);
}
/*
......@@ -193,25 +193,32 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
* Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
*/
Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& quantized_data = new_args[0];
const auto* param = attrs.as<RequantizeAttrs>();
CHECK(param != nullptr);
// Find input shape.
CHECK_EQ(arg_types.size(), 1);
auto input_dtype = arg_types[0];
auto input_tensor_type = input_dtype.as<TensorTypeNode>();
CHECK(input_tensor_type != nullptr) << "Type information missing."
CHECK_EQ(types.size(), 2);
auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
CHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;
// Find the output dtype.
auto out_type = types[1];
auto out_tensor_type = out_type.as<TensorTypeNode>();
CHECK(out_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = input_tensor_type->shape;
auto out_dtype = out_tensor_type->dtype;
// Check rounding validity.
CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
<< "QNN requantize supports two rounding modes - UPWARD and "
<< "TONEAREST";
return RequantizeLower(quantized_data, param, input_shape);
return RequantizeLower(quantized_data, param, input_shape, out_dtype);
}
/*
......@@ -261,7 +268,7 @@ The requantize operator converts one quantized tensor to another quantized
tensor. For the output tensor, we are provided with output scale and zero
point. The computation looks like this
Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.RequantizeAttrs")
......
......@@ -47,7 +47,7 @@ def test_legalize():
return y
@register_legalize("nn.conv2d", level=100)
def legalize_conv2d(attrs, inputs, arg_types):
def legalize_conv2d(attrs, inputs, types):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs)
......@@ -80,7 +80,7 @@ def test_legalize_none():
called = [False]
@register_legalize("nn.global_max_pool2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types):
def legalize_conv2d(attrs, inputs, types):
called[0] = True
return None
......@@ -103,12 +103,13 @@ def test_legalize_multi_input():
return func
@register_legalize("concatenate", level=100)
def legalize_concatenate(attrs, inputs, arg_types):
def legalize_concatenate(attrs, inputs, types):
# Check that the correct multi-input case is handled.
assert len(inputs) == 1
assert isinstance(inputs[0], tvm.relay.expr.Tuple)
assert len(arg_types) == 1
assert isinstance(arg_types[0], tvm.relay.ty.TupleType)
assert len(types) == 2
assert isinstance(types[0], tvm.relay.ty.TupleType)
assert isinstance(types[1], tvm.relay.ty.TensorType)
return None
def expected():
......@@ -153,9 +154,9 @@ def test_legalize_arm_layout_functional():
return func
@register_legalize("nn.conv2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types):
def legalize_conv2d(attrs, inputs, types):
from topi.arm_cpu.conv2d import _conv2d_legalize
return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op)
return _conv2d_legalize(attrs, inputs, types)
a = before()
b = run_opt_pass(a, transform.Legalize())
......
......@@ -18,10 +18,11 @@
"""Conv2D schedule for ARM CPU"""
from __future__ import absolute_import as _abs
import warnings
import logging
import tvm
from tvm import autotvm
from tvm import relay
import tvm.contrib.nnpack
from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
......@@ -35,6 +36,8 @@ from ..nn import conv2d_legalize
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices
logger = logging.getLogger('topi')
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
"""TOPI compute callback for conv2d
......@@ -671,7 +674,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
if layout != 'NCHW':
return None
if dilation != (1, 1):
warnings.warn("Does not support weight pre-transform for dilated convolution.")
logger.warning("Does not support weight pre-transform for dilated convolution.")
return None
data, kernel = tinfos[0:2]
......@@ -786,21 +789,36 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
return None
@conv2d_legalize.register("arm_cpu")
def _conv2d_legalize(attrs, inputs, arg_types, F):
if F.__name__ != 'tvm.relay.op':
return None
def _conv2d_legalize(attrs, inputs, arg_types):
"""Legalizes Conv2D op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
if attrs['data_layout'] == 'NHWC':
data, kernel = inputs
if attrs['kernel_layout'] == 'HWIO':
# Handle HWIO layout. This is common in TF graph.
kernel = F.transpose(kernel, axes=(3, 2, 0, 1))
kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
elif attrs['kernel_layout'] == 'HWOI':
# Handle HWOI layout. This is common in TF depthwise conv2d graph.
kernel = F.transpose(kernel, axes=(2, 3, 0, 1))
kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
elif attrs['kernel_layout'] != 'OIHW':
return None
warnings.warn("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
logger.warning("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
+ "fallback to NCHW. This can result in performance degradation.")
# Set new attrs for the tranposed conv.
new_attrs = {k: attrs[k] for k in attrs.keys()}
......@@ -808,9 +826,9 @@ def _conv2d_legalize(attrs, inputs, arg_types, F):
new_attrs['kernel_layout'] = 'OIHW'
# Convert from NHWC to NCHW.
data = F.transpose(data, axes=(0, 3, 1, 2))
conv = F.nn.conv2d(data, kernel, **new_attrs)
data = relay.transpose(data, axes=(0, 3, 1, 2))
conv = relay.nn.conv2d(data, kernel, **new_attrs)
# Convert back to original NHWC layout.
out = F.transpose(conv, axes=(0, 2, 3, 1))
out = relay.transpose(conv, axes=(0, 2, 3, 1))
return out
return None
......@@ -72,22 +72,22 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
@tvm.target.generic_func
def conv2d_legalize(attrs, inputs, arg_dtypes, F):
def conv2d_legalize(attrs, inputs, types):
"""Legalizes Conv2D op.
Parameters
----------
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized.
arg_dtypes : list of types
List of types of input arguments
F: symbol
The context, can be either nnvm.sym or relay.op
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level,
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# not to change by default
return None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment