/*! * Copyright (c) 2018 by Contributors * * \file quantize.cc * * \brief transform a graph to a low-bit graph * for compression and acceleration. */ #include <dmlc/thread_local.h> #include <tvm/base.h> #include <tvm/relay/pass.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/op_attr_types.h> #include <cmath> #include <string> #include <vector> #include <stack> #include <utility> #include "pattern_util.h" #include "quantize.h" namespace tvm { namespace relay { namespace quantize { /*! \brief Attribute for simulated quantize operator */ struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> { int kind; bool sign; std::string rounding; TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { TVM_ATTR_FIELD(kind) .describe("kind of field, hint for nbit/dtype configuration."); TVM_ATTR_FIELD(sign).set_default(true) .describe("whether to use signed data type."); TVM_ATTR_FIELD(rounding).set_default("round") .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); } }; TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); bool SimulatedQuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 5); const auto param = attrs.as<SimulatedQuantizeAttrs>(); CHECK(param != nullptr); const auto* data = types[0].as<TensorTypeNode>(); CHECK(data != nullptr); CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; reporter->Assign(types[1], TensorTypeNode::make({}, Float(32))); // dom_scale reporter->Assign(types[2], TensorTypeNode::make({}, Float(32))); // clip_min reporter->Assign(types[3], TensorTypeNode::make({}, Float(32))); // clip_max reporter->Assign(types[4], types[0]); // output return true; } RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") .describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE) .set_num_inputs(4) .add_argument("data", "Tensor", "The input data.") .add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar") .add_argument("clip_min", "Tensor", "lower bound. It should be a scalar") .add_argument("clip_max", "Tensor", "upper bound. It should be a scalar") .set_attrs_type_key("relay.attrs.SimulatedQuantizeAttrs") .set_support_level(10) .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); TVM_REGISTER_API("relay._quantize.simulated_quantize") .set_body_typed<Expr(Expr, Expr, Expr, Expr, int, bool, std::string)>( [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign, std::string rounding) { auto attrs = make_node<SimulatedQuantizeAttrs>(); attrs->kind = kind; attrs->sign = sign; attrs->rounding = rounding; static const Op& op = Op::Get("relay.op.annotation.simulated_quantize"); return CallNode::make(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); }); // ============= // annotate pass Expr QAnnotateExprNode::Realize() const { const auto& cfg = QConfig::Current(); if (cfg->store_lowbit_output) { // store low bit output back for VTA const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); return (*f)(this->expr, static_cast<int>(kQInput)); } else { return expr; } } QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) { auto rnode = make_node<QAnnotateExprNode>(); rnode->expr = expr; rnode->kind = kind; return QAnnotateExpr(rnode); } TVM_REGISTER_API("relay._quantize.make_annotate_expr") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = QAnnotateExprNode::make(args[0], static_cast<QAnnotateKind>(args[1].operator int())); }); TVM_REGISTER_API("relay._quantize.annotate") .set_body_typed<Expr(Expr)>([] (const Expr& expr) { std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) { if (e->derived_from<TempExprNode>()) { const auto* n = e.as<QAnnotateExprNode>(); CHECK(n); const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); Expr ret = (*f)(n->expr, static_cast<int>(kQInput)); return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput)); } return e; }; return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref); }); // ============= // realize pass Expr QRealizeIntExprNode::Realize() const { const auto& cfg = QConfig::Current(); Expr data = this->data; if (cfg->store_lowbit_output) { data = Cast(data, cfg->dtype_input); } // dequantize data = Cast(data, Float(32)); data = Multiply(data, this->dom_scale); return data; } QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) { NodePtr<QRealizeIntExprNode> n = make_node<QRealizeIntExprNode>(); n->data = std::move(data); n->dom_scale = std::move(dom_scale); n->dtype = std::move(dtype); return QRealizeIntExpr(n); } inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) { return CallNode::make(ref_call->op, args, ref_call->attrs, ref_call->type_args); } /* calculate `data * s1 / s2`, use shift if possible */ inline Expr MulAndDiv(Expr data, float s1, float s2) { // here we assume the dtype of data is dtype activation const QConfig& cfg = QConfig::Current(); if (s1 == s2) return data; float factor = s1 / s2; float shift_factor = std::log2(factor); CHECK_GT(shift_factor, 0); if (static_cast<int>(shift_factor) == shift_factor) { return LeftShift(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(shift_factor))); } else if (static_cast<int>(factor) == factor) { return Multiply(data, MakeConstantScalar(cfg->dtype_activation, factor)); } else { LOG(FATAL) << "fall back to float computation"; data = Cast(data, Float(32)); return Multiply(data, MakeConstantScalar(Float(32), factor)); } } Expr QuantizeRealize(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) { const QConfig& cfg = QConfig::Current(); // do not handle data type cast const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>(); CHECK_EQ(param->rounding, "round"); Expr dom_scale = new_args[1]; Expr clip_min = new_args[2]; Expr clip_max = new_args[3]; float dom_scale_imm = GetScalarFromConstant<float>(dom_scale); float clip_min_imm = GetScalarFromConstant<float>(clip_min); float clip_max_imm = GetScalarFromConstant<float>(clip_max); // x * idom_scale = y * odom_scale // => y = x * idom_scale / odom_scale if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) { // int32->int8 Expr data = n->data; float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale); float odom_scale_imm = GetScalarFromConstant<float>(dom_scale); if (idom_scale_imm == odom_scale_imm) { // same domain scale, only clip data = Clip(data, clip_min_imm, clip_max_imm); return QRealizeIntExprNode::make(data, dom_scale, n->dtype); } float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); CHECK_GT(shift_nbit, 0); if (static_cast<int>(shift_nbit) == shift_nbit) { // use right shift if (cfg->round_for_shift) { float round_bias = std::pow(2.0, shift_nbit - 1); data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias))); } data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(shift_nbit))); data = Clip(data, clip_min_imm, clip_max_imm); return QRealizeIntExprNode::make(data, dom_scale, n->dtype); } else { // float computation data = Cast(data, Float(32)); Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale)); Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); } } // quantize from real CHECK(!new_args[0]->derived_from<TempExprNode>()); Expr data = new_args[0]; Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm)); Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); } RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") .set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize); Expr Conv2dRealize(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>()) { return Expr(nullptr); } const auto* lhs = new_args[0].as<QRealizeIntExprNode>(); CHECK(lhs); const auto* rhs = new_args[1].as<QRealizeIntExprNode>(); CHECK(rhs); Expr ldata = lhs->data; if (lhs->dtype != cfg->dtype_input) { ldata = Cast(ldata, cfg->dtype_input); } Expr rdata = Cast(rhs->data, cfg->dtype_weight); const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>(); auto attrs = make_node<Conv2DAttrs>(); *attrs = *ref_attrs; DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; Expr ret = CallNode::make(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale)); return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); } RELAY_REGISTER_OP("nn.conv2d") .set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize); Expr MulRealize(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) { // execute the operation with activation data type. const auto* lhs = new_args[0].as<QRealizeIntExprNode>(); const auto* rhs = new_args[1].as<QRealizeIntExprNode>(); Expr ldata = lhs->data; Expr rdata = rhs->data; DataType dtype = cfg->dtype_activation; if (lhs->dtype == Float(32)) { ldata = Cast(ldata, dtype); } else { CHECK_EQ(lhs->dtype, dtype); } if (rhs->dtype == Float(32)) { rdata = Cast(rdata, dtype); } else { CHECK_EQ(rhs->dtype, dtype); } Expr ret = ForwardOp(ref_call, {ldata, rdata}); Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale)); return QRealizeIntExprNode::make(ret, dom_scale, dtype); } CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>()); return Expr(nullptr); } RELAY_REGISTER_OP("multiply") .set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize); float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) { if (nptrs.size() == 2) { // x = a * s1, y = b * s2 // x + y = (a * s1 / s2 + b) * s2, if s1 > s2 // = (a + b * s2 / s1) * s1, if s2 > s1 float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale); float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale); return s1 > s2 ? s2 : s1; } else { const QConfig& cfg = QConfig::Current(); float scale = cfg->global_scale; return scale / std::pow(2.0, cfg->nbit_activation - 1); } } /* \brief Unify the dom scale of arguments */ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args, DataType* dtype_ptr, Expr* scale_ptr) { static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); const QConfig& cfg = QConfig::Current(); std::vector<const QRealizeIntExprNode*> nptrs; Array<Expr> ret; for (auto arg : args) { const auto* nptr = arg.as<QRealizeIntExprNode>(); CHECK(nptr); nptrs.push_back(nptr); ret.push_back(nptr->data); } // unify the data type CHECK_EQ(ref_args.size(), args.size()); DataType dtype = cfg->dtype_activation; for (size_t i = 0; i < ret.size(); ++i) { auto ref_arg = ref_args[i].as<CallNode>(); if (nptrs[i]->dtype != dtype) { ret.Set(i, Cast(ret[i], dtype)); } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) && ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) { auto new_arg = Cast(ret[i], cfg->dtype_input); if (cfg->use_stop_fusion) { new_arg = StopFusion(new_arg); } ret.Set(i, Cast(new_arg, dtype)); } } // unify the dom_scale float s = ChooseDomScale(nptrs); Expr dom_scale = MakeConstantScalar(Float(32), s); for (size_t i = 0; i < ret.size(); ++i) { float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale); ret.Set(i, MulAndDiv(ret[i], cur_s, s)); } *dtype_ptr = dtype; *scale_ptr = dom_scale; return ret; } Expr AddRealize(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) { CHECK_EQ(new_args.size(), 2); if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) { DataType dtype; Expr dom_scale; Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale); Expr ret = ForwardOp(ref_call, ret_args); return QRealizeIntExprNode::make(ret, dom_scale, dtype); } CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>()); return Expr(nullptr); } RELAY_REGISTER_OP("add") .set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize); Expr ConcatenateRealize(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) { CHECK_EQ(new_args.size(), 1); CHECK_EQ(ref_call->args.size(), 1); const auto* tuple = new_args[0].as<TupleNode>(); const auto* ref_tuple = ref_call->args[0].as<TupleNode>(); CHECK(tuple); CHECK(ref_tuple); const Array<Expr>& arr = tuple->fields; const Array<Expr>& ref_arr = ref_tuple->fields; if (arr[0].as<QRealizeIntExprNode>()) { DataType dtype; Expr dom_scale; Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale); Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)}); return QRealizeIntExprNode::make(ret, dom_scale, dtype); } else { for (auto arg : new_args) { CHECK(!arg->derived_from<TempExprNode>()); } return Expr(nullptr); } } RELAY_REGISTER_OP("concatenate") .set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize); /* \brief forward the original operator */ Expr IdentityRealize(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) { CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) { Expr ret = ForwardOp(ref_call, {n->data}); return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); } CHECK(!new_args[0]->derived_from<TempExprNode>()); return Expr(nullptr); } RELAY_REGISTER_OP("nn.relu") .set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize); RELAY_REGISTER_OP("strided_slice") .set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize); Expr MaxPoolRealize(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) { Expr data = Cast(n->data, cfg->dtype_input); Expr ret = ForwardOp(ref_call, {data}); return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input); } CHECK(!new_args[0]->derived_from<TempExprNode>()); return Expr(nullptr); } RELAY_REGISTER_OP("nn.max_pool2d") .set_attr<FForwardRewrite>("FQRealizeRewrite", MaxPoolRealize); Expr AvgPoolRealize(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) { Expr data = n->data; if (n->dtype != cfg->dtype_activation) { data = Cast(n->data, cfg->dtype_activation); } Expr ret = ForwardOp(ref_call, {data}); return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation); } CHECK(!new_args[0]->derived_from<TempExprNode>()); return Expr(nullptr); } RELAY_REGISTER_OP("nn.avg_pool2d") .set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize); TVM_REGISTER_API("relay._quantize.realize") .set_body_typed<Expr(Expr)>([](const Expr& e) { Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr); return ret; }); // ============= // qconfig QConfig qconfig() { return QConfig(make_node<QConfigNode>()); } /*! \brief Entry to hold the BuildConfig context stack. */ struct TVMQConfigThreadLocalEntry { /*! \brief The default build config if the stack is empty */ QConfig default_config; /*! \brief The current build config context */ std::stack<QConfig> context_stack; TVMQConfigThreadLocalEntry() : default_config(qconfig()) { } }; /*! \brief Thread local store to hold the BuildConfig context stack. */ typedef dmlc::ThreadLocalStore<TVMQConfigThreadLocalEntry> TVMQConfigThreadLocalStore; void QConfig::EnterQConfigScope(const QConfig& build_config) { TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); entry->context_stack.push(build_config); } void QConfig::ExitQConfigScope() { TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); entry->context_stack.pop(); } QConfig& QConfig::Current() { TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); } return entry->default_config; } TVM_REGISTER_NODE_TYPE(QConfigNode); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch<QConfigNode>([](const QConfigNode *op, IRPrinter *p) { p->stream << "qconfig("; p->stream << "nbit_input=" << op->nbit_input << ", "; p->stream << "nbit_weight=" << op->nbit_weight << ", "; p->stream << "nbit_activation=" << op->nbit_activation << ", "; p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "skip_k_conv==" << op->skip_k_conv << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; p->stream << "use_stop_fusion==" << op->use_stop_fusion; p->stream << ")"; }); TVM_REGISTER_API("relay._quantize._GetCurrentQConfig") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = QConfig::Current(); }); TVM_REGISTER_API("relay._quantize._EnterQConfigScope") .set_body([](TVMArgs args, TVMRetValue* ret) { QConfig target = args[0]; QConfig::EnterQConfigScope(target); }); TVM_REGISTER_API("relay._quantize._ExitQConfigScope") .set_body([](TVMArgs args, TVMRetValue* ret) { QConfig::ExitQConfigScope(); }); } // namespace quantize } // namespace relay } // namespace tvm