op.cc 3.79 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

tqchen committed
20 21 22 23 24
/*!
 *  Copyright (c) 2016 by Contributors
 * \file op.cc
 * \brief Support for operator registry.
 */
tqchen committed
25 26
#include <nnvm/base.h>
#include <nnvm/op.h>
tqchen committed
27

28
#include <memory>
tqchen committed
29 30
#include <atomic>
#include <mutex>
31
#include <unordered_set>
tqchen committed
32 33 34

namespace dmlc {
// enable registry
tqchen committed
35
DMLC_REGISTRY_ENABLE(nnvm::Op);
tqchen committed
36 37
}  // namespace dmlc

tqchen committed
38
namespace nnvm {
tqchen committed
39 40 41 42

// single manager of operator information.
struct OpManager {
  // mutex to avoid registration from multiple threads.
43 44
  // recursive is needed for trigger(which calls UpdateAttrMap)
  std::recursive_mutex mutex;
tqchen committed
45 46 47
  // global operator counter
  std::atomic<int> op_counter{0};
  // storage of additional attribute table.
48
  std::unordered_map<std::string, std::unique_ptr<any> > attr;
49 50 51 52
  // storage of existing triggers
  std::unordered_map<std::string, std::vector<std::function<void(Op*)>  > > tmap;
  // group of each operator.
  std::vector<std::unordered_set<std::string> > op_group;
tqchen committed
53 54 55 56 57 58 59 60 61 62 63 64 65
  // get singleton of the
  static OpManager* Global() {
    static OpManager inst;
    return &inst;
  }
};

// constructor
Op::Op() {
  OpManager* mgr = OpManager::Global();
  index_ = mgr->op_counter++;
}

66 67 68 69 70
Op& Op::add_alias(const std::string& alias) {  // NOLINT(*)
  dmlc::Registry<Op>::Get()->AddAlias(this->name, alias);
  return *this;
}

tqchen committed
71 72 73 74 75 76 77 78 79
// find operator by name
const Op* Op::Get(const std::string& name) {
  const Op* op = dmlc::Registry<Op>::Find(name);
  CHECK(op != nullptr)
      << "Operator " << name << " is not registered";
  return op;
}

// Get attribute map by key
80 81
const any* Op::GetAttrMap(const std::string& key) {
  auto& dict =  OpManager::Global()->attr;
tqchen committed
82
  auto it = dict.find(key);
83 84 85 86 87
  if (it != dict.end()) {
    return it->second.get();
  } else {
    return nullptr;
  }
tqchen committed
88 89
}

90
// update attribute map
tqchen committed
91 92 93
void Op::UpdateAttrMap(const std::string& key,
                       std::function<void(any*)> updater) {
  OpManager* mgr = OpManager::Global();
94
  std::lock_guard<std::recursive_mutex>(mgr->mutex);
95 96 97
  std::unique_ptr<any>& value = mgr->attr[key];
  if (value.get() == nullptr) value.reset(new any());
  if (updater != nullptr) updater(value.get());
tqchen committed
98 99
}

100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
void Op::AddGroupTrigger(const std::string& group_name,
                         std::function<void(Op*)> trigger) {
  OpManager* mgr = OpManager::Global();
  std::lock_guard<std::recursive_mutex>(mgr->mutex);
  auto& tvec = mgr->tmap[group_name];
  tvec.push_back(trigger);
  auto& op_group = mgr->op_group;
  for (const Op* op : dmlc::Registry<Op>::List()) {
    if (op->index_ < op_group.size() &&
        op_group[op->index_].count(group_name) != 0) {
      trigger((Op*)op);  // NOLINT(*)
    }
  }
}

Op& Op::include(const std::string& group_name) {
  OpManager* mgr = OpManager::Global();
  std::lock_guard<std::recursive_mutex>(mgr->mutex);
  auto it = mgr->tmap.find(group_name);
  if (it != mgr->tmap.end()) {
    for (auto& trigger : it->second) {
      trigger(this);
    }
  }
  auto& op_group = mgr->op_group;
  if (index_ >= op_group.size()) {
    op_group.resize(index_ + 1);
  }
  op_group[index_].insert(group_name);
  return *this;
}

tqchen committed
132
}  // namespace nnvm