Commit 66cd036e by ziheng Committed by eqy

[Quantize] Skip for same input-output domain scale. (#2611)

parent 2ae3124f
...@@ -198,14 +198,20 @@ Expr QuantizeRealize(const Call& ref_call, ...@@ -198,14 +198,20 @@ Expr QuantizeRealize(const Call& ref_call,
// x * idom_scale = y * odom_scale // x * idom_scale = y * odom_scale
// => y = x * idom_scale / odom_scale // => y = x * idom_scale / odom_scale
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) { if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
// int32->int8
Expr data = n->data; Expr data = n->data;
float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale); float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale);
float odom_scale_imm = GetScalarFromConstant<float>(dom_scale); float odom_scale_imm = GetScalarFromConstant<float>(dom_scale);
if (idom_scale_imm == odom_scale_imm) {
// same domain scale, only clip
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
}
float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
// int32->int8
CHECK_GT(shift_nbit, 0); CHECK_GT(shift_nbit, 0);
if (static_cast<int>(shift_nbit) == shift_nbit) { if (static_cast<int>(shift_nbit) == shift_nbit) {
// use shift // use right shift
if (cfg->round_for_shift) { if (cfg->round_for_shift) {
float round_bias = std::pow(2.0, shift_nbit - 1); float round_bias = std::pow(2.0, shift_nbit - 1);
data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias))); data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias)));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment