mac_count.cc 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32
/*!
 * Copyright (c) 2019 by Contributors
 *
 * \file mac_count.cc
 * \brief Pass to roughly count the number of MACs (Multiply-Accumulate) 
 * operations of a model. Only MACs in CONV and Dense ops are counted.
 * This pass is valid after the type infer pass is called,
 * otherwise the count is 0.
 */

#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
33
#include <tvm/data_layout.h>
34 35 36 37

namespace tvm {
namespace relay {

38
namespace mac_count {
39

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
inline int64_t GetCartesianProd(Array<IndexExpr> arr) {
  int64_t ret = 1;
  for (size_t i = 0; i < arr.size(); i++) {
    const auto* intImm = arr[i].as<IntImm>();
    ret *= static_cast<int64_t>(intImm->value);
  }
  return ret;
}

/*
 * \brief Preparation function for MAC count.
 * \param call_node The call node.
 * \return The number of MACs.
 */
using FMacCount = runtime::TypedPackedFunc<
  int64_t(const Call& call_node)>;

//----------------------------------------------
// Per operator defs for MAC count
//----------------------------------------------

int64_t ConvMacCount(const Call& call_node) {
  if (!call_node->checked_type_.defined()) {
    LOG(WARNING) << "The infer type pass should be called before the mac count pass";
    return 0;
  }
  Array<Expr> args = call_node->args;
  CHECK(args.size() == 2)
      << "The number of input arguments of a CONV 2D node should be 2.";
  const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>();
  const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
  Array<IndexExpr> data_shape = data_type->shape;
  std::string data_layout = conv_2d_attr->data_layout;
73 74
  int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
  int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
75 76 77 78 79 80 81 82 83 84 85 86 87 88
  CHECK(C_ind != -1)
      << "There is no input channel dimension.";
  int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
  if (c_ind != -1)
    input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
  Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
  CHECK(kernel_size.size() == 2)
      << "The dimension of the kernel size in Conv 2D should be 2.";
  const auto* expr = call_node->checked_type().as<TensorTypeNode>();
  Array<IndexExpr> output_tensor = expr->shape;
  CHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
      << "The dimension of the output tensor in Conv 2D should be 4 or 5.";
  int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
  return count;
89 90
}

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
int64_t DenseMacCount(const Call& call_node) {
  if (!call_node->checked_type_.defined()) {
    LOG(WARNING) << "The infer type pass should be called before the mac count pass";
    return 0;
  }
  Array<Expr> args = call_node->args;
  CHECK(args.size() == 2)
      << "The number of input arguments of a Dense node should be 2.";
  const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
  const auto* weight_type = args[1]->checked_type().as<TensorTypeNode>();
  Array<IndexExpr> data_shape = data_type->shape;
  Array<IndexExpr> weight_shape = weight_type->shape;
  CHECK(data_shape.size() == 2 && weight_shape.size() == 2)
      << "The dimension of an input tensor to Dense node should be 2.";
  int64_t d1 = static_cast<int64_t>(data_shape[0].as<IntImm>()->value);
  int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImm>()->value);
  int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImm>()->value);
  int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImm>()->value);
  CHECK(d2 == d4)
      << "The dimensions of input arguments do not match.";
  int64_t count = d1 * d2 * d3;
  return count;
113 114
}

115 116 117 118 119
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FMacCount>("FMacCount", ConvMacCount);

RELAY_REGISTER_OP("nn.dense")
.set_attr<FMacCount>("FMacCount", DenseMacCount);
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

class MacCounter : private ExprVisitor {
 public:
  MacCounter() {
    count_ = 0;
  }
  static int64_t GetTotalMacNumber(const Expr& expr) {
    LOG(INFO) << "This pass only counts MACs in direct CONV 2D and Dense ops";
    MacCounter counter;
    counter(expr);
    return counter.count_;
  }

 private:
  void VisitExpr_(const CallNode* call_node) final {
135 136 137 138
    static const auto& fprep =
        Op::GetAttr<FMacCount>("FMacCount");
    auto f = fprep.get(call_node->op, nullptr);
    if (f != nullptr) count_ += f(GetRef<Call>(call_node));
139 140 141 142 143 144 145 146 147 148 149
    ExprVisitor::VisitExpr_(call_node);
  }

  int64_t count_;
};

int64_t GetTotalMacNumber(const Expr& expr) {
  return MacCounter::GetTotalMacNumber(expr);
}

TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber")
150
.set_body_typed(GetTotalMacNumber);
151

152
}  // namespace mac_count
153 154
}  // namespace relay
}  // namespace tvm