Commit e8899285 by Wuwei Lin Committed by Tianqi Chen

[Relay][Quantize] Use fixed point mulplications (#4160)

parent 8b1fb4d5
...@@ -83,6 +83,7 @@ class QConfig(NodeBase): ...@@ -83,6 +83,7 @@ class QConfig(NodeBase):
"do_simulation": False, "do_simulation": False,
"round_for_shift": True, "round_for_shift": True,
"debug_enabled_ops": None, "debug_enabled_ops": None,
"rounding": "UPWARD"
} }
# pylint: disable=no-member # pylint: disable=no-member
...@@ -160,6 +161,9 @@ def qconfig(**kwargs): ...@@ -160,6 +161,9 @@ def qconfig(**kwargs):
is None, which means will try to call all operartors' annotate rewrite is None, which means will try to call all operartors' annotate rewrite
function. function.
rounding: "UPWARD" or "TONEAREST"
Rounding direction for fixed point multiplications.
Returns Returns
------- -------
config: QConfig config: QConfig
......
...@@ -126,7 +126,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -126,7 +126,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "do_simulation==" << op->do_simulation << ", "; p->stream << "do_simulation==" << op->do_simulation << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops <<", ";
p->stream << "rounding==" << op->rounding;
p->stream << ")"; p->stream << ")";
}); });
......
...@@ -75,6 +75,7 @@ class QConfigNode : public Node { ...@@ -75,6 +75,7 @@ class QConfigNode : public Node {
bool do_simulation = false; bool do_simulation = false;
bool round_for_shift = true; bool round_for_shift = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr)); Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
std::string rounding = "UPWARD";
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("nbit_input", &nbit_input); v->Visit("nbit_input", &nbit_input);
...@@ -88,6 +89,7 @@ class QConfigNode : public Node { ...@@ -88,6 +89,7 @@ class QConfigNode : public Node {
v->Visit("do_simulation", &do_simulation); v->Visit("do_simulation", &do_simulation);
v->Visit("round_for_shift", &round_for_shift); v->Visit("round_for_shift", &round_for_shift);
v->Visit("debug_enabled_ops", &debug_enabled_ops); v->Visit("debug_enabled_ops", &debug_enabled_ops);
v->Visit("rounding", &rounding);
} }
static constexpr const char* _type_key = "relay.quantize.QConfig"; static constexpr const char* _type_key = "relay.quantize.QConfig";
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <tvm/relay/attrs/annotation.h> #include <tvm/relay/attrs/annotation.h>
#include "./quantize.h" #include "./quantize.h"
#include "../pattern_util.h" #include "../pattern_util.h"
#include "../../qnn/util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -97,7 +98,9 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) { ...@@ -97,7 +98,9 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
/* calculate `data * s1 / s2`, use shift if possible */ /* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
const Array<IndexExpr> &data_shape) {
const QConfig& cfg = QConfig::Current();
// here we assume the dtype of data is dtype activation // here we assume the dtype of data is dtype activation
if (s1 == s2) return data; if (s1 == s2) return data;
...@@ -110,9 +113,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { ...@@ -110,9 +113,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
} else if (static_cast<int>(factor) == factor) { } else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor)); return Multiply(data, MakeConstantScalar(dtype, factor));
} else { } else {
data = Cast(data, Float(32)); data = qnn::FixedPointMultiply(Cast(data, Int(64)), factor, data_shape, cfg->rounding);
data = Multiply(data, MakeConstantScalar(Float(32), factor)); return Cast(data, dtype);
return Cast(Round(data), dtype);
} }
} }
...@@ -164,11 +166,12 @@ Expr QuantizeRealize(const Call& ref_call, ...@@ -164,11 +166,12 @@ Expr QuantizeRealize(const Call& ref_call,
data = Clip(data, clip_min_imm, clip_max_imm); data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype); return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} else { } else {
// float computation data = Cast(data, Int(64));
data = Cast(data, Float(32)); data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm,
Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale)); ref_call->type_as<TensorTypeNode>()->shape,
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); cfg->rounding);
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} }
} }
...@@ -355,7 +358,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args ...@@ -355,7 +358,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
Expr dom_scale = MakeConstantScalar(Float(32), s); Expr dom_scale = MakeConstantScalar(Float(32), s);
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {
float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale); float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype)); ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype, ref_args[i]->type_as<TensorTypeNode>()->shape));
} }
*dtype_ptr = dtype; *dtype_ptr = dtype;
......
...@@ -37,8 +37,6 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs); ...@@ -37,8 +37,6 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
// Lowering of qnn.requantize op // Lowering of qnn.requantize op
/* /*
* \brief Lower requantize to a sequence of ops. * \brief Lower requantize to a sequence of ops.
* \param input_tensor The input tensor to requantize op. * \param input_tensor The input tensor to requantize op.
...@@ -73,8 +71,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, ...@@ -73,8 +71,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
// 2) If the input and output scales are same, we can skip the fixed point multiplication. // 2) If the input and output scales are same, we can skip the fixed point multiplication.
auto scaled_int64_t = tensor; auto scaled_int64_t = tensor;
if (param->input_scale != param->output_scale) { if (param->input_scale != param->output_scale) {
scaled_int64_t = FixedPointMuliply(scaled_int64_t, double_multiplier, input_shape, scaled_int64_t =
param->rounding); FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
} }
// 3) Add the output zero point. // 3) Add the output zero point.
......
...@@ -76,7 +76,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift( ...@@ -76,7 +76,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
return std::make_pair(significand, exponent); return std::make_pair(significand, exponent);
} }
Expr FixedPointMuliply(Expr tensor, double multiplier, Expr FixedPointMultiply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape, const std::string& rounding) { const Array<IndexExpr>& input_shape, const std::string& rounding) {
// Choose high precision datatype to be int64. This is for avoiding overflow // Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values. // in multiplication of two int32 values.
...@@ -121,6 +121,8 @@ Expr FixedPointMuliply(Expr tensor, double multiplier, ...@@ -121,6 +121,8 @@ Expr FixedPointMuliply(Expr tensor, double multiplier,
auto zero_t = Zeros(input_shape, hp_dtype); auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar = round_scalar =
Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
} else {
LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
} }
// Add the rounding scalar. // Add the rounding scalar.
tensor = Add(tensor, round_scalar); tensor = Add(tensor, round_scalar);
......
...@@ -115,9 +115,9 @@ static inline int64_t get_const_int(const tvm::Expr& x) { ...@@ -115,9 +115,9 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
* 2) Round the result. * 2) Round the result.
* 3) Right shift the result * 3) Right shift the result
*/ */
Expr FixedPointMuliply(Expr tensor, double multiplier, Expr FixedPointMultiply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape, const Array<IndexExpr>& input_shape,
const std::string& rounding); const std::string& rounding);
} // namespace qnn } // namespace qnn
} // namespace relay } // namespace relay
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment