Unverified Commit ac3f5bd9 by Tianqi Chen Committed by GitHub

[RELAY] Hotfix build_module creation (#3198)

parent 493f90ff
......@@ -18,12 +18,11 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file relay/backend/build_module.cc
* \brief Code generation for TVM's graph runtime.
*/
#include <tvm/build_module.h>
#include <tvm/runtime/device_api.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h>
......@@ -41,31 +40,6 @@ namespace backend {
using TargetsMap = Map<tvm::Integer, tvm::Target>;
/*!
* \brief Context index to Target
*/
struct ContextTargetMap {
static const std::unordered_map<int, tvm::Target> mask2str;
static tvm::Target Mask2Str(int mask) {
CHECK_GT(mask2str.count(mask), 0) << "Unknown mask.";
return mask2str.at(mask);
}
};
const std::unordered_map<int, tvm::Target> ContextTargetMap::mask2str = {
{1, tvm::Target::create("llvm")},
{2, tvm::Target::create("cuda")},
{4, tvm::Target::create("opencl")},
{5, tvm::Target::create("aocl")},
{6, tvm::Target::create("sdaccel")},
{7, tvm::Target::create("vulkan")},
{8, tvm::Target::create("metal")},
{9, tvm::Target::create("vpi")},
{10, tvm::Target::create("rocm")},
{11, tvm::Target::create("opengl")},
{12, tvm::Target::create("ext_dev")}
};
/*!
* \brief A data structure to map the names of specific optimizations to
* numeric optimization levels
*
......@@ -310,8 +284,8 @@ class RelayBuildModule : public runtime::ModuleNode {
*
* \return Array<StringImm> names of params
*/
Array<HalideIR::Expr> ListParamNames() {
Array<HalideIR::Expr> ret;
Array<tvm::Expr> ListParamNames() {
Array<tvm::Expr> ret;
for (const auto& kv : params_) {
ret.push_back(ir::StringImm::make(kv.first));
}
......@@ -470,12 +444,9 @@ class RelayBuildModule : public runtime::ModuleNode {
if (cfg.pass_enabled("AlterOpLayout")) {
if (targets.size() == 1) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
auto enter_pf = GetPackedFunc("_EnterTargetScope");
auto exit_pf = GetPackedFunc("_ExitTargetScope");
for (const auto& kv : targets) {
(*enter_pf)(kv.second);
TargetContext tctx(kv.second);
func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
(*exit_pf)();
}
} else {
LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous"
......@@ -487,6 +458,18 @@ class RelayBuildModule : public runtime::ModuleNode {
}
return func;
}
/*!
* \brief Create a default type.
* \param device_type The device type index.
* \return the default target for the device.
*/
Target CreateDefaultTarget(int device_type) {
std::string name = runtime::DeviceName(device_type);
if (name == "cpu") return Target::create("llvm");
if (name == "gpu") return Target::create("cuda");
return Target::create(name);
}
/*!
* \brief Update the target and fallback device required for heterogeneous
* compilation. CPU is used as the fallback device if it wasn't provided.
......@@ -507,7 +490,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if (tmp_map.count(cfg.fallback_device) == 0) {
device_target.Set(
cfg.fallback_device,
ContextTargetMap::Mask2Str(cfg.fallback_device));
CreateDefaultTarget(cfg.fallback_device));
}
return device_target;
}
......@@ -520,7 +503,8 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param targets_map_ptr
* \return Function
*/
Function RunDeviceAnnotationPass(Function func, const RelayBuildConfig& cfg,
Function RunDeviceAnnotationPass(Function func,
const RelayBuildConfig& cfg,
TargetsMap* targets_map_ptr) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func,
......@@ -532,7 +516,7 @@ class RelayBuildModule : public runtime::ModuleNode {
"relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr);
if (annotation_map.size() == 0) {
targets_map_ptr->Set(
0, ContextTargetMap::Mask2Str(cfg.fallback_device));
0, CreateDefaultTarget(cfg.fallback_device));
} else {
int64_t dev_type = -1;
for (auto kv : annotation_map) {
......@@ -547,7 +531,7 @@ class RelayBuildModule : public runtime::ModuleNode {
<< "found. Please check the "
<< "RewriteAnnotation pass.";
}
targets_map_ptr->Set(0, ContextTargetMap::Mask2Str(dev_type));
targets_map_ptr->Set(0, CreateDefaultTarget(dev_type));
}
}
return func;
......@@ -611,7 +595,8 @@ runtime::Module RelayBuildCreate() {
return runtime::Module(exec);
}
TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
TVM_REGISTER_GLOBAL("relay.build_module._BuildModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RelayBuildCreate();
});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment