Commit e8899285 by Wuwei Lin Committed by Tianqi Chen

[Relay][Quantize] Use fixed point mulplications (#4160)

parent 8b1fb4d5
......@@ -83,6 +83,7 @@ class QConfig(NodeBase):
"do_simulation": False,
"round_for_shift": True,
"debug_enabled_ops": None,
"rounding": "UPWARD"
}
# pylint: disable=no-member
......@@ -160,6 +161,9 @@ def qconfig(**kwargs):
is None, which means will try to call all operartors' annotate rewrite
function.
rounding: "UPWARD" or "TONEAREST"
Rounding direction for fixed point multiplications.
Returns
-------
config: QConfig
......
......@@ -126,7 +126,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "do_simulation==" << op->do_simulation << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops <<", ";
p->stream << "rounding==" << op->rounding;
p->stream << ")";
});
......
......@@ -75,6 +75,7 @@ class QConfigNode : public Node {
bool do_simulation = false;
bool round_for_shift = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
std::string rounding = "UPWARD";
void VisitAttrs(AttrVisitor* v) {
v->Visit("nbit_input", &nbit_input);
......@@ -88,6 +89,7 @@ class QConfigNode : public Node {
v->Visit("do_simulation", &do_simulation);
v->Visit("round_for_shift", &round_for_shift);
v->Visit("debug_enabled_ops", &debug_enabled_ops);
v->Visit("rounding", &rounding);
}
static constexpr const char* _type_key = "relay.quantize.QConfig";
......
......@@ -31,6 +31,7 @@
#include <tvm/relay/attrs/annotation.h>
#include "./quantize.h"
#include "../pattern_util.h"
#include "../../qnn/util.h"
namespace tvm {
namespace relay {
......@@ -97,7 +98,9 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
/* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
const Array<IndexExpr> &data_shape) {
const QConfig& cfg = QConfig::Current();
// here we assume the dtype of data is dtype activation
if (s1 == s2) return data;
......@@ -110,9 +113,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
data = Cast(data, Float(32));
data = Multiply(data, MakeConstantScalar(Float(32), factor));
return Cast(Round(data), dtype);
data = qnn::FixedPointMultiply(Cast(data, Int(64)), factor, data_shape, cfg->rounding);
return Cast(data, dtype);
}
}
......@@ -164,11 +166,12 @@ Expr QuantizeRealize(const Call& ref_call,
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} else {
// float computation
data = Cast(data, Float(32));
Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
data = Cast(data, Int(64));
data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm,
ref_call->type_as<TensorTypeNode>()->shape,
cfg->rounding);
data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
}
}
......@@ -355,7 +358,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
Expr dom_scale = MakeConstantScalar(Float(32), s);
for (size_t i = 0; i < ret.size(); ++i) {
float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype));
ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype, ref_args[i]->type_as<TensorTypeNode>()->shape));
}
*dtype_ptr = dtype;
......
......@@ -37,8 +37,6 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
// Lowering of qnn.requantize op
/*
* \brief Lower requantize to a sequence of ops.
* \param input_tensor The input tensor to requantize op.
......@@ -73,8 +71,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
// 2) If the input and output scales are same, we can skip the fixed point multiplication.
auto scaled_int64_t = tensor;
if (param->input_scale != param->output_scale) {
scaled_int64_t = FixedPointMuliply(scaled_int64_t, double_multiplier, input_shape,
param->rounding);
scaled_int64_t =
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
}
// 3) Add the output zero point.
......
......@@ -76,7 +76,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
return std::make_pair(significand, exponent);
}
Expr FixedPointMuliply(Expr tensor, double multiplier,
Expr FixedPointMultiply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape, const std::string& rounding) {
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
......@@ -121,6 +121,8 @@ Expr FixedPointMuliply(Expr tensor, double multiplier,
auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar =
Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
} else {
LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
}
// Add the rounding scalar.
tensor = Add(tensor, round_scalar);
......
......@@ -115,9 +115,9 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
* 2) Round the result.
* 3) Right shift the result
*/
Expr FixedPointMuliply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape,
const std::string& rounding);
Expr FixedPointMultiply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape,
const std::string& rounding);
} // namespace qnn
} // namespace relay
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment