/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2018 by Contributors. * * \file tvm/relay/pass/quantize.h * \brief Header of definitions for quantization */ #ifndef TVM_RELAY_PASS_QUANTIZE_H_ #define TVM_RELAY_PASS_QUANTIZE_H_ #include <tvm/relay/op.h> #include <tvm/relay/expr.h> #include <string> #include "pattern_util.h" namespace tvm { namespace relay { namespace quantize { /*! \brief Kind of annotate field */ enum QAnnotateKind : int { kQInput = 1, kQWeight = 2, kQActivation = 3, }; /*! * \brief TempExpr used during annotate forward rewrite. */ class QAnnotateExpr; /*! * \brief TempExprNode used during annotate forward rewrite. */ class QAnnotateExprNode : public TempExprNode { public: /*! \brief The original expression */ Expr expr; /*! \brief The kind of annotate field */ QAnnotateKind kind; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("expr", &expr); v->Visit("kind", &kind); } TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind); Expr Realize() const final; static constexpr const char* _type_key = "relay.QAnnotateExpr"; TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode); }; RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr); /*! \brief TempExpr used during realize forward rewrite. */ class QRealizeExpr; /*! \brief TempExpr representing integer. */ class QRealizeIntExpr; class QRealizeExprNode : public TempExprNode { public: /*! \brief The original expression */ Expr data; static constexpr const char* _type_key = "relay.quantize.QRealizeExpr"; TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode); }; RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr); class QRealizeIntExprNode : public QRealizeExprNode { public: Expr dom_scale; /*! \brief current data type */ DataType dtype; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); v->Visit("dom_scale", &dom_scale); v->Visit("dtype", &dtype); } Expr Realize() const final; TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype); static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr"; TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode); }; RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr); class QConfig; /*! * \brief Container for build configuration options */ class QConfigNode : public Node { public: int nbit_input = 8; int nbit_weight = 8; int nbit_activation = 32; DataType dtype_input = Int(8); DataType dtype_weight = Int(8); DataType dtype_activation = Int(32); double global_scale = 8.0; int skip_k_conv = 1; Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr)); bool round_for_shift = true; bool store_lowbit_output = true; Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr)); bool use_stop_fusion = true; void VisitAttrs(AttrVisitor* v) final { v->Visit("nbit_input", &nbit_input); v->Visit("nbit_weight", &nbit_weight); v->Visit("nbit_activation", &nbit_activation); v->Visit("dtype_input", &dtype_input); v->Visit("dtype_weight", &dtype_weight); v->Visit("dtype_activation", &dtype_activation); v->Visit("global_scale", &global_scale); v->Visit("skip_k_conv", &skip_k_conv); v->Visit("skip_conv_layers", &skip_conv_layers); v->Visit("round_for_shift", &round_for_shift); v->Visit("store_lowbit_output", &store_lowbit_output); v->Visit("debug_enabled_ops", &debug_enabled_ops); v->Visit("use_stop_fusion", &use_stop_fusion); } static constexpr const char* _type_key = "relay.quantize.QConfig"; TVM_DECLARE_NODE_TYPE_INFO(QConfigNode, Node); }; /*! * \brief Container for build configuration options */ class QConfig : public NodeRef { public: QConfig() {} explicit QConfig(NodePtr<Node> n) : NodeRef(n) {} const QConfigNode* operator->() const { return static_cast<const QConfigNode*>(node_.get()); } QConfigNode* operator->() { return static_cast<QConfigNode*>(node_.get()); } /*! * \brief Push a new BuildConfig context onto the thread local stack. * \param build_config The configuration to set as the current context. */ static void EnterQConfigScope(const QConfig& qconfig); /*! * \brief Pop a build config off the thread local context stack, restoring the previous * configuration as the current context. */ static void ExitQConfigScope(); /*! * \brief Get the current BuildConfig context from thread local storage, or a default * configuration if a BuildConfig scope has not been entered. * \return The configuration that is the current context. */ static QConfig& Current(); using ContainerType = QConfigNode; }; /*! * \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the * context stack when constructed, and pops it when destructed. */ struct QConfigContext { /*! * \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current * context. When the BuildConfigContext is destructed, the previous context is restored. * \param build_config The BuildConfig to set as the new current context. */ explicit QConfigContext(const QConfig& qconfig) { QConfig::EnterQConfigScope(qconfig); } /*! \brief Destructor. Pops the context off the thread local stack. */ ~QConfigContext() { QConfig::ExitQConfigScope(); } }; /*! * \brief Construct a BuildConfig containing a new BuildConfigNode * \return The new BuildConfig */ TVM_DLL QConfig qconfig(); } // namespace quantize } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_QUANTIZE_H_