registry.cc 4.55 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21
 * \file registry.cc
22 23 24
 * \brief The global registry of packed function.
 */
#include <dmlc/logging.h>
25
#include <dmlc/thread_local.h>
26
#include <tvm/runtime/registry.h>
27
#include <unordered_map>
28
#include <mutex>
29
#include <memory>
30
#include <array>
31
#include "runtime_base.h"
32 33 34 35

namespace tvm {
namespace runtime {

36
struct Registry::Manager {
37 38 39 40 41
  // map storing the functions.
  // We delibrately used raw pointer
  // This is because PackedFunc can contain callbacks into the host languge(python)
  // and the resource can become invalid because of indeterminstic order of destruction.
  // The resources will only be recycled during program exit.
42
  std::unordered_map<std::string, Registry*> fmap;
43
  // mutex
44
  std::mutex mutex;
45

46 47 48
  Manager() {
  }

49
  static Manager* Global() {
50 51 52 53 54
    // We deliberately leak the Manager instance, to avoid leak sanitizers
    // complaining about the entries in Manager::fmap being leaked at program
    // exit.
    static Manager* inst = new Manager();
    return inst;
55 56 57
  }
};

58 59 60 61 62
Registry& Registry::set_body(PackedFunc f) {  // NOLINT(*)
  func_ = f;
  return *this;
}

63
Registry& Registry::Register(const std::string& name, bool override) {  // NOLINT(*)
64
  Manager* m = Manager::Global();
65
  std::lock_guard<std::mutex> lock(m->mutex);
66
  auto it = m->fmap.find(name);
67 68 69 70 71 72 73
  if (it == m->fmap.end()) {
    Registry* r = new Registry();
    r->name_ = name;
    m->fmap[name] = r;
    return *r;
  } else {
    CHECK(override)
74
      << "Global PackedFunc " << name << " is already registered";
75 76
    return *it->second;
  }
77 78
}

79 80
bool Registry::Remove(const std::string& name) {
  Manager* m = Manager::Global();
81
  std::lock_guard<std::mutex> lock(m->mutex);
82 83 84 85
  auto it = m->fmap.find(name);
  if (it == m->fmap.end()) return false;
  m->fmap.erase(it);
  return true;
86 87
}

88 89
const PackedFunc* Registry::Get(const std::string& name) {
  Manager* m = Manager::Global();
90
  std::lock_guard<std::mutex> lock(m->mutex);
91 92 93
  auto it = m->fmap.find(name);
  if (it == m->fmap.end()) return nullptr;
  return &(it->second->func_);
94 95
}

96 97
std::vector<std::string> Registry::ListNames() {
  Manager* m = Manager::Global();
98
  std::lock_guard<std::mutex> lock(m->mutex);
99
  std::vector<std::string> keys;
100 101
  keys.reserve(m->fmap.size());
  for (const auto &kv : m->fmap) {
102 103 104 105 106 107 108 109
    keys.push_back(kv.first);
  }
  return keys;
}

}  // namespace runtime
}  // namespace tvm

110 111 112 113 114 115 116 117 118 119 120
/*! \brief entry to to easily hold returning information */
struct TVMFuncThreadLocalEntry {
  /*! \brief result holder for returning strings */
  std::vector<std::string> ret_vec_str;
  /*! \brief result holder for returning string pointers */
  std::vector<const char *> ret_vec_charp;
};

/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;

121 122
int TVMFuncRegisterGlobal(
    const char* name, TVMFunctionHandle f, int override) {
123
  API_BEGIN();
124
  tvm::runtime::Registry::Register(name, override != 0)
125
      .set_body(*static_cast<tvm::runtime::PackedFunc*>(f));
126 127 128 129 130
  API_END();
}

int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
  API_BEGIN();
131 132
  const tvm::runtime::PackedFunc* fp =
      tvm::runtime::Registry::Get(name);
133 134 135 136 137
  if (fp != nullptr) {
    *out = new tvm::runtime::PackedFunc(*fp);  // NOLINT(*)
  } else {
    *out = nullptr;
  }
138 139 140 141 142 143 144
  API_END();
}

int TVMFuncListGlobalNames(int *out_size,
                           const char*** out_array) {
  API_BEGIN();
  TVMFuncThreadLocalEntry *ret = TVMFuncThreadLocalStore::Get();
145
  ret->ret_vec_str = tvm::runtime::Registry::ListNames();
146 147 148 149 150 151
  ret->ret_vec_charp.clear();
  for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
    ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
  }
  *out_array = dmlc::BeginPtr(ret->ret_vec_charp);
  *out_size = static_cast<int>(ret->ret_vec_str.size());
152 153
  API_END();
}