Unverified Commit 11ee1a0e by Ruizhe Zhao Committed by GitHub

Return empty CSourceModule when no lowered_funcs exists in Relay mod (#4847)

* Use dummy func when no lowered_funcs exists in Relay mod

* Dummy func -> CSourceModule with empty code str

* Added comments describing the empty CSouceModule

* Always import external modules w/o assertions

* Use CSourceModule as a fallback for LLVMModule

* Changed cond for target == llvm

* Create an empty LLVM module w/o using dummy func

* Avoid using IR str concat to create LLVM module

* Improved comments for codegen.LLVMModuleCreate

* Satisfy the linter for LLVMModuleCreate
parent 1c347315
...@@ -28,8 +28,10 @@ ...@@ -28,8 +28,10 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/relay/qnn/transform.h> #include <tvm/relay/qnn/transform.h>
#include <tvm/tir/ir_pass.h>
#include <memory> #include <memory>
#include "../../target/source/codegen_source_base.h"
#include "utils.h" #include "utils.h"
namespace tvm { namespace tvm {
...@@ -451,28 +453,51 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -451,28 +453,51 @@ class RelayBuildModule : public runtime::ModuleNode {
ret_.params = graph_codegen_->GetParams(); ret_.params = graph_codegen_->GetParams();
auto lowered_funcs = graph_codegen_->GetLoweredFunc(); auto lowered_funcs = graph_codegen_->GetLoweredFunc();
// When there is no lowered_funcs due to reasons such as optimization.
if (lowered_funcs.size() == 0) { if (lowered_funcs.size() == 0) {
LOG(WARNING) << "no lowered funcs exist in the compiled module"; Target target_host = GetTargetHost();
// If no target_host has been set, we choose a default one, which is
// llvm if "codegen.LLVMModuleCreate" is accessible.
const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate");
if (!target_host.defined())
target_host = (pf != nullptr) ? target::llvm() : target::stackvm();
if (target_host.defined() && target_host->target_name == "llvm") {
// If we can decide the target is LLVM, we then create an empty LLVM module.
ret_.mod = (*pf)(target_host->str(), "empty_module");
} else {
// If we cannot decide the target is LLVM, we create an empty CSourceModule.
// The code content is initialized with ";" to prevent complaining
// from CSourceModuleNode::SaveToFile.
ret_.mod = tvm::codegen::CSourceModuleCreate(";", "");
}
} else { } else {
ret_.mod = tvm::build( ret_.mod = tvm::build(
lowered_funcs, lowered_funcs,
target_host_, target_host_,
BuildConfig::Current()); BuildConfig::Current());
} }
Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules(); Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
if (!ext_mods.empty()) {
CHECK(lowered_funcs.size() > 0 || ext_mods.size() == 1)
<< "Expect to have a TVM DSOModule when multiple external runtime modules exist";
if (lowered_funcs.size() == 0) {
// Execute the whole module using external runtime.
ret_.mod = ext_mods[0];
} else {
// Import all external runtime modules. // Import all external runtime modules.
for (const auto& it : ext_mods) { for (const auto& it : ext_mods)
ret_.mod.Import(it); ret_.mod.Import(it);
} }
private:
Target GetTargetHost() {
Target target_host = target_host_;
if (!target_host_.defined()) {
for (const auto &it : targets_) {
if (it.second->device_type == kDLCPU) {
target_host = it.second;
break;
}
} }
} }
return target_host;
} }
protected: protected:
......
...@@ -356,6 +356,28 @@ TVM_REGISTER_GLOBAL("codegen.build_llvm") ...@@ -356,6 +356,28 @@ TVM_REGISTER_GLOBAL("codegen.build_llvm")
*rv = runtime::Module(n); *rv = runtime::Module(n);
}); });
TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
.set_body([](TVMArgs args, TVMRetValue *rv) {
auto n = make_object<LLVMModuleNode>();
auto target = args[0].operator std::string();
auto module_name = args[1].operator std::string();
// Generate a LLVM module from an input target string
InitializeLLVM();
auto tm = GetLLVMTargetMachine(target);
auto ctx = std::make_shared<llvm::LLVMContext>();
std::unique_ptr<llvm::Module> module(new llvm::Module(module_name, *ctx));
// Use a default data layout and target triple
auto triple = tm->getTargetTriple();
module->setTargetTriple(triple.str());
module->setDataLayout(tm->createDataLayout());
n->Init(std::move(module), ctx);
*rv = runtime::Module(n);
});
TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0])); *rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment