Unverified Commit 38118bef by Animesh Jain Committed by GitHub

[ConvertLayout] Support QNN ops. (#5066)

* [ConvertLayout] Support QNN ops.

* Changing layouts to C.

* Fixing dilation.

* Empty commit.

Co-authored-by: Ubuntu <ubuntu@ip-172-31-53-55.us-west-2.compute.internal>
parent e54a9a97
......@@ -138,8 +138,6 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
"""
# pylint: disable=import-outside-toplevel
from tvm import relay
data_layout = attrs['data_layout']
kernel_layout = attrs['kernel_layout']
data, weight = inputs
assert desired_layout == 'NCHW', \
"Currently only transformation to NCHW layout is supported."
......@@ -147,13 +145,7 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
new_attrs['kernel_layout'] = 'OIHW'
if data_layout == 'NHWC' and kernel_layout == 'HWIO':
# Convert (NHWC, HWIO) to (NCHW, OIHW)
return relay.nn.conv2d(data, weight, **new_attrs)
if data_layout == 'NHWC' and kernel_layout == 'HWOI':
# Convert (NHWC, HWOI) to (NCHW, OIHW). Depthwise conv2d.
return relay.nn.conv2d(data, weight, **new_attrs)
return relay.nn.conv2d(data, weight, **new_attrs)
return None
......
......@@ -19,4 +19,4 @@
from __future__ import absolute_import as _abs
from .qnn import *
from .op import register_qnn_legalize
from . import legalizations
from . import legalizations, layout_conversions
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Convert layout related registration"""
from __future__ import absolute_import
from tvm.relay.op import op as reg
@reg.register_convert_op_layout("qnn.conv2d")
def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layout):
"""Convert Layout pass registration for QNN conv2d op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layout : str
The desired layout
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
# pylint: disable=import-outside-toplevel
from tvm import relay
assert desired_layout == 'NCHW', \
"Currently only transformation to NCHW layout is supported."
if desired_layout == 'NCHW':
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
new_attrs['kernel_layout'] = 'OIHW'
return relay.qnn.op.conv2d(*inputs, **new_attrs)
return None
......@@ -39,7 +39,7 @@ template <typename T>
Array<Array<Layout>> BinaryConv2DInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
const Array<tvm::relay::Type>& old_in_types) {
const T* params = attrs.as<T>();
// We always make other operators to fit the layouts of convolution layers
......
......@@ -34,23 +34,6 @@
namespace tvm {
namespace relay {
template<typename T>
Array<Array<Layout> > ConvInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const T* params = attrs.as<T>();
// We always make other operators to fit the layouts of convolution layers
// So this inference ignores all inputs
return Array<Array<Layout> >{{params->data_layout, params->kernel_layout},
{params->out_layout == "" ?
params->data_layout : params->out_layout}};
}
template <typename T>
Expr MakeConv(Expr data,
Expr weight,
......@@ -1048,7 +1031,7 @@ Array<Array<Layout> > Dilation2DInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const Array<tvm::relay::Type> &old_in_types) {
const T* params = attrs.as<T>();
// We always make other operators to fit the layouts of convolution layers
......
......@@ -431,6 +431,21 @@ bool Dilation2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}
template<typename T>
Array<Array<Layout> > ConvInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type> &old_in_types) {
const T* params = attrs.as<T>();
// We always make other operators to fit the layouts of convolution layers
// So this inference ignores all inputs
return Array<Array<Layout> >{{params->data_layout, params->kernel_layout},
{params->out_layout == "" ?
params->data_layout : params->out_layout}};
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_NN_CONVOLUTION_H_
......@@ -272,10 +272,10 @@ Array<Array<Layout> > PReluInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const Array<tvm::relay::Type> &old_in_types) {
CHECK_EQ(old_in_layouts.size(), 2U);
CHECK_EQ(old_in_shapes.size(), 2U);
CHECK_EQ(old_in_types.size(), 2U);
Layout data_layout = old_in_layouts[0];
if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 2U);
......@@ -615,9 +615,15 @@ TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
Array<Array<Layout>> BatchNormInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
const Array<tvm::relay::Type>& old_in_types) {
BatchNormAttrs* param = const_cast<BatchNormAttrs*>(attrs.as<BatchNormAttrs>());
Array<Array<IndexExpr>> old_in_shapes;
for (auto old_in_t : old_in_types) {
CHECK(old_in_t.as<TensorTypeNode>());
old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
}
size_t axis =
param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis);
......
......@@ -39,7 +39,7 @@ Array<Array<Layout> > PadInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const Array<tvm::relay::Type> &old_in_types) {
// NOTE: Discard "const" qualifier here.
PadAttrs *params = const_cast<PadAttrs*>(attrs.as<PadAttrs>());
......
......@@ -41,7 +41,7 @@ Array<Array<Layout> > PoolInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const Array<tvm::relay::Type> &old_in_types) {
// NOTE: Discard "const" qualifier here.
T *params = const_cast<T*>(attrs.as<T>());
......
......@@ -39,7 +39,7 @@ Array<Array<Layout> > UpsamplingInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const Array<tvm::relay::Type> &old_in_types) {
// NOTE: Discard "const" qualifier here.
T *params = const_cast<T*>(attrs.as<T>());
......
......@@ -122,11 +122,16 @@ Array<Integer> GetExcludeAxes(size_t indim,
Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
ReduceAttrs* params = const_cast<ReduceAttrs*>(attrs.as<ReduceAttrs>());
// Get the reduce axes.
Array<Array<IndexExpr>> old_in_shapes;
for (auto old_in_t : old_in_types) {
CHECK(old_in_t.as<TensorTypeNode>());
old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
}
uint32_t indim = old_in_shapes[0].size();
auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);
......
......@@ -275,54 +275,6 @@ Array<te::Tensor> ConcatenateCompute(const Attrs& attrs,
return { topi::concatenate(inputs, param->axis) };
}
Array<Array<Layout>> ConcatenateLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
ConcatenateAttrs* param = const_cast<ConcatenateAttrs*>(attrs.as<ConcatenateAttrs>());
size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
static_cast<size_t>(param->axis);
Layout ret;
bool is_new_layout_selected = false;
if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
// If all the new input layouts are same, the new in layout gets selected. For axis, the new
// axis in the new layout is identified. The param->axis is then modified on the fly to conform
// to the new input layout.
const auto& concate_dim = old_in_layouts[0][axis];
bool all_input_layouts_same = true;
for (auto new_layout : new_in_layouts) {
if (!new_layout.Equals(new_in_layouts[0])) {
all_input_layouts_same = false;
}
}
if (all_input_layouts_same) {
auto new_index = new_in_layouts[0].IndexOf(concate_dim);
ret = new_in_layouts[0];
param->axis = new_index;
is_new_layout_selected = true;
}
}
if (!is_new_layout_selected) {
// this function is called on the original correct relay ir
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
if (old_in_layouts[i].defined()) {
ret = old_in_layouts[i];
break;
}
}
if (ret.ndim() <= axis || !ret[axis].IsPrimal()) {
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
}
return Array<Array<Layout> > {Array<Layout>(old_in_layouts.size(), ret), {ret}};
}
Expr MakeConcatenate(Expr data,
int axis) {
auto attrs = make_object<ConcatenateAttrs>();
......@@ -1933,7 +1885,14 @@ Array<Array<Layout> > StridedSliceInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
const Array<tvm::relay::Type>& old_in_types) {
Array<Array<IndexExpr>> old_in_shapes;
for (auto old_in_t : old_in_types) {
CHECK(old_in_t.as<TensorTypeNode>());
old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
}
CHECK(old_in_layouts.defined());
CHECK_EQ(old_in_layouts.size(), 1);
CHECK(old_in_shapes.defined());
......
......@@ -25,6 +25,7 @@
#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_
#include <tvm/ir/error.h>
#include <tvm/relay/attrs/transform.h>
#include <vector>
#include <algorithm>
#include <limits>
......@@ -124,6 +125,63 @@ bool ConcatenateRel(const Array<Type>& types,
return true;
}
static inline Array<Array<Layout>> ConcatenateLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type> &old_in_types) {
ConcatenateAttrs* param = const_cast<ConcatenateAttrs*>(attrs.as<ConcatenateAttrs>());
Array<Array<IndexExpr>> old_in_shapes;
CHECK_EQ(old_in_types.size(), 1);
for (auto old_in_tuple_t : old_in_types) {
CHECK(old_in_tuple_t.as<TupleTypeNode>());
for (auto old_in_t : old_in_tuple_t.as<TupleTypeNode>()->fields) {
old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
}
}
size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
static_cast<size_t>(param->axis);
Layout ret;
bool is_new_layout_selected = false;
if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
// If all the new input layouts are same, the new in layout gets selected. For axis, the new
// axis in the new layout is identified. The param->axis is then modified on the fly to conform
// to the new input layout.
const auto& concate_dim = old_in_layouts[0][axis];
bool all_input_layouts_same = true;
for (auto new_layout : new_in_layouts) {
if (!new_layout.Equals(new_in_layouts[0])) {
all_input_layouts_same = false;
}
}
if (all_input_layouts_same) {
auto new_index = new_in_layouts[0].IndexOf(concate_dim);
ret = new_in_layouts[0];
param->axis = new_index;
is_new_layout_selected = true;
}
}
if (!is_new_layout_selected) {
// this function is called on the original correct relay ir
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
if (old_in_layouts[i].defined()) {
ret = old_in_layouts[i];
break;
}
}
if (ret.ndim() <= axis || !ret[axis].IsPrimal()) {
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
}
}
return Array<Array<Layout> > {Array<Layout>(old_in_layouts.size(), ret), {ret}};
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_TENSOR_TRANSFORM_H_
......@@ -25,6 +25,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../transforms/pattern_util.h"
#include "../../transforms/infer_layout_util.h"
#include "../util.h"
#include "op_common.h"
......@@ -32,6 +33,23 @@ namespace tvm {
namespace relay {
namespace qnn {
/*! \brief Infer layout for QNN binary broadcast operators */
Array<Array<Layout> > QnnBinaryBroadcastLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// Use Relay Binary Broadcast Infer correct layout.
auto layouts = BinaryBroadcastLayout(attrs, new_in_layouts, old_in_layouts, old_in_types);
// Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
// tensors can be treated as C.
Layout channel_layout = Layout("C");
Array<Layout> input_layouts = {layouts[0][0], layouts[0][1], channel_layout, channel_layout,
channel_layout, channel_layout, channel_layout, channel_layout};
Array<Layout> output_layouts = layouts[1];
return {input_layouts, output_layouts};
}
/*
* \brief Canonicalizes the QNN add op.
* \param attrs The QNN concatenate attrs.
......@@ -118,7 +136,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
QNN_REGISTER_BINARY_OP("add")
.describe("Elementwise add with with broadcasting for quantized tensors.")
.set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize);
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnAddCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnBinaryBroadcastLayout);
} // namespace qnn
} // namespace relay
......
......@@ -28,6 +28,7 @@
#include <tvm/relay/qnn/attrs.h>
#include "../../op/tensor/transform.h"
#include "../../transforms/pattern_util.h"
#include "../../transforms/infer_layout_util.h"
#include "../util.h"
namespace tvm {
......@@ -70,6 +71,43 @@ bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& at
return ConcatenateRel<ConcatenateAttrs>(tensor_types, 2, attrs, reporter);
}
Array<Array<Layout>> QnnConcatenateLayout(const Attrs& attrs, const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// Collect the layouts and types to reuse Relay Concatenate Infer Correct Layout.
CHECK_EQ(old_in_types.size(), 5);
auto input_tuple_type = old_in_types[0].as<TupleTypeNode>();
CHECK(input_tuple_type);
auto num_input_tensors = input_tuple_type->fields.size();
Array<Layout> relay_new_in_layouts(nullptr);
if (new_in_layouts.defined()) {
relay_new_in_layouts =
Array<Layout>(new_in_layouts.begin(), new_in_layouts.begin() + num_input_tensors);
}
Array<Layout> relay_old_in_layouts(nullptr);
if (old_in_layouts.defined()) {
relay_old_in_layouts =
Array<Layout>(old_in_layouts.begin(), old_in_layouts.begin() + num_input_tensors);
}
// Use Relay Concatenate Infer Correct layout to infer the layouts for data tensors.
auto layouts =
ConcatenateLayout(attrs, relay_new_in_layouts, relay_old_in_layouts, {old_in_types[0]});
// Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
// tensors can be treated as channel layout. Total number of these tensors are 2 * num of data
// tensors (scale and zero point for each input data tensor) + 2 for the output data tensor.
Layout channel_layout = Layout("C");
Array<Layout> input_layouts = layouts[0];
for (size_t i = 0; i < 2 * num_input_tensors + 2; i++) {
input_layouts.push_back(channel_layout);
}
Array<Layout> output_layouts = layouts[1];
return {input_layouts, output_layouts};
}
Expr MakeQnnConcatenate(Expr data, Expr input_scales, Expr input_zero_points, Expr output_scale,
Expr output_zero_point, int axis) {
auto attrs = make_object<ConcatenateAttrs>();
......@@ -161,7 +199,8 @@ RELAY_REGISTER_OP("qnn.concatenate")
.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("QnnConcatenate", QnnConcatenateRel)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize);
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConcatenateLayout);
TVM_REGISTER_GLOBAL("relay.qnn.op._make.concatenate")
.set_body_typed(MakeQnnConcatenate);
......
......@@ -68,6 +68,23 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return Conv2DRel<Conv2DAttrs>(tensor_types, 3, attrs, reporter);
}
Array<Array<Layout>> QnnConvInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// Use Relay Conv2D Infer correct layout.
auto layouts =
ConvInferCorrectLayout<Conv2DAttrs>(attrs, new_in_layouts, old_in_layouts, old_in_types);
// Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
// tensors can be treated as channel layout.
Layout channel_layout = Layout("C");
Array<Layout> input_layouts = {layouts[0][0], layouts[0][1], channel_layout,
channel_layout, channel_layout, channel_layout};
Array<Layout> output_layouts = layouts[1];
return {input_layouts, output_layouts};
}
bool is_depthwise(const Conv2DAttrs* param) {
return param->channels.defined() && tvm::tir::Equal(param->channels, param->groups) &&
param->groups != 1;
......@@ -684,7 +701,8 @@ operator to understand how to scale back the int32 output to (u)int8.
.add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.")
.set_support_level(11)
.add_type_rel("QnnConv2D", QnnConv2DRel)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize);
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnConvInferCorrectLayout);
TVM_REGISTER_GLOBAL("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D);
......
......@@ -26,6 +26,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../transforms/pattern_util.h"
#include "../../transforms/infer_layout_util.h"
#include "../util.h"
namespace tvm {
......@@ -34,6 +35,79 @@ namespace qnn {
TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
Array<Array<Layout>> RequantizeInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
RequantizeAttrs* param = const_cast<RequantizeAttrs*>(attrs.as<RequantizeAttrs>());
Array<Array<IndexExpr>> old_in_shapes;
for (auto old_in_t : old_in_types) {
CHECK(old_in_t.as<TensorTypeNode>());
old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
}
Array<Layout> input_layouts, output_layouts;
if (new_in_layouts.defined()) {
// Adapt to new layout. The axis has to change.
// Record original reduce axis. Convert to the modified layout axis.
CHECK_EQ(new_in_layouts.size(), 5);
CHECK_EQ(old_in_layouts.size(), 5);
// 1) Get the axis.
int axis = param->axis;
axis = (axis == -1) ? old_in_shapes[0].size() - 1 : axis;
// 2) Collect the original axis
std::string old_dim = old_in_layouts[0][axis].name();
// 3) Collect the new axes by walking new_layout.
tvm::Integer new_axis;
std::string new_layout_string = "";
int axis_index = 0;
for (auto iter_var : new_in_layouts[0]->axes) {
const auto& layout_axis = LayoutAxis::Get(iter_var);
const std::string& layout_dim = layout_axis.name();
if (old_dim == layout_dim) {
new_axis = tvm::Integer(axis_index);
}
// Collect only the primal axis.
if (layout_axis.IsPrimal()) {
new_layout_string += layout_dim;
axis_index++;
}
}
// 4) Set the new axis and layout.
Layout new_layout = Layout(new_layout_string);
// Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
// tensors can be treated as channel layout.
Layout channel_layout = Layout("C");
input_layouts = {new_layout, channel_layout, channel_layout, channel_layout, channel_layout};
output_layouts = {new_layout};
param->axis = new_axis;
} else if (old_in_layouts.defined()) {
// If the new layout is undefined, set the old layout as the inferred layout.
CHECK_EQ(old_in_layouts.size(), 5);
Layout old_layout = old_in_layouts[0];
// Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
// tensors can be treated as channel layout.
Layout channel_layout = Layout("C");
input_layouts = {old_layout, channel_layout, channel_layout, channel_layout, channel_layout};
output_layouts = {old_layout};
} else {
// Set the layouts to undef.
Layout undef = Layout::Undef();
input_layouts = Array<Layout>(5, undef);
output_layouts = {undef};
}
return Array<Array<Layout>>{input_layouts, output_layouts};
}
// Lowering of qnn.requantize op
/*
......@@ -247,7 +321,8 @@ Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("Requantize", RequantizeRel)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize);
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", RequantizeInferCorrectLayout);
TVM_REGISTER_GLOBAL("relay.qnn.op._make.requantize")
.set_body_typed(MakeRequantize);
......
......@@ -90,7 +90,7 @@ inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& o
* This can be undefined, which means we call this function before alternating
* any operators.
* \param old_in_layouts The layouts of input arguments before alter_op_layout.
* \param old_in_shapes The shapes of old input arguments.
* \param old_in_types The types of old input arguments.
* \return infered_layout An array of two elements that are inferred input layouts and
* inferred output layouts.
*/
......@@ -98,13 +98,13 @@ using FInferCorrectLayout = runtime::TypedPackedFunc<
Array<Array<Layout>>(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes)>;
const Array<tvm::relay::Type> &old_in_types)>;
/*! \brief take arbitrary input layout and copy to output */
inline Array<Array<Layout> > ElemwiseArbitraryLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const Array<tvm::relay::Type> &old_in_types) {
Layout ret;
if (new_in_layouts.defined()) {
......@@ -126,8 +126,13 @@ inline Array<Array<Layout> > ElemwiseArbitraryLayout(const Attrs& attrs,
inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const Array<tvm::relay::Type> &old_in_types) {
Array<Layout> layouts;
Array<Array<IndexExpr>> old_in_shapes;
for (auto old_in_t : old_in_types) {
CHECK(old_in_t.as<TensorTypeNode>());
old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
}
if (new_in_layouts.defined()) {
layouts.assign(new_in_layouts.begin(), new_in_layouts.end());
......@@ -203,7 +208,7 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
*/
static inline std::tuple<Array<Layout>, Array<Layout>, bool> InferCorrectLayouts(
const Call& call, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
const Array<tvm::relay::Type>& old_in_types) {
static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");
if (!call->op.as<OpNode>()) {
return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
......@@ -213,7 +218,7 @@ static inline std::tuple<Array<Layout>, Array<Layout>, bool> InferCorrectLayouts
if (finfer_layout.count(op)) {
Array<Array<Layout>> inferred_layouts;
inferred_layouts =
finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_shapes);
finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types);
CHECK_EQ(inferred_layouts.size(), 2)
<< "FInferCorrectLayout should return an array with size of 2";
for (auto x : inferred_layouts) {
......
......@@ -225,7 +225,6 @@ template <class TransformMemorizerT>
Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
std::vector<LayoutAlternatedExpr<TransformMemorizerT>> inputs;
std::vector<Expr> normal_new_args;
Array<Array<IndexExpr>> input_shapes;
// NOTE: discard the "const" qualifier
// TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx);
......@@ -273,21 +272,16 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
new_in.push_back(inp->new_layout);
}
// Collect input types to pass on to Infer Correct Layout.
tvm::Array<tvm::relay::Type> types;
for (auto arg : ref_call->args) {
if (arg->IsInstance<TupleNode>()) { // flatten tuple
Tuple tuple_arg = Downcast<Tuple>(arg);
for (auto x : tuple_arg->fields) {
input_shapes.push_back(x->type_as<TensorTypeNode>()->shape);
}
} else {
input_shapes.push_back(arg->type_as<TensorTypeNode>()->shape);
}
types.push_back(arg->checked_type());
}
// old_in, old_out = op.infer(old_in)
bool success = false;
std::tie(old_in, old_out, success) =
InferCorrectLayouts(ref_call, Array<Layout>(nullptr), old_in, input_shapes);
InferCorrectLayouts(ref_call, Array<Layout>(nullptr), old_in, types);
if (!success) {
return Expr(nullptr);
}
......@@ -307,7 +301,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
if (new_call->op->IsInstance<OpNode>()) {
success = false;
std::tie(new_in2, new_out, success) =
InferCorrectLayouts(new_call, new_in, old_in, input_shapes);
InferCorrectLayouts(new_call, new_in, old_in, types);
if (!success) {
return Expr(nullptr);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment