Commit 975ab75c by Wuwei Lin Committed by ziheng

[Relay][Quantization] Fix out-of-date realize (#3790)

parent d3eb9cb8
......@@ -110,7 +110,6 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
LOG(FATAL) << "fall back to float computation";
data = Cast(data, Float(32));
data = Multiply(data, MakeConstantScalar(Float(32), factor));
return Cast(Round(data), dtype);
......@@ -147,15 +146,21 @@ Expr QuantizeRealize(const Call& ref_call,
}
float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
CHECK_GT(shift_nbit, 0);
CHECK_NE(shift_nbit, 0);
if (static_cast<int>(shift_nbit) == shift_nbit) {
// use right shift
if (cfg->round_for_shift) {
float round_bias = std::pow(2.0, shift_nbit - 1);
data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias)));
if (shift_nbit > 0) {
// use right shift
if (cfg->round_for_shift) {
float round_bias = std::pow(2.0, shift_nbit - 1);
data = Add(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(round_bias)));
}
data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
} else {
data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
}
data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} else {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment