calibrate.cc 3.37 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * Copyright (c) 2019 by Contributors
 *
 * \file calibrate.cc
 *
 * \brief Create profile graph and calibrate on dataset
 */
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include "./quantize.h"


namespace tvm {
namespace relay {
namespace quantize {

class StatsCollector : private ExprMutator {
 public:
  Expr Collect(const Expr& expr) {
    auto new_e = this->Mutate(expr);
    const FunctionNode* func = new_e.as<FunctionNode>();
    CHECK(func) << "Input shoule be Function";
    Expr new_body = TupleNode::make(std::move(profile_data_));
    return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
            func->attrs);
  }

 private:
  Array<Expr> profile_data_;

  Expr VisitExpr_(const CallNode* call) {
    static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
    Expr new_e = ExprMutator::VisitExpr_(call);
    const CallNode* new_call = new_e.as<CallNode>();
    CHECK(new_call);
    if (new_call->op.same_as(simulated_quantize)) {
      auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
      // rewrite the annotation
      auto new_attrs = make_node<SimulatedQuantizeAttrs>();
      const Expr& quantize_input = new_call->args[0];  // expression being quantized
      auto placeholder = MakeConstantScalar(Float(32), 0.);  // unused argument
      Array<Expr> new_args{quantize_input, placeholder, placeholder, placeholder};
      new_attrs->kind = QAnnotateKind::kQIdentity;
      new_attrs->sign = attrs->sign;
      new_attrs->rounding = attrs->rounding;
      Expr identity_quantize = CallNode::make(new_call->op, new_args, Attrs{new_attrs}, {});

      // add non-const expressions to profile data
      if (attrs->kind != QAnnotateKind::kQWeight) {
        CHECK(!quantize_input.as<ConstantNode>());
        profile_data_.push_back(identity_quantize);
      }
      return identity_quantize;
    } else {
      return new_e;
    }
  }
};

/*
 * \brief Given an annotated graph, create a profile graph to collect profile data from the
 * calibration dataset.
 *
 * This pass collects simulated_quantize op into a tuple. Simulated_quantize ops are rewritten to
 * identity mode. The tuple is the output of the profile graph. Both input and output of this pass
 * are relay::Function.
 *
 * \param expr The simulation graph after annotation.
 * \return The profile graph.
 */
Expr CollectStats(const Expr& expr) {
  return StatsCollector().Collect(expr);
}

TVM_REGISTER_API("relay._quantize.CollectStats")
.set_body_typed(CollectStats);

}  // namespace quantize
}  // namespace relay
}  // namespace tvm