registry.cc 4.7 KB
Newer Older
1 2
/*!
 *  Copyright (c) 2017 by Contributors
3
 * \file registry.cc
4 5 6
 * \brief The global registry of packed function.
 */
#include <dmlc/logging.h>
7
#include <dmlc/thread_local.h>
8
#include <tvm/runtime/registry.h>
9
#include <unordered_map>
10
#include <mutex>
11
#include <memory>
12
#include <array>
13
#include "runtime_base.h"
14 15 16 17

namespace tvm {
namespace runtime {

18
struct Registry::Manager {
19 20 21 22 23
  // map storing the functions.
  // We delibrately used raw pointer
  // This is because PackedFunc can contain callbacks into the host languge(python)
  // and the resource can become invalid because of indeterminstic order of destruction.
  // The resources will only be recycled during program exit.
24
  std::unordered_map<std::string, Registry*> fmap;
25 26
  // vtable for extension type
  std::array<ExtTypeVTable, kExtEnd> ext_vtable;
27
  // mutex
28
  std::mutex mutex;
29

30
  Manager() {
31 32 33
    for (auto& x : ext_vtable) {
      x.destroy = nullptr;
    }
34 35
  }

36
  static Manager* Global() {
37 38 39 40 41
    // We deliberately leak the Manager instance, to avoid leak sanitizers
    // complaining about the entries in Manager::fmap being leaked at program
    // exit.
    static Manager* inst = new Manager();
    return inst;
42 43 44
  }
};

45 46 47 48 49
Registry& Registry::set_body(PackedFunc f) {  // NOLINT(*)
  func_ = f;
  return *this;
}

50
Registry& Registry::Register(const std::string& name, bool override) {  // NOLINT(*)
51
  Manager* m = Manager::Global();
52
  std::lock_guard<std::mutex> lock(m->mutex);
53
  auto it = m->fmap.find(name);
54 55 56 57 58 59 60
  if (it == m->fmap.end()) {
    Registry* r = new Registry();
    r->name_ = name;
    m->fmap[name] = r;
    return *r;
  } else {
    CHECK(override)
61
      << "Global PackedFunc " << name << " is already registered";
62 63
    return *it->second;
  }
64 65
}

66 67
bool Registry::Remove(const std::string& name) {
  Manager* m = Manager::Global();
68
  std::lock_guard<std::mutex> lock(m->mutex);
69 70 71 72
  auto it = m->fmap.find(name);
  if (it == m->fmap.end()) return false;
  m->fmap.erase(it);
  return true;
73 74
}

75 76
const PackedFunc* Registry::Get(const std::string& name) {
  Manager* m = Manager::Global();
77
  std::lock_guard<std::mutex> lock(m->mutex);
78 79 80
  auto it = m->fmap.find(name);
  if (it == m->fmap.end()) return nullptr;
  return &(it->second->func_);
81 82
}

83 84
std::vector<std::string> Registry::ListNames() {
  Manager* m = Manager::Global();
85
  std::lock_guard<std::mutex> lock(m->mutex);
86
  std::vector<std::string> keys;
87 88
  keys.reserve(m->fmap.size());
  for (const auto &kv : m->fmap) {
89 90 91 92 93
    keys.push_back(kv.first);
  }
  return keys;
}

94 95 96 97 98 99 100 101 102 103 104 105 106
ExtTypeVTable* ExtTypeVTable::Get(int type_code) {
  CHECK(type_code > kExtBegin && type_code < kExtEnd);
  Registry::Manager* m = Registry::Manager::Global();
  ExtTypeVTable* vt = &(m->ext_vtable[type_code]);
  CHECK(vt->destroy != nullptr)
      << "Extension type not registered";
  return vt;
}

ExtTypeVTable* ExtTypeVTable::RegisterInternal(
    int type_code, const ExtTypeVTable& vt) {
  CHECK(type_code > kExtBegin && type_code < kExtEnd);
  Registry::Manager* m = Registry::Manager::Global();
107
  std::lock_guard<std::mutex> lock(m->mutex);
108 109 110 111
  ExtTypeVTable* pvt = &(m->ext_vtable[type_code]);
  pvt[0] = vt;
  return pvt;
}
112 113 114
}  // namespace runtime
}  // namespace tvm

115 116 117 118 119 120 121 122 123 124 125
/*! \brief entry to to easily hold returning information */
struct TVMFuncThreadLocalEntry {
  /*! \brief result holder for returning strings */
  std::vector<std::string> ret_vec_str;
  /*! \brief result holder for returning string pointers */
  std::vector<const char *> ret_vec_charp;
};

/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;

126 127 128 129 130 131
int TVMExtTypeFree(void* handle, int type_code) {
  API_BEGIN();
  tvm::runtime::ExtTypeVTable::Get(type_code)->destroy(handle);
  API_END();
}

132 133
int TVMFuncRegisterGlobal(
    const char* name, TVMFunctionHandle f, int override) {
134
  API_BEGIN();
135
  tvm::runtime::Registry::Register(name, override != 0)
136
      .set_body(*static_cast<tvm::runtime::PackedFunc*>(f));
137 138 139 140 141
  API_END();
}

int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
  API_BEGIN();
142 143
  const tvm::runtime::PackedFunc* fp =
      tvm::runtime::Registry::Get(name);
144 145 146 147 148
  if (fp != nullptr) {
    *out = new tvm::runtime::PackedFunc(*fp);  // NOLINT(*)
  } else {
    *out = nullptr;
  }
149 150 151 152 153 154 155
  API_END();
}

int TVMFuncListGlobalNames(int *out_size,
                           const char*** out_array) {
  API_BEGIN();
  TVMFuncThreadLocalEntry *ret = TVMFuncThreadLocalStore::Get();
156
  ret->ret_vec_str = tvm::runtime::Registry::ListNames();
157 158 159 160 161 162
  ret->ret_vec_charp.clear();
  for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
    ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
  }
  *out_array = dmlc::BeginPtr(ret->ret_vec_charp);
  *out_size = static_cast<int>(ret->ret_vec_str.size());
163 164
  API_END();
}