Commit 975ab75c by Wuwei Lin Committed by ziheng

[Relay][Quantization] Fix out-of-date realize (#3790)

parent d3eb9cb8
...@@ -110,7 +110,6 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { ...@@ -110,7 +110,6 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
} else if (static_cast<int>(factor) == factor) { } else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor)); return Multiply(data, MakeConstantScalar(dtype, factor));
} else { } else {
LOG(FATAL) << "fall back to float computation";
data = Cast(data, Float(32)); data = Cast(data, Float(32));
data = Multiply(data, MakeConstantScalar(Float(32), factor)); data = Multiply(data, MakeConstantScalar(Float(32), factor));
return Cast(Round(data), dtype); return Cast(Round(data), dtype);
...@@ -147,15 +146,21 @@ Expr QuantizeRealize(const Call& ref_call, ...@@ -147,15 +146,21 @@ Expr QuantizeRealize(const Call& ref_call,
} }
float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
CHECK_GT(shift_nbit, 0); CHECK_NE(shift_nbit, 0);
if (static_cast<int>(shift_nbit) == shift_nbit) { if (static_cast<int>(shift_nbit) == shift_nbit) {
// use right shift if (shift_nbit > 0) {
if (cfg->round_for_shift) { // use right shift
float round_bias = std::pow(2.0, shift_nbit - 1); if (cfg->round_for_shift) {
data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias))); float round_bias = std::pow(2.0, shift_nbit - 1);
data = Add(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(round_bias)));
}
data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
} else {
data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
} }
data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
data = Clip(data, clip_min_imm, clip_max_imm); data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype); return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} else { } else {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment