Unverified Commit ac3f5bd9 by Tianqi Chen Committed by GitHub

[RELAY] Hotfix build_module creation (#3198)

parent 493f90ff
...@@ -18,12 +18,11 @@ ...@@ -18,12 +18,11 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file relay/backend/build_module.cc * \file relay/backend/build_module.cc
* \brief Code generation for TVM's graph runtime. * \brief Code generation for TVM's graph runtime.
*/ */
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/runtime/device_api.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
...@@ -41,31 +40,6 @@ namespace backend { ...@@ -41,31 +40,6 @@ namespace backend {
using TargetsMap = Map<tvm::Integer, tvm::Target>; using TargetsMap = Map<tvm::Integer, tvm::Target>;
/*! /*!
* \brief Context index to Target
*/
struct ContextTargetMap {
static const std::unordered_map<int, tvm::Target> mask2str;
static tvm::Target Mask2Str(int mask) {
CHECK_GT(mask2str.count(mask), 0) << "Unknown mask.";
return mask2str.at(mask);
}
};
const std::unordered_map<int, tvm::Target> ContextTargetMap::mask2str = {
{1, tvm::Target::create("llvm")},
{2, tvm::Target::create("cuda")},
{4, tvm::Target::create("opencl")},
{5, tvm::Target::create("aocl")},
{6, tvm::Target::create("sdaccel")},
{7, tvm::Target::create("vulkan")},
{8, tvm::Target::create("metal")},
{9, tvm::Target::create("vpi")},
{10, tvm::Target::create("rocm")},
{11, tvm::Target::create("opengl")},
{12, tvm::Target::create("ext_dev")}
};
/*!
* \brief A data structure to map the names of specific optimizations to * \brief A data structure to map the names of specific optimizations to
* numeric optimization levels * numeric optimization levels
* *
...@@ -310,8 +284,8 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -310,8 +284,8 @@ class RelayBuildModule : public runtime::ModuleNode {
* *
* \return Array<StringImm> names of params * \return Array<StringImm> names of params
*/ */
Array<HalideIR::Expr> ListParamNames() { Array<tvm::Expr> ListParamNames() {
Array<HalideIR::Expr> ret; Array<tvm::Expr> ret;
for (const auto& kv : params_) { for (const auto& kv : params_) {
ret.push_back(ir::StringImm::make(kv.first)); ret.push_back(ir::StringImm::make(kv.first));
} }
...@@ -470,12 +444,9 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -470,12 +444,9 @@ class RelayBuildModule : public runtime::ModuleNode {
if (cfg.pass_enabled("AlterOpLayout")) { if (cfg.pass_enabled("AlterOpLayout")) {
if (targets.size() == 1) { if (targets.size() == 1) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
auto enter_pf = GetPackedFunc("_EnterTargetScope");
auto exit_pf = GetPackedFunc("_ExitTargetScope");
for (const auto& kv : targets) { for (const auto& kv : targets) {
(*enter_pf)(kv.second); TargetContext tctx(kv.second);
func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
(*exit_pf)();
} }
} else { } else {
LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous" LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous"
...@@ -487,6 +458,18 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -487,6 +458,18 @@ class RelayBuildModule : public runtime::ModuleNode {
} }
return func; return func;
} }
/*!
* \brief Create a default type.
* \param device_type The device type index.
* \return the default target for the device.
*/
Target CreateDefaultTarget(int device_type) {
std::string name = runtime::DeviceName(device_type);
if (name == "cpu") return Target::create("llvm");
if (name == "gpu") return Target::create("cuda");
return Target::create(name);
}
/*! /*!
* \brief Update the target and fallback device required for heterogeneous * \brief Update the target and fallback device required for heterogeneous
* compilation. CPU is used as the fallback device if it wasn't provided. * compilation. CPU is used as the fallback device if it wasn't provided.
...@@ -507,7 +490,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -507,7 +490,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if (tmp_map.count(cfg.fallback_device) == 0) { if (tmp_map.count(cfg.fallback_device) == 0) {
device_target.Set( device_target.Set(
cfg.fallback_device, cfg.fallback_device,
ContextTargetMap::Mask2Str(cfg.fallback_device)); CreateDefaultTarget(cfg.fallback_device));
} }
return device_target; return device_target;
} }
...@@ -520,7 +503,8 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -520,7 +503,8 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param targets_map_ptr * \param targets_map_ptr
* \return Function * \return Function
*/ */
Function RunDeviceAnnotationPass(Function func, const RelayBuildConfig& cfg, Function RunDeviceAnnotationPass(Function func,
const RelayBuildConfig& cfg,
TargetsMap* targets_map_ptr) { TargetsMap* targets_map_ptr) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func,
...@@ -532,7 +516,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -532,7 +516,7 @@ class RelayBuildModule : public runtime::ModuleNode {
"relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr); "relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr);
if (annotation_map.size() == 0) { if (annotation_map.size() == 0) {
targets_map_ptr->Set( targets_map_ptr->Set(
0, ContextTargetMap::Mask2Str(cfg.fallback_device)); 0, CreateDefaultTarget(cfg.fallback_device));
} else { } else {
int64_t dev_type = -1; int64_t dev_type = -1;
for (auto kv : annotation_map) { for (auto kv : annotation_map) {
...@@ -547,7 +531,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -547,7 +531,7 @@ class RelayBuildModule : public runtime::ModuleNode {
<< "found. Please check the " << "found. Please check the "
<< "RewriteAnnotation pass."; << "RewriteAnnotation pass.";
} }
targets_map_ptr->Set(0, ContextTargetMap::Mask2Str(dev_type)); targets_map_ptr->Set(0, CreateDefaultTarget(dev_type));
} }
} }
return func; return func;
...@@ -611,7 +595,8 @@ runtime::Module RelayBuildCreate() { ...@@ -611,7 +595,8 @@ runtime::Module RelayBuildCreate() {
return runtime::Module(exec); return runtime::Module(exec);
} }
TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) { TVM_REGISTER_GLOBAL("relay.build_module._BuildModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RelayBuildCreate(); *rv = RelayBuildCreate();
}); });
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment