module.cc 5.45 KB
Newer Older
1 2 3
/*!
 *  Copyright (c) 2017 by Contributors
 * \file module.cc
4
 * \brief TVM module system
5 6 7 8 9
 */
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include <unordered_set>
10
#include <cstring>
nhynes committed
11
#ifndef _LIBCPP_SGX_CONFIG
12
#include "file_util.h"
nhynes committed
13
#endif
14 15 16 17 18

namespace tvm {
namespace runtime {

void Module::Import(Module other) {
19 20 21 22
  // specially handle rpc
  if (!std::strcmp((*this)->type_key(), "rpc")) {
    static const PackedFunc* fimport_ = nullptr;
    if (fimport_ == nullptr) {
23
      fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
24 25 26 27 28
      CHECK(fimport_ != nullptr);
    }
    (*fimport_)(*this, other);
    return;
  }
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
  // cyclic detection.
  std::unordered_set<const ModuleNode*> visited{other.node_.get()};
  std::vector<const ModuleNode*> stack{other.node_.get()};
  while (!stack.empty()) {
    const ModuleNode* n = stack.back();
    stack.pop_back();
    for (const Module& m : n->imports_) {
      const ModuleNode* next = m.node_.get();
      if (visited.count(next)) continue;
      visited.insert(next);
      stack.push_back(next);
    }
  }
  CHECK(!visited.count(node_.get()))
      << "Cyclic dependency detected during import";
  node_->imports_.emplace_back(std::move(other));
}

Module Module::LoadFromFile(const std::string& file_name,
                            const std::string& format) {
nhynes committed
49
#ifndef _LIBCPP_SGX_CONFIG
50 51 52 53 54 55
  std::string fmt = GetFileFormat(file_name, format);
  CHECK(fmt.length() != 0)
      << "Cannot deduce format of file " << file_name;
  if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
    fmt = "so";
  }
56
  std::string load_f_name = "module.loadfile_" + fmt;
57 58 59 60 61 62
  const PackedFunc* f = Registry::Get(load_f_name);
  CHECK(f != nullptr)
      << "Loader of " << format << "("
      << load_f_name << ") is not presented.";
  Module m = (*f)(file_name, format);
  return m;
nhynes committed
63 64 65
#else
  LOG(FATAL) << "SGX does not support LoadFromFile";
#endif
66 67
}

68 69 70 71 72 73 74 75 76 77 78 79 80 81
void ModuleNode::SaveToFile(const std::string& file_name,
                            const std::string& format) {
  LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
}

void ModuleNode::SaveToBinary(dmlc::Stream* stream) {
  LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary";
}

std::string ModuleNode::GetSource(const std::string& format) {
  LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource";
  return "";
}

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
  auto it = import_cache_.find(name);
  if (it != import_cache_.end()) return it->second.get();
  PackedFunc pf;
  for (Module& m : this->imports_) {
    pf = m.GetFunction(name, false);
    if (pf != nullptr) break;
  }
  if (pf == nullptr) {
    const PackedFunc* f = Registry::Get(name);
    CHECK(f != nullptr)
        << "Cannot find function " << name
        << " in the imported modules or global registry";
    return f;
  } else {
    std::unique_ptr<PackedFunc> f(new PackedFunc(pf));
    import_cache_[name] = std::move(f);
    return import_cache_.at(name).get();
  }
}

103
bool RuntimeEnabled(const std::string& target) {
104
  std::string f_name;
105 106 107
  if (target == "cpu") {
    return true;
  } else if (target == "cuda" || target == "gpu") {
108
    f_name = "device_api.gpu";
109
  } else if (target == "cl" || target == "opencl" || target == "sdaccel") {
110
    f_name = "device_api.opencl";
111 112
  } else if (target == "gl" || target == "opengl") {
    f_name = "device_api.opengl";
113 114
  } else if (target == "mtl" || target == "metal") {
    f_name = "device_api.metal";
115 116
  } else if (target == "vulkan") {
    f_name = "device_api.vulkan";
117 118
  } else if (target == "stackvm") {
    f_name = "codegen.build_stackvm";
119 120
  } else if (target == "rpc") {
    f_name = "device_api.rpc";
121 122
  } else if (target == "vpi" || target == "verilog") {
    f_name = "device_api.vpi";
123
  } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
124
    f_name = "device_api.gpu";
125
  } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
126
    f_name = "device_api.rocm";
127 128 129 130
  } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
    const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
    if (pf == nullptr) return false;
    return (*pf)(target);
131 132 133
  } else {
    LOG(FATAL) << "Unknown optional runtime " << target;
  }
134
  return runtime::Registry::Get(f_name) != nullptr;
135 136
}

137
TVM_REGISTER_GLOBAL("module._Enabled")
138 139 140 141
.set_body([](TVMArgs args, TVMRetValue *ret) {
    *ret = RuntimeEnabled(args[0]);
    });

142
TVM_REGISTER_GLOBAL("module._GetSource")
143 144 145 146
.set_body([](TVMArgs args, TVMRetValue *ret) {
    *ret = args[0].operator Module()->GetSource(args[1]);
    });

147
TVM_REGISTER_GLOBAL("module._ImportsSize")
148 149 150 151 152
.set_body([](TVMArgs args, TVMRetValue *ret) {
    *ret = static_cast<int64_t>(
        args[0].operator Module()->imports().size());
    });

153
TVM_REGISTER_GLOBAL("module._GetImport")
154 155 156 157 158
.set_body([](TVMArgs args, TVMRetValue *ret) {
    *ret = args[0].operator Module()->
        imports().at(args[1].operator int());
    });

159
TVM_REGISTER_GLOBAL("module._GetTypeKey")
160
.set_body([](TVMArgs args, TVMRetValue *ret) {
161
    *ret = std::string(args[0].operator Module()->type_key());
162 163
    });

164
TVM_REGISTER_GLOBAL("module._LoadFromFile")
165 166 167 168
.set_body([](TVMArgs args, TVMRetValue *ret) {
    *ret = Module::LoadFromFile(args[0], args[1]);
    });

169
TVM_REGISTER_GLOBAL("module._SaveToFile")
170 171 172 173 174 175
.set_body([](TVMArgs args, TVMRetValue *ret) {
    args[0].operator Module()->
        SaveToFile(args[1], args[2]);
    });
}  // namespace runtime
}  // namespace tvm