annotate.cc 3.56 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *
 * \file annotate.cc
 *
 * \brief Annotating the graph with simulated quantize operators.
 */

#include <tvm/relay/transform.h>
#include <tvm/relay/analysis.h>
#include "./quantize.h"

namespace tvm {
namespace relay {
namespace quantize {

using namespace relay::transform;

class QAnnotateExpr;
class QAnnotateExprNode : public TempExprNode {
 public:
  Expr expr;
  QAnnotateKind kind;

43
  void VisitAttrs(tvm::AttrVisitor* v) {
44 45 46 47 48 49 50
    v->Visit("expr", &expr);
    v->Visit("kind", &kind);
  }

  Expr Realize() const final;

  static constexpr const char* _type_key = "relay.QAnnotateExpr";
51
  TVM_DECLARE_FINAL_OBJECT_INFO(QAnnotateExprNode, TempExprNode);
52 53
};

54 55
class QAnnotateExpr : public TempExpr {
 public:
56 57 58 59 60 61 62
  /*!
   * \brief The constructor
   * \param expr The original relay expression.
   * \param kind The annotation kind.
   */
  TVM_DLL QAnnotateExpr(Expr expr, QAnnotateKind kind);

63 64
  TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode);
};
65 66 67 68 69 70


Expr QAnnotateExprNode::Realize() const {
  return expr;
}

71
QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) {
72
  auto rnode = make_object<QAnnotateExprNode>();
73
  rnode->expr = std::move(expr);
74
  rnode->kind = kind;
75
  data_ = std::move(rnode);
76 77
}

78
TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr")
79 80 81
.set_body_typed([](Expr expr, int kind) {
  return QAnnotateExpr(expr, static_cast<QAnnotateKind>(kind));
});
82 83 84 85 86 87


Pass QuantizeAnnotate() {
  // TODO(tvm-teams): since partition has added cast_hint in different
  // branches, try to remove this in the future.
  std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
88
    if (e->IsInstance<TempExprNode>()) {
89 90 91 92 93
      const auto* n = e.as<QAnnotateExprNode>();
      CHECK(n);
      const PackedFunc* f =
          runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
      Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
94
      return static_cast<Expr>(QAnnotateExpr(ret, kQInput));
95 96 97 98
    }
    return e;
  };

99 100
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
101 102 103 104 105
      auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
      auto new_params = func->params;
      for (const auto& x : FreeVars(func)) {
        new_params.push_back(x);
      }
106
      return Function(new_params,
107 108 109 110 111 112 113 114
                                func->body,
                                func->ret_type,
                                func->type_params,
                                func->attrs);
  };
  return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}

115
TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate")
116 117
.set_body_typed(QuantizeAnnotate);

118 119
TVM_REGISTER_NODE_TYPE(QAnnotateExprNode);

120 121 122
}  // namespace quantize
}  // namespace relay
}  // namespace tvm