op.cc 4.68 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
/*!
 *  Copyright (c) 2018 by Contributors
 * \file src/tvm/relay/op.cc
 * \brief Resolve incomplete types to complete types.
 */
#include <tvm/relay/op.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>

#include <memory>
#include <mutex>

namespace dmlc {
// enable registry
DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry);
}  // namespace dmlc

namespace tvm {
namespace relay {

::dmlc::Registry<OpRegistry>* OpRegistry::Registry() {
  return ::dmlc::Registry<OpRegistry>::Get();
}

// single manager of operator information.
struct OpManager {
  // mutex to avoid registration from multiple threads.
  std::mutex mutex;
  // global operator counter
  std::atomic<int> op_counter{0};
  // storage of additional attribute table.
  std::unordered_map<std::string, std::unique_ptr<GenericOpMap>> attr;
  // frontend functions
  std::vector<PackedFunc*> frontend_funcs;
  // get singleton of the op manager
  static OpManager* Global() {
    static OpManager inst;
    return &inst;
  }
};

// find operator by name
const Op& Op::Get(const std::string& name) {
  const OpRegistry* reg = dmlc::Registry<OpRegistry>::Find(name);
  CHECK(reg != nullptr) << "Operator " << name << " is not registered";
  return reg->op();
}

OpRegistry::OpRegistry() {
  OpManager* mgr = OpManager::Global();
52
  NodePtr<OpNode> n = make_node<OpNode>();
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
  n->index_ = mgr->op_counter++;
  op_ = Op(n);
}

// Get attribute map by key
const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
  OpManager* mgr = OpManager::Global();
  std::lock_guard<std::mutex> lock(mgr->mutex);
  auto it = mgr->attr.find(key);
  if (it == mgr->attr.end()) {
    LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered";
  }
  return *it->second.get();
}

68 69
void OpRegistry::UpdateAttr(const std::string& key,
                            TVMRetValue value,
70 71 72 73 74 75
                            int plevel) {
  OpManager* mgr = OpManager::Global();
  std::lock_guard<std::mutex> lock(mgr->mutex);
  std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
  if (op_map == nullptr) {
    op_map.reset(new GenericOpMap());
76
    op_map->attr_name_ = key;
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
  }
  uint32_t index = op_->index_;
  if (op_map->data_.size() <= index) {
    op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
  }
  std::pair<TVMRetValue, int>& p = op_map->data_[index];
  CHECK(p.second != plevel)
      << "Attribute " << key << " of operator " << this->name
      << " is already registered with same plevel=" << plevel;
  if (p.second < plevel) {
    op_map->data_[index] = std::make_pair(value, plevel);
  }
}

// Frontend APIs
TVM_REGISTER_API("relay.op._ListOpNames")
93 94 95 96 97 98 99 100
.set_body_typed<Array<tvm::Expr>()>([]() {
    Array<tvm::Expr> ret;
    for (const std::string& name :
             dmlc::Registry<OpRegistry>::ListAllNames()) {
      ret.push_back(tvm::Expr(name));
    }
    return ret;
  });
101 102 103 104 105 106 107 108 109 110 111 112 113 114

TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get);

TVM_REGISTER_API("relay.op._OpGetAttr")
    .set_body([](TVMArgs args, TVMRetValue* rv) {
      Op op = args[0];
      std::string attr_name = args[1];
      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
      if (op_map.count(op)) {
        *rv = op_map[op];
      }
    });

TVM_REGISTER_API("relay.op._Register")
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
.set_body([](TVMArgs args, TVMRetValue* rv) {
    std::string op_name = args[0];
    std::string attr_key = args[1];
    runtime::TVMArgValue value = args[2];
    int plevel = args[3];
    auto& reg =
        OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
    // enable resgiteration and override of certain properties
    if (attr_key == "num_inputs" && plevel > 128) {
      reg.set_num_inputs(value);
    } else if (attr_key == "attrs_type_key" && plevel > 128) {
      reg.set_attrs_type_key(value);
    } else {
      // normal attr table override.
      if (args[2].type_code() == kFuncHandle) {
        // do an eager copy of the PackedFunc
        PackedFunc f = args[2];
        // If we get a function from frontend, avoid deleting it.
        OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
        reg.set_attr(attr_key, f, plevel);
135
      } else {
136
        reg.set_attr(attr_key, args[2], plevel);
137
      }
138 139
    }
  });
140

141
NodePtr<Node> CreateOp(const std::string& name) {
142
  auto op = Op::Get(name);
143
  CHECK(op.defined()) << "Cannot find op \'" << name << '\'';
144
  return op.node_;
145 146 147 148 149 150 151 152
}

TVM_REGISTER_NODE_TYPE(OpNode)
.set_creator(CreateOp)
.set_global_key([](const Node* n) {
    return static_cast<const OpNode*>(n)->name;
  });

153 154 155 156 157
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<OpNode>([](const OpNode* node, tvm::IRPrinter* p) {
    p->stream << "Op(" << node->name << ")";
  });

158 159
}  // namespace relay
}  // namespace tvm