op.cc 7.21 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22
 * \file src/tvm/ir/op.cc
 * \brief Primitive operators and intrinsics.
23
 */
24 25
#include <tvm/ir/op.h>
#include <tvm/ir/type.h>
26 27 28 29 30 31 32 33
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>

#include <memory>
#include <mutex>

namespace dmlc {
// enable registry
34
DMLC_REGISTRY_ENABLE(::tvm::OpRegistry);
35 36 37 38
}  // namespace dmlc

namespace tvm {

39 40 41 42
using runtime::TVMRetValue;
using runtime::TVMArgs;
using runtime::PackedFunc;

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
::dmlc::Registry<OpRegistry>* OpRegistry::Registry() {
  return ::dmlc::Registry<OpRegistry>::Get();
}

// single manager of operator information.
struct OpManager {
  // mutex to avoid registration from multiple threads.
  std::mutex mutex;
  // global operator counter
  std::atomic<int> op_counter{0};
  // storage of additional attribute table.
  std::unordered_map<std::string, std::unique_ptr<GenericOpMap>> attr;
  // frontend functions
  std::vector<PackedFunc*> frontend_funcs;
  // get singleton of the op manager
  static OpManager* Global() {
59 60
    static OpManager* inst = new OpManager();
    return inst;
61 62 63 64 65 66 67 68 69 70 71 72
  }
};

// find operator by name
const Op& Op::Get(const std::string& name) {
  const OpRegistry* reg = dmlc::Registry<OpRegistry>::Find(name);
  CHECK(reg != nullptr) << "Operator " << name << " is not registered";
  return reg->op();
}

OpRegistry::OpRegistry() {
  OpManager* mgr = OpManager::Global();
73
  ObjectPtr<OpNode> n = make_object<OpNode>();
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
  n->index_ = mgr->op_counter++;
  op_ = Op(n);
}

// Get attribute map by key
const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
  OpManager* mgr = OpManager::Global();
  std::lock_guard<std::mutex> lock(mgr->mutex);
  auto it = mgr->attr.find(key);
  if (it == mgr->attr.end()) {
    LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered";
  }
  return *it->second.get();
}

89 90 91 92 93 94 95 96 97 98 99
// Check if a key is present in the registry.
const bool Op::HasGenericAttr(const std::string& key) {
  OpManager* mgr = OpManager::Global();
  std::lock_guard<std::mutex> lock(mgr->mutex);
  auto it = mgr->attr.find(key);
  if (it == mgr->attr.end()) {
    return false;
  }
  return true;
}

100 101 102 103 104 105 106 107 108 109 110 111 112 113
// Resets attr of the OpMap.
void OpRegistry::reset_attr(const std::string& key) {
  OpManager* mgr = OpManager::Global();
  std::lock_guard<std::mutex> lock(mgr->mutex);
  std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
  if (op_map == nullptr) {
    return;
  }
  uint32_t index = op_->index_;
  if (op_map->data_.size() > index) {
    op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
  }
}

114 115
void OpRegistry::UpdateAttr(const std::string& key,
                            TVMRetValue value,
116 117 118 119 120 121
                            int plevel) {
  OpManager* mgr = OpManager::Global();
  std::lock_guard<std::mutex> lock(mgr->mutex);
  std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
  if (op_map == nullptr) {
    op_map.reset(new GenericOpMap());
122
    op_map->attr_name_ = key;
123 124 125 126 127 128 129 130 131
  }
  uint32_t index = op_->index_;
  if (op_map->data_.size() <= index) {
    op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
  }
  std::pair<TVMRetValue, int>& p = op_map->data_[index];
  CHECK(p.second != plevel)
      << "Attribute " << key << " of operator " << this->name
      << " is already registered with same plevel=" << plevel;
132
  CHECK(value.type_code() != kTVMNullptr)
133 134
      << "Registered packed_func is Null for " << key
      << " of operator " << this->name;
135
  if (p.second < plevel && value.type_code() != kTVMNullptr) {
136 137 138 139 140
    op_map->data_[index] = std::make_pair(value, plevel);
  }
}

// Frontend APIs
141
TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
142
.set_body_typed([]() {
143
    Array<tvm::PrimExpr> ret;
144 145
    for (const std::string& name :
             dmlc::Registry<OpRegistry>::ListAllNames()) {
146
      ret.push_back(tvm::PrimExpr(name));
147 148 149
    }
    return ret;
  });
150

151
TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed(Op::Get);
152

153
TVM_REGISTER_GLOBAL("relay.op._OpGetAttr")
154 155 156 157 158 159 160 161 162
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Op op = args[0];
    std::string attr_name = args[1];
    auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
    if (op_map.count(op)) {
      *rv = op_map[op];
    }
  });

163
TVM_REGISTER_GLOBAL("relay.op._OpSetAttr")
164 165 166 167 168 169 170 171 172
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Op op = args[0];
    std::string attr_name = args[1];
    runtime::TVMArgValue value = args[2];
    int plevel = args[3];
    auto& reg =
        OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
    reg.set_attr(attr_name, value, plevel);
  });
173

174
TVM_REGISTER_GLOBAL("relay.op._OpResetAttr")
175 176 177 178 179 180 181 182
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Op op = args[0];
    std::string attr_name = args[1];
    auto& reg =
        OpRegistry::Registry()->__REGISTER_OR_GET__(op->name);
    reg.reset_attr(attr_name);
  });

183
TVM_REGISTER_GLOBAL("relay.op._Register")
184 185 186 187 188 189 190 191 192 193 194
.set_body([](TVMArgs args, TVMRetValue* rv) {
    std::string op_name = args[0];
    std::string attr_key = args[1];
    runtime::TVMArgValue value = args[2];
    int plevel = args[3];
    auto& reg =
        OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
    // enable resgiteration and override of certain properties
    if (attr_key == "num_inputs" && plevel > 128) {
      reg.set_num_inputs(value);
    } else if (attr_key == "attrs_type_key" && plevel > 128) {
195
      LOG(FATAL) << "attrs type key no longer supported";
196 197
    } else {
      // normal attr table override.
198
      if (args[2].type_code() == kTVMPackedFuncHandle) {
199 200 201 202 203
        // do an eager copy of the PackedFunc
        PackedFunc f = args[2];
        // If we get a function from frontend, avoid deleting it.
        OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
        reg.set_attr(attr_key, f, plevel);
204
      } else {
205
        reg.set_attr(attr_key, args[2], plevel);
206
      }
207 208
    }
  });
209

210
// helper to get internal dev function in objectref.
211 212 213
struct Op2ObjectPtr : public ObjectRef {
  static ObjectPtr<Object> Get(const Op& op) {
    return GetDataPtr<Object>(op);
214 215 216
  }
};

217
ObjectPtr<Object> CreateOp(const std::string& name) {
218
  // Hack use TVMRetValue as exchange
219
  auto op = Op::Get(name);
220
  CHECK(op.defined()) << "Cannot find op \'" << name << '\'';
221
  return Op2ObjectPtr::Get(op);
222 223 224 225
}

TVM_REGISTER_NODE_TYPE(OpNode)
.set_creator(CreateOp)
226
.set_global_key([](const Object* n) {
227 228 229
    return static_cast<const OpNode*>(n)->name;
  });

230 231
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, NodePrinter* p) {
232
    auto* node = static_cast<const OpNode*>(ref.get());
233 234 235
    p->stream << "Op(" << node->name << ")";
  });

236
}  // namespace tvm