Commit 48a5aa06 by masahi Committed by Wuwei Lin

[Quantization] Fix annotation for multiply op (#4458)

* fix mul rewrite

* register Realize Rewrite for global avg pool and add test

* remove unnecessary check

* improve the test case
parent 123a4077
......@@ -214,8 +214,10 @@ def multiply_rewrite(ref_call, new_args, ctx):
# quantize lhs to INPUT field
if lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
# quantize rhs to WEIGHT field
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
if _analysis.check_constant(rhs_expr):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
else:
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
......
......@@ -278,13 +278,9 @@ Expr MulRealize(const Call& ref_call,
DataType dtype = cfg->dtype_activation;
if (lhs->dtype != dtype) {
ldata = Cast(ldata, dtype);
} else {
CHECK_EQ(lhs->dtype, dtype);
}
if (rhs->dtype != dtype) {
rdata = Cast(rdata, dtype);
} else {
CHECK_EQ(rhs->dtype, dtype);
}
Expr ret = ForwardOp(ref_call, {ldata, rdata});
......@@ -499,6 +495,9 @@ Expr AvgPoolRealize(const Call& ref_call,
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
RELAY_REGISTER_OP("nn.global_avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
Expr CastHintRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay import testing
def quantize_and_build(out):
f = relay.Function(relay.analysis.free_vars(out), out)
mod, params = testing.create_workload(f)
with relay.quantize.qconfig(skip_conv_layers=[]):
qmod = relay.quantize.quantize(mod, params)
relay.build(qmod, "llvm", params=params)
def test_mul_rewrite():
"""a test case where rhs of mul is not constant"""
data = relay.var("data", shape=(1, 16, 64, 64))
multiplier = relay.sigmoid(relay.var("data", shape=(1, 16, 1, 1)))
conv = relay.nn.conv2d(data, relay.var("weight"),
kernel_size=(3, 3),
padding=(1, 1),
channels=16)
act = relay.nn.relu(data=conv)
quantize_and_build(act * multiplier)
pool = relay.nn.global_avg_pool2d(data=act)
quantize_and_build(act * pool)
if __name__ == "__main__":
test_mul_rewrite()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment