partition.cc 2.74 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *
 * \file partition.cc
 *
 * \brief Partition a graph into sections for quantization.
 */

#include <tvm/relay/transform.h>
28
#include "../transforms/pattern_util.h"
29 30 31 32 33 34 35 36
#include "./quantize.h"

namespace tvm {
namespace relay {
namespace quantize {

using namespace relay::transform;

37

38 39 40 41 42 43
class QPartitionExpr;
class QPartitionExprNode : public TempExprNode {
 public:
  /*! \brief The original expression */
  Expr expr;

44
  void VisitAttrs(tvm::AttrVisitor* v) {
45 46 47 48 49 50
    v->Visit("expr", &expr);
  }

  Expr Realize() const final;

  static constexpr const char* _type_key = "relay.QPartitionExpr";
51
  TVM_DECLARE_FINAL_OBJECT_INFO(QPartitionExprNode, TempExprNode);
52 53
};

54 55
class QPartitionExpr : public TempExpr {
 public:
56 57 58 59 60 61
  /*!
   * \brief  The constructor
   * \param expr The original relay expression.
   */
  TVM_DLL explicit QPartitionExpr(Expr expr);

62 63
  TVM_DEFINE_OBJECT_REF_METHODS(QPartitionExpr, TempExpr, QPartitionExprNode);
};
64 65 66 67 68 69 70 71 72


Expr QPartitionExprNode::Realize() const {
  // insert cast hint and stop fusion
  const QConfig& cfg = QConfig::Current();
  Expr ret = CastHint(this->expr, cfg->dtype_input);
  return StopFusion(ret);
}

73
QPartitionExpr::QPartitionExpr(Expr expr) {
74
  auto rnode = make_object<QPartitionExprNode>();
75 76
  rnode->expr = std::move(expr);
  data_ = std::move(rnode);
77 78
}

79
TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr")
80 81 82
.set_body_typed([](Expr expr) {
  return QPartitionExpr(expr);
});
83 84

Pass QuantizePartition() {
85 86
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
87 88 89 90 91 92 93
      auto ret = Downcast<Function>(
          ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr));
      return ret;
  };
  return CreateFunctionPass(pass_func, 1, "QuantizePartition", {});
}

94
TVM_REGISTER_GLOBAL("relay._quantize.QuantizePartition")
95 96
.set_body_typed(QuantizePartition);

97 98
TVM_REGISTER_NODE_TYPE(QPartitionExprNode);

99 100 101
}  // namespace quantize
}  // namespace relay
}  // namespace tvm