quantize.h 5.08 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file tvm/relay/pass/quantize.h
 * \brief Header of definitions for quantization
 */
24 25
#ifndef TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_
#define TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_
26 27 28 29

#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <string>
30
#include "../pattern_util.h"
31 32 33 34 35 36 37

namespace tvm {
namespace relay {
namespace quantize {

/*! \brief Kind of annotate field */
enum QAnnotateKind : int {
38
  kQIdentity = 0,
39 40
  kQInput = 1,
  kQWeight = 2,
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
  kQActivation = 3
};

/*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
  int kind;
  bool sign;
  std::string rounding;

  TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
    TVM_ATTR_FIELD(kind)
        .describe("kind of field, hint for nbit/dtype configuration.");
    TVM_ATTR_FIELD(sign).set_default(true)
        .describe("whether to use signed data type.");
    TVM_ATTR_FIELD(rounding).set_default("round")
        .describe("rounding mode. Can be 'floor', 'ceil', 'round'");
  }
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
};


class QConfig;
/*!
* \brief Container for build configuration options
*/
class QConfigNode : public Node {
 public:
  int nbit_input = 8;
  int nbit_weight = 8;
  int nbit_activation = 32;
  DataType dtype_input = Int(8);
  DataType dtype_weight = Int(8);
  DataType dtype_activation = Int(32);
73
  std::string calibrate_mode = "global_scale";
74
  double global_scale = 8.0;
75
  std::string weight_scale = "power2";
76
  Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
77
  bool do_simulation = false;
78 79
  bool round_for_shift = true;
  Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
80
  std::string rounding = "UPWARD";
81

82
  void VisitAttrs(AttrVisitor* v) {
83 84 85 86 87 88
    v->Visit("nbit_input", &nbit_input);
    v->Visit("nbit_weight", &nbit_weight);
    v->Visit("nbit_activation", &nbit_activation);
    v->Visit("dtype_input", &dtype_input);
    v->Visit("dtype_weight", &dtype_weight);
    v->Visit("dtype_activation", &dtype_activation);
89
    v->Visit("calibrate_mode", &calibrate_mode);
90
    v->Visit("global_scale", &global_scale);
91
    v->Visit("weight_scale", &weight_scale);
92
    v->Visit("skip_conv_layers", &skip_conv_layers);
93
    v->Visit("do_simulation", &do_simulation);
94 95
    v->Visit("round_for_shift", &round_for_shift);
    v->Visit("debug_enabled_ops", &debug_enabled_ops);
96
    v->Visit("rounding", &rounding);
97 98 99 100 101 102 103 104 105 106 107 108
  }

  static constexpr const char* _type_key = "relay.quantize.QConfig";
  TVM_DECLARE_NODE_TYPE_INFO(QConfigNode, Node);
};

/*!
* \brief Container for build configuration options
*/
class QConfig : public NodeRef {
 public:
  QConfig() {}
109
  explicit QConfig(ObjectPtr<Object> n) : NodeRef(n) {}
110 111

  const QConfigNode* operator->() const {
112
    return static_cast<const QConfigNode*>(get());
113 114 115
  }

  QConfigNode* operator->() {
116
    return static_cast<QConfigNode*>(get_mutable());
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
  }

  /*!
   * \brief Push a new BuildConfig context onto the thread local stack.
   * \param build_config The configuration to set as the current context.
   */
  static void EnterQConfigScope(const QConfig& qconfig);

  /*!
   * \brief Pop a build config off the thread local context stack, restoring the previous
   * configuration as the current context.
   */
  static void ExitQConfigScope();

  /*!
   * \brief Get the current BuildConfig context from thread local storage, or a default
   * configuration if a BuildConfig scope has not been entered.
   * \return The configuration that is the current context.
   */
136
  static QConfig& Current();
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163

  using ContainerType = QConfigNode;
};

/*!
 * \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
 * context stack when constructed, and pops it when destructed.
 */
struct QConfigContext {
  /*!
   * \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
   * context. When the BuildConfigContext is destructed, the previous context is restored.
   * \param build_config The BuildConfig to set as the new current context.
   */
  explicit QConfigContext(const QConfig& qconfig) {
    QConfig::EnterQConfigScope(qconfig);
  }

  /*! \brief Destructor. Pops the context off the thread local stack. */
  ~QConfigContext() {
    QConfig::ExitQConfigScope();
  }
};

}  // namespace quantize
}  // namespace relay
}  // namespace tvm
164
#endif  // TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_