op.cc 3 KB
Newer Older
tqchen committed
1 2 3 4 5
/*!
 *  Copyright (c) 2016 by Contributors
 * \file op.cc
 * \brief Support for operator registry.
 */
tqchen committed
6 7
#include <nnvm/base.h>
#include <nnvm/op.h>
tqchen committed
8

9
#include <memory>
tqchen committed
10 11
#include <atomic>
#include <mutex>
12
#include <unordered_set>
tqchen committed
13 14 15

namespace dmlc {
// enable registry
tqchen committed
16
DMLC_REGISTRY_ENABLE(nnvm::Op);
tqchen committed
17 18
}  // namespace dmlc

tqchen committed
19
namespace nnvm {
tqchen committed
20 21 22 23

// single manager of operator information.
struct OpManager {
  // mutex to avoid registration from multiple threads.
24 25
  // recursive is needed for trigger(which calls UpdateAttrMap)
  std::recursive_mutex mutex;
tqchen committed
26 27 28
  // global operator counter
  std::atomic<int> op_counter{0};
  // storage of additional attribute table.
29
  std::unordered_map<std::string, std::unique_ptr<any> > attr;
30 31 32 33
  // storage of existing triggers
  std::unordered_map<std::string, std::vector<std::function<void(Op*)>  > > tmap;
  // group of each operator.
  std::vector<std::unordered_set<std::string> > op_group;
tqchen committed
34 35 36 37 38 39 40 41 42 43 44 45 46
  // get singleton of the
  static OpManager* Global() {
    static OpManager inst;
    return &inst;
  }
};

// constructor
Op::Op() {
  OpManager* mgr = OpManager::Global();
  index_ = mgr->op_counter++;
}

47 48 49 50 51
Op& Op::add_alias(const std::string& alias) {  // NOLINT(*)
  dmlc::Registry<Op>::Get()->AddAlias(this->name, alias);
  return *this;
}

tqchen committed
52 53 54 55 56 57 58 59 60
// find operator by name
const Op* Op::Get(const std::string& name) {
  const Op* op = dmlc::Registry<Op>::Find(name);
  CHECK(op != nullptr)
      << "Operator " << name << " is not registered";
  return op;
}

// Get attribute map by key
61 62
const any* Op::GetAttrMap(const std::string& key) {
  auto& dict =  OpManager::Global()->attr;
tqchen committed
63
  auto it = dict.find(key);
64 65 66 67 68
  if (it != dict.end()) {
    return it->second.get();
  } else {
    return nullptr;
  }
tqchen committed
69 70
}

71
// update attribute map
tqchen committed
72 73 74
void Op::UpdateAttrMap(const std::string& key,
                       std::function<void(any*)> updater) {
  OpManager* mgr = OpManager::Global();
75
  std::lock_guard<std::recursive_mutex>(mgr->mutex);
76 77 78
  std::unique_ptr<any>& value = mgr->attr[key];
  if (value.get() == nullptr) value.reset(new any());
  if (updater != nullptr) updater(value.get());
tqchen committed
79 80
}

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
void Op::AddGroupTrigger(const std::string& group_name,
                         std::function<void(Op*)> trigger) {
  OpManager* mgr = OpManager::Global();
  std::lock_guard<std::recursive_mutex>(mgr->mutex);
  auto& tvec = mgr->tmap[group_name];
  tvec.push_back(trigger);
  auto& op_group = mgr->op_group;
  for (const Op* op : dmlc::Registry<Op>::List()) {
    if (op->index_ < op_group.size() &&
        op_group[op->index_].count(group_name) != 0) {
      trigger((Op*)op);  // NOLINT(*)
    }
  }
}

Op& Op::include(const std::string& group_name) {
  OpManager* mgr = OpManager::Global();
  std::lock_guard<std::recursive_mutex>(mgr->mutex);
  auto it = mgr->tmap.find(group_name);
  if (it != mgr->tmap.end()) {
    for (auto& trigger : it->second) {
      trigger(this);
    }
  }
  auto& op_group = mgr->op_group;
  if (index_ >= op_group.size()) {
    op_group.resize(index_ + 1);
  }
  op_group[index_].insert(group_name);
  return *this;
}

tqchen committed
113
}  // namespace nnvm