system_lib_module.cc 3.47 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 * \file system_lib_module.cc
 * \brief SystemLib module.
 */
#include <tvm/runtime/registry.h>
25
#include <tvm/runtime/memory.h>
26
#include <tvm/runtime/c_backend_api.h>
27
#include <mutex>
28
#include "module_util.h"
29 30 31 32 33 34

namespace tvm {
namespace runtime {

class SystemLibModuleNode : public ModuleNode {
 public:
35 36
  SystemLibModuleNode() = default;

37 38 39 40 41 42
  const char* type_key() const final {
    return "system_lib";
  }

  PackedFunc GetFunction(
      const std::string& name,
43
      const ObjectPtr<Object>& sptr_to_self) final {
44
    std::lock_guard<std::mutex> lock(mutex_);
45 46 47 48 49 50 51

    if (module_blob_ != nullptr) {
      // If we previously recorded submodules, load them now.
      ImportModuleBlob(reinterpret_cast<const char*>(module_blob_), &imports_);
      module_blob_ = nullptr;
    }

52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    auto it = tbl_.find(name);
    if (it != tbl_.end()) {
      return WrapPackedFunc(
          reinterpret_cast<BackendPackedCFunc>(it->second), sptr_to_self);
    } else {
      return PackedFunc();
    }
  }

  void RegisterSymbol(const std::string& name, void* ptr) {
    std::lock_guard<std::mutex> lock(mutex_);
    if (name == symbol::tvm_module_ctx) {
      void** ctx_addr = reinterpret_cast<void**>(ptr);
      *ctx_addr = this;
    } else if (name == symbol::tvm_dev_mblob) {
67 68 69 70 71 72 73 74
      // Record pointer to content of submodules to be loaded.
      // We defer loading submodules to the first call to GetFunction().
      // The reason is that RegisterSymbol() gets called when initializing the
      // syslib (i.e. library loading time), and the registeries aren't ready
      // yet. Therefore, we might not have the functionality to load submodules
      // now.
      CHECK(module_blob_ == nullptr) << "Resetting mobule blob?";
      module_blob_ = ptr;
75 76
    } else {
      auto it = tbl_.find(name);
nhynes committed
77 78 79 80
      if (it != tbl_.end() && ptr != it->second) {
        LOG(WARNING) << "SystemLib symbol " << name
                     << " get overriden to a different address "
                     << ptr << "->" << it->second;
81
      }
nhynes committed
82
      tbl_[name] = ptr;
83 84 85
    }
  }

86 87
  static const ObjectPtr<SystemLibModuleNode>& Global() {
    static auto inst = make_object<SystemLibModuleNode>();
88 89 90 91 92 93 94 95
    return inst;
  }

 private:
  // Internal mutex
  std::mutex mutex_;
  // Internal symbol table
  std::unordered_map<std::string, void*> tbl_;
96 97
  // Module blob to be imported
  void* module_blob_{nullptr};
98 99 100 101 102 103 104 105 106 107 108 109 110
};

TVM_REGISTER_GLOBAL("module._GetSystemLib")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = runtime::Module(SystemLibModuleNode::Global());
  });
}  // namespace runtime
}  // namespace tvm

int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) {
  tvm::runtime::SystemLibModuleNode::Global()->RegisterSymbol(name, ptr);
  return 0;
}