module.cc 6.16 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21
/*!
 * \file module.cc
22
 * \brief TVM module system
23 24 25 26 27
 */
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include <unordered_set>
28
#include <cstring>
29
#include "file_util.h"
30 31 32 33

namespace tvm {
namespace runtime {

34
void ModuleNode::Import(Module other) {
35
  // specially handle rpc
36
  if (!std::strcmp(this->type_key(), "rpc")) {
37 38
    static const PackedFunc* fimport_ = nullptr;
    if (fimport_ == nullptr) {
39
      fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
40 41
      CHECK(fimport_ != nullptr);
    }
42
    (*fimport_)(GetRef<Module>(this), other);
43 44
    return;
  }
45
  // cyclic detection.
46 47
  std::unordered_set<const ModuleNode*> visited{other.operator->()};
  std::vector<const ModuleNode*> stack{other.operator->()};
48 49 50 51
  while (!stack.empty()) {
    const ModuleNode* n = stack.back();
    stack.pop_back();
    for (const Module& m : n->imports_) {
52
      const ModuleNode* next = m.operator->();
53 54 55 56 57
      if (visited.count(next)) continue;
      visited.insert(next);
      stack.push_back(next);
    }
  }
58
  CHECK(!visited.count(this))
59
      << "Cyclic dependency detected during import";
60 61 62 63 64 65 66 67 68 69 70 71 72 73
  this->imports_.emplace_back(std::move(other));
}

PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) {
  ModuleNode* self = this;
  PackedFunc pf = self->GetFunction(name, GetObjectPtr<Object>(this));
  if (pf != nullptr) return pf;
  if (query_imports) {
    for (Module& m : self->imports_) {
      pf = m->GetFunction(name, m.data_);
      if (pf != nullptr) return pf;
    }
  }
  return pf;
74 75 76 77 78 79 80 81 82 83
}

Module Module::LoadFromFile(const std::string& file_name,
                            const std::string& format) {
  std::string fmt = GetFileFormat(file_name, format);
  CHECK(fmt.length() != 0)
      << "Cannot deduce format of file " << file_name;
  if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
    fmt = "so";
  }
84
  std::string load_f_name = "runtime.module.loadfile_" + fmt;
85 86 87 88 89 90 91 92
  const PackedFunc* f = Registry::Get(load_f_name);
  CHECK(f != nullptr)
      << "Loader of " << format << "("
      << load_f_name << ") is not presented.";
  Module m = (*f)(file_name, format);
  return m;
}

93 94 95 96 97 98 99 100 101 102 103 104 105 106
void ModuleNode::SaveToFile(const std::string& file_name,
                            const std::string& format) {
  LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
}

void ModuleNode::SaveToBinary(dmlc::Stream* stream) {
  LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary";
}

std::string ModuleNode::GetSource(const std::string& format) {
  LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource";
  return "";
}

107 108 109 110 111
const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
  auto it = import_cache_.find(name);
  if (it != import_cache_.end()) return it->second.get();
  PackedFunc pf;
  for (Module& m : this->imports_) {
112
    pf = m.GetFunction(name, true);
113 114 115 116 117 118 119 120 121
    if (pf != nullptr) break;
  }
  if (pf == nullptr) {
    const PackedFunc* f = Registry::Get(name);
    CHECK(f != nullptr)
        << "Cannot find function " << name
        << " in the imported modules or global registry";
    return f;
  } else {
122
    import_cache_.insert(std::make_pair(name, std::make_shared<PackedFunc>(pf)));
123 124 125 126
    return import_cache_.at(name).get();
  }
}

127
bool RuntimeEnabled(const std::string& target) {
128
  std::string f_name;
129 130 131
  if (target == "cpu") {
    return true;
  } else if (target == "cuda" || target == "gpu") {
132
    f_name = "device_api.gpu";
133
  } else if (target == "cl" || target == "opencl" || target == "sdaccel") {
134
    f_name = "device_api.opencl";
135 136
  } else if (target == "gl" || target == "opengl") {
    f_name = "device_api.opengl";
137 138
  } else if (target == "mtl" || target == "metal") {
    f_name = "device_api.metal";
139 140
  } else if (target == "vulkan") {
    f_name = "device_api.vulkan";
141
  } else if (target == "stackvm") {
142
    f_name = "target.build.stackvm";
143 144
  } else if (target == "rpc") {
    f_name = "device_api.rpc";
145 146
  } else if (target == "micro_dev") {
    f_name = "device_api.micro_dev";
147
  } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
148
    f_name = "device_api.gpu";
149
  } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
150
    f_name = "device_api.rocm";
151 152 153 154
  } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
    const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
    if (pf == nullptr) return false;
    return (*pf)(target);
155 156 157
  } else {
    LOG(FATAL) << "Unknown optional runtime " << target;
  }
158
  return runtime::Registry::Get(f_name) != nullptr;
159 160
}

161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled")
.set_body_typed(RuntimeEnabled);

TVM_REGISTER_GLOBAL("runtime.ModuleGetSource")
.set_body_typed([](Module mod, std::string fmt) {
  return mod->GetSource(fmt);
});

TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize")
.set_body_typed([](Module mod) {
  return static_cast<int64_t>(mod->imports().size());
});

TVM_REGISTER_GLOBAL("runtime.ModuleGetImport")
.set_body_typed([](Module mod, int index) {
  return mod->imports().at(index);
});

TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey")
.set_body_typed([](Module mod) {
  return std::string(mod->type_key());
});

TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile")
.set_body_typed(Module::LoadFromFile);

TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
.set_body_typed([](Module mod, std::string name, std::string fmt) {
  mod->SaveToFile(name, fmt);
});
191 192
}  // namespace runtime
}  // namespace tvm