Commit 671421a8 by Animesh Jain Committed by Zhi

[Relay][QNN] QNNtoRelay & QNNLegalize Pass utility using Relay Legalize API. (#3838)

parent a5def36f
...@@ -522,10 +522,15 @@ TVM_DLL Pass AlterOpLayout(); ...@@ -522,10 +522,15 @@ TVM_DLL Pass AlterOpLayout();
/*! /*!
* \brief Legalizes an expr with another expression. * \brief Legalizes an expr with another expression.
* \param legalize_map_attr_name The Op's attr name which corresponds to the legalize rule function.
* One can collect and isolate similar type of legalize transformations using this param. For
* example, transformations that only apply to Dialects can be isolated into a FTVMDialectLegalize
* string. This pass calls only those transformations that have been registered using the supplied
* legalize_map_attr_name.
* *
* \return The pass. * \return The pass.
*/ */
TVM_DLL Pass Legalize(); TVM_DLL Pass Legalize(const std::string& legalize_map_attr_name = "FTVMLegalize");
/*! /*!
* \brief Canonicalize cast expressions to make operator fusion more efficient. * \brief Canonicalize cast expressions to make operator fusion more efficient.
......
...@@ -18,3 +18,4 @@ ...@@ -18,3 +18,4 @@
"""QNN dialect operators and IR passes.""" """QNN dialect operators and IR passes."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import op from . import op
from . import transform
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
"""Neural network related operators.""" """QNN dialect related operators."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .qnn import * from .qnn import *
from .op import register_qnn_legalize
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=unused-argument
"""The register functions for the QNN dialect."""
from tvm.relay.op.op import register as register
def register_qnn_legalize(op_name, legal_op=None, level=10):
"""Register legal transformation function for a QNN op
Parameters
----------
op_name : str
The name of the operator
legal_op: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
The function for transforming an expr to another expr.
level : int
The priority level
"""
return register(op_name, "FTVMQnnLegalize", legal_op, level)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,arguments-differ,no-else-return,unused-argument,missing-docstring
"""
QNN pass transformation infrastructure.
"""
from tvm import relay
def CanonicalizeOps():
"""Converts/Lowers an expression containing QNN ops to an expression containing only core
(non-Dialect) Relay ops. Each QNN op is lowered to a sequence of exisiting Relay ops. This is a
target-independent pass. One can register the lowering/transformation function for this op using
FTVMQnnCanonicalize attr_name for FTVMLegalize op attribute. An example of this transformation
is below
Examples
________
.. code-block:: python
# Original expression
qnn_expr = relay.qnn.op.requantize(y,
input_scale=1,
input_zero_point=0,
output_scale=1,
output_zero_point=0,
out_dtype='int8')
# We want to utilize all the existing Relay infrastucture. So, instead of supporting this
# QNN requantize op, we convert it into a sequence of existing Relay operators.
mod = relay.Module.from_expr(qnn_expr)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
relay_expr = mod['main']
print(relay_expr)
def @main(%quantized_data: Tensor[(200), int32]) -> Tensor[(200), int8] {
%0 = cast(%quantized_data, dtype="int64") /* ty=Tensor[(200), int64] */;
%1 = multiply(%0, 2 /* ty=int64 */) /* ty=Tensor[(200), int64] */;
%2 = multiply(%1, 1073741824 /* ty=int64 */) /* ty=Tensor[(200), int64] */;
%3 = add(%2, 1073741824 /* ty=int64 */) /* ty=Tensor[(200), int64] */;
%4 = right_shift(%3, 31 /* ty=int64 */) /* ty=Tensor[(200), int64] */;
%5 = add(0 /* ty=int64 */, %4) /* ty=Tensor[(200), int64] */;
%6 = clip(%5, a_min=-128f, a_max=127f) /* ty=Tensor[(200), int64] */;
cast(%6, dtype="int8") /* ty=Tensor[(200), int8] */
}
Returns
-------
ret : tvm.relay.Pass
The registered pass that canonicalizes QNN ops to Relay ops.
"""
return relay.transform.Legalize("FTVMQnnCanonicalize")
def Legalize():
"""Legalizes QNN ops. As opposed to Relay Legalize, this one legalizes only QNN ops. One can
register a transformation/legalization function for an op by using the FTVMQnnLegalize attr_name
for FTVMLegalize op attribute. The isolation of QNN and Relay Legalize gives us separation of
concerns, leading to a better software practice. The legalization can be configured to happen
per target. An example of this type of legalization is shown below.
Examples
________
Suppose the original graph is as follows
data(u8) weight(u8)
| |
| |
qnn.conv2d (int32)
|
|
nn.relu (int32)
Now, we know that Intel Cascade Lake has VNNI instructions to speedup convolution. However, it
only works on u8 x i8 inputs. So, here, we can use QNN Legalize to transform the above graph as
follows
data(u8) weight(u8)
| |
| |
| requantize(i8)
| |
| |
qnn.conv2d (int32)
|
|
nn.relu (int32)
In this legalization, since we have isolated legalization for QNN ops, it will only trigger the
transformation for qnn.conv2d (and not nn.relu). This pass can be followed by CanonicalizeOps to
further lower the qnn.requantize and qnn.conv2d into an expr containing only Relay ops.
Returns
-------
ret : tvm.relay.Pass
The registered pass that legalizes QNN ops.
"""
return relay.transform.Legalize("FTVMQnnLegalize")
...@@ -414,19 +414,24 @@ def AlterOpLayout(): ...@@ -414,19 +414,24 @@ def AlterOpLayout():
return _transform.AlterOpLayout() return _transform.AlterOpLayout()
def Legalize(): def Legalize(legalize_map_attr_name="FTVMLegalize"):
"""Legalizes an expression with another expression. """Legalizes an expression with another expression.
This pass can be used to replace an expr with another expr for target This pass can be used to replace an expr with another expr for target
dependent optimizations. For example, one expr, though semnatically dependent optimizations. For example, one expr, though semnatically
equivalent to the other, can have better performance on a target. This pass equivalent to the other, can have better performance on a target. This pass
can be used to legalize the expr in a target-dependent manner. can be used to legalize the expr in a target-dependent manner.
Parameters
----------
legalize_map_attr_name : str
The Op's attr name which corresponds to the legalize rule function.
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.relay.Pass
The registered pass that rewrites an expr. The registered pass that rewrites an expr.
""" """
return _transform.Legalize() return _transform.Legalize(legalize_map_attr_name)
def RewriteAnnotatedOps(fallback_device): def RewriteAnnotatedOps(fallback_device):
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
*/ */
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
...@@ -35,48 +36,64 @@ namespace legalize { ...@@ -35,48 +36,64 @@ namespace legalize {
// Call registered FTVMLegalize of an op // Call registered FTVMLegalize of an op
// Returns the legalized expression // Returns the legalized expression
Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) { class Legalizer : public ExprMutator {
static auto fop_legalize = Op::GetAttr<FTVMLegalize>("FTVMLegalize"); public:
Op op = Downcast<Op>(ref_call->op); explicit Legalizer(const std::string& legalize_map_attr_name)
: legalize_map_attr_name_{legalize_map_attr_name} {}
Expr new_e;
bool modified = false; Expr VisitExpr_(const CallNode* call_node) {
if (fop_legalize.count(op)) { // Get the new_call node without any changes to current call node.
// Collect input and output dtypes to pass on to Legalize API. Expr new_e = ExprMutator::VisitExpr_(call_node);
tvm::Array<tvm::relay::Type> types; Call new_call = Downcast<Call>(new_e);
for (auto& expr : ref_call->args) {
types.push_back(expr->checked_type()); // Collect the registered legalize function.
auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
Op op = Downcast<Op>(call_node->op);
if (fop_legalize.count(op)) {
// Collect the new_args.
tvm::Array<Expr> call_args = new_call->args;
// Collect input and output dtypes to pass on to Legalize API.
tvm::Array<tvm::relay::Type> types;
for (auto arg : call_node->args) {
types.push_back(arg->checked_type());
}
types.push_back(call_node->checked_type());
// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
// Reassign new_e if the transformation succeeded.
if (legalized_value.defined()) {
// Check that the returned Expr from legalize is CallNode.
const CallNode* legalized_call_node = legalized_value.as<CallNode>();
CHECK(legalized_call_node)
<< "Can only replace the original operator with another call node";
new_e = legalized_value;
}
} }
types.push_back(ref_call->checked_type());
// Transform the op by calling the registered legalize function. return new_e;
Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, types);
// Check if the transformation succeeded. If not, revert back to the original ref_call->op.
if (legalized_value.defined()) {
new_e = legalized_value;
modified = true;
}
}
if (!modified) {
new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs);
} }
const CallNode* new_call = new_e.as<CallNode>(); private:
CHECK(new_call) << "Can only replace the original operator with another call node"; std::string legalize_map_attr_name_;
return GetRef<Call>(new_call); };
}
Expr Legalize(const Expr& expr) { return ForwardRewrite(expr, Legalizer, nullptr); } Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
return Legalizer(legalize_map_attr_name).Mutate(expr);
}
} // namespace legalize } // namespace legalize
namespace transform { namespace transform {
Pass Legalize() { Pass Legalize(const std::string& legalize_map_attr_name) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, Module m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f)); return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
}; };
return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")}); return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")});
} }
......
...@@ -72,9 +72,9 @@ Expr DequantizeLower(const Expr& input_tensor, ...@@ -72,9 +72,9 @@ Expr DequantizeLower(const Expr& input_tensor,
return scaled_output; return scaled_output;
} }
Expr DequantizeLegalize(const Attrs& attrs, Expr DequantizeQnnCanonicalize(const Attrs& attrs,
const Array<Expr>& new_args, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& types) { const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1); CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0]; auto& data = new_args[0];
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>(); const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
...@@ -93,7 +93,7 @@ The input is always quantized (int8, uint8) and will be converted to float32 giv ...@@ -93,7 +93,7 @@ The input is always quantized (int8, uint8) and will be converted to float32 giv
.add_argument("data", "Tensor", "The tensor to dequantize.") .add_argument("data", "Tensor", "The tensor to dequantize.")
.set_support_level(11) .set_support_level(11)
.add_type_rel("Dequantize", DequantizeRel) .add_type_rel("Dequantize", DequantizeRel)
.set_attr<FTVMLegalize>("FTVMLegalize", DequantizeLegalize); .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize);
TVM_REGISTER_API("relay.qnn.op._make.dequantize") TVM_REGISTER_API("relay.qnn.op._make.dequantize")
.set_body_typed(MakeDequantize); .set_body_typed(MakeDequantize);
......
...@@ -83,9 +83,9 @@ Expr QuantizeLower(const Expr& input_tensor, ...@@ -83,9 +83,9 @@ Expr QuantizeLower(const Expr& input_tensor,
return clamp_out_dtype; return clamp_out_dtype;
} }
Expr QuantizeLegalize(const Attrs& attrs, Expr QuantizeQnnCanonicalize(const Attrs& attrs,
const Array<Expr>& new_args, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& types) { const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1); CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0]; auto& data = new_args[0];
const auto* quantize_attrs = attrs.as<QuantizeAttrs>(); const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
...@@ -111,7 +111,7 @@ scale and zero point. ...@@ -111,7 +111,7 @@ scale and zero point.
.add_argument("data", "Tensor", "The tensor to quantize.") .add_argument("data", "Tensor", "The tensor to quantize.")
.set_support_level(11) .set_support_level(11)
.add_type_rel("Quantize", QuantizeRel) .add_type_rel("Quantize", QuantizeRel)
.set_attr<FTVMLegalize>("FTVMLegalize", QuantizeLegalize); .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QuantizeQnnCanonicalize);
TVM_REGISTER_API("relay.qnn.op._make.quantize") TVM_REGISTER_API("relay.qnn.op._make.quantize")
.set_body_typed(MakeQuantize); .set_body_typed(MakeQuantize);
......
...@@ -192,8 +192,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, ...@@ -192,8 +192,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
* *
* Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
*/ */
Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args, Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& types) { const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1); CHECK_EQ(new_args.size(), 1);
auto& quantized_data = new_args[0]; auto& quantized_data = new_args[0];
const auto* param = attrs.as<RequantizeAttrs>(); const auto* param = attrs.as<RequantizeAttrs>();
...@@ -276,7 +276,7 @@ Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) ...@@ -276,7 +276,7 @@ Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
.add_argument("data", "Tensor", "The quantized input tensor.") .add_argument("data", "Tensor", "The quantized input tensor.")
.set_support_level(11) .set_support_level(11)
.add_type_rel("Requantize", RequantizeRel) .add_type_rel("Requantize", RequantizeRel)
.set_attr<FTVMLegalize>("FTVMLegalize", RequantizeLegalize); .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize);
TVM_REGISTER_API("relay.qnn.op._make.requantize") TVM_REGISTER_API("relay.qnn.op._make.requantize")
.set_body_typed(MakeRequantize); .set_body_typed(MakeRequantize);
......
...@@ -92,6 +92,51 @@ def test_legalize_none(): ...@@ -92,6 +92,51 @@ def test_legalize_none():
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
assert(called[0]) assert(called[0])
def test_legalize_multiple_ops():
"""Test directly replacing an operator with a new one"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var('weight', shape=(64, 64, 3, 3))
y = relay.nn.conv2d(x, weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.nn.relu(y)
y = relay.Function([x, weight], y)
return y
@register_legalize("nn.conv2d", level=102)
def legalize_conv2d(attrs, inputs, types):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs)
@register_legalize("nn.relu", level=103)
def legalize_conv2d(attrs, inputs, types):
data = inputs[0]
add = relay.add(tvm.relay.const(0, "float32"), data)
return relay.nn.relu(add)
def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var('weight', shape=(64, 64, 3, 3))
y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.add(tvm.relay.const(0, "float32"), y)
y = relay.nn.relu(y)
y = relay.Function([x, weight], y)
return y
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_legalize_multi_input(): def test_legalize_multi_input():
"""Test directly replacing an operator with a new one""" """Test directly replacing an operator with a new one"""
def before(): def before():
...@@ -102,7 +147,7 @@ def test_legalize_multi_input(): ...@@ -102,7 +147,7 @@ def test_legalize_multi_input():
func = relay.Function([x, y, z], func) func = relay.Function([x, y, z], func)
return func return func
@register_legalize("concatenate", level=100) @register_legalize("concatenate", level=104)
def legalize_concatenate(attrs, inputs, types): def legalize_concatenate(attrs, inputs, types):
# Check that the correct multi-input case is handled. # Check that the correct multi-input case is handled.
assert len(inputs) == 1 assert len(inputs) == 1
...@@ -153,7 +198,7 @@ def test_legalize_arm_layout_functional(): ...@@ -153,7 +198,7 @@ def test_legalize_arm_layout_functional():
func = relay.Function([data, kernel], y) func = relay.Function([data, kernel], y)
return func return func
@register_legalize("nn.conv2d", level=101) @register_legalize("nn.conv2d", level=105)
def legalize_conv2d(attrs, inputs, types): def legalize_conv2d(attrs, inputs, types):
from topi.arm_cpu.conv2d import _conv2d_legalize from topi.arm_cpu.conv2d import _conv2d_legalize
return _conv2d_legalize(attrs, inputs, types) return _conv2d_legalize(attrs, inputs, types)
...@@ -173,5 +218,6 @@ def test_legalize_arm_layout_functional(): ...@@ -173,5 +218,6 @@ def test_legalize_arm_layout_functional():
if __name__ == "__main__": if __name__ == "__main__":
test_legalize() test_legalize()
test_legalize_none() test_legalize_none()
test_legalize_multiple_ops()
test_legalize_multi_input() test_legalize_multi_input()
test_legalize_arm_layout_functional() test_legalize_arm_layout_functional()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test legalize pass"""
import numpy as np
import tvm
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.qnn.op import register_qnn_legalize
from tvm.relay import transform, analysis
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = relay.Module.from_expr(expr)
seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_qnn_legalize():
"""Test directly replacing an operator with a new one"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8')
y = relay.qnn.op.requantize(x,
input_scale=1,
input_zero_point=0,
output_scale=1,
output_zero_point=0,
out_dtype='int8')
y = relay.Function([x], y)
return y
@register_qnn_legalize("qnn.requantize", level=100)
def legalize_qnn_requantize(attrs, inputs, types):
data = inputs[0]
data = relay.add(relay.const(0, 'int8'), data)
y = relay.qnn.op.requantize(data,
input_scale=1,
input_zero_point=0,
output_scale=1,
output_zero_point=0,
out_dtype='int8')
return y
def expected():
x = relay.var("x", shape=(1, 64, 56, 56), dtype='int8')
y = relay.add(relay.const(0, 'int8'), x)
z = relay.qnn.op.requantize(y,
input_scale=1,
input_zero_point=0,
output_scale=1,
output_zero_point=0,
out_dtype='int8')
z = relay.Function([x], z)
return z
a = before()
# Check that Relay Legalize does not change the graph.
a = run_opt_pass(a, relay.transform.Legalize())
b = run_opt_pass(before(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
# Check that QNN Legalize modifies the graph.
a = run_opt_pass(a, relay.qnn.transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__":
test_qnn_legalize()
...@@ -41,7 +41,7 @@ def test_same_io_qnn_params(): ...@@ -41,7 +41,7 @@ def test_same_io_qnn_params():
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 0 assert func.astext().count('requantize') == 0
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
mod = relay.transform.Legalize()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"] func = mod["main"]
golden_output = np.concatenate((x_data, y_data), axis=axis) golden_output = np.concatenate((x_data, y_data), axis=axis)
...@@ -70,7 +70,7 @@ def test_different_io_qnn_params(): ...@@ -70,7 +70,7 @@ def test_different_io_qnn_params():
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 2 assert func.astext().count('requantize') == 2
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
mod = relay.transform.Legalize()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"] func = mod["main"]
golden_output = np.concatenate((x_data - 2, y_data - 3), axis=axis) golden_output = np.concatenate((x_data - 2, y_data - 3), axis=axis)
...@@ -99,7 +99,7 @@ def test_few_same_io_qnn_params(): ...@@ -99,7 +99,7 @@ def test_few_same_io_qnn_params():
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 1 assert func.astext().count('requantize') == 1
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
mod = relay.transform.Legalize()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"] func = mod["main"]
golden_output = np.concatenate((x_data + 1, y_data), axis=axis) golden_output = np.concatenate((x_data + 1, y_data), axis=axis)
...@@ -128,7 +128,7 @@ def test_same_i_qnn_params(): ...@@ -128,7 +128,7 @@ def test_same_i_qnn_params():
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 1 assert func.astext().count('requantize') == 1
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
mod = relay.transform.Legalize()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"] func = mod["main"]
golden_output = np.concatenate((x_data + 1, y_data + 1), axis=axis) golden_output = np.concatenate((x_data + 1, y_data + 1), axis=axis)
...@@ -137,7 +137,6 @@ def test_same_i_qnn_params(): ...@@ -137,7 +137,6 @@ def test_same_i_qnn_params():
op_res = intrp.evaluate(func)(x_data, y_data) op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output) np.testing.assert_equal(op_res.asnumpy(), golden_output)
if __name__ == '__main__': if __name__ == '__main__':
test_same_io_qnn_params() test_same_io_qnn_params()
test_different_io_qnn_params() test_different_io_qnn_params()
......
...@@ -31,7 +31,7 @@ def test_dequantize_op(): ...@@ -31,7 +31,7 @@ def test_dequantize_op():
input_zero_point=input_zero_point) input_zero_point=input_zero_point)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod) mod = relay.Module.from_expr(mod)
mod = relay.transform.Legalize()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None) graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
......
...@@ -31,7 +31,7 @@ def test_quantize_op(): ...@@ -31,7 +31,7 @@ def test_quantize_op():
output_zero_point=output_zero_point,out_dtype=out_dtype) output_zero_point=output_zero_point,out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod) mod = relay.Module.from_expr(mod)
mod = relay.transform.Legalize()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None) graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
......
...@@ -49,7 +49,7 @@ def test_requantize(): ...@@ -49,7 +49,7 @@ def test_requantize():
mod = relay.Function(relay.analysis.free_vars(mod), mod) mod = relay.Function(relay.analysis.free_vars(mod), mod)
mod = relay.Module.from_expr(mod) mod = relay.Module.from_expr(mod)
mod = relay.transform.Legalize()(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod)
return mod return mod
def same_scale_test(): def same_scale_test():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment