Unverified Commit d756d3ca by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Introduce include/tvm/target (#4721)

As part of Unified IR infra.
Introduce target folder to store all the compilation target related information.
parent 51b5153f
...@@ -65,7 +65,7 @@ docs/_build/ ...@@ -65,7 +65,7 @@ docs/_build/
docs/gen_modules docs/gen_modules
# PyBuilder # PyBuilder
target/ /target/
# IPython Notebook # IPython Notebook
.ipynb_checkpoints .ipynb_checkpoints
......
...@@ -127,6 +127,7 @@ assign_source_group("Include" ${GROUP_INCLUDE}) ...@@ -127,6 +127,7 @@ assign_source_group("Include" ${GROUP_INCLUDE})
file(GLOB COMPILER_SRCS file(GLOB COMPILER_SRCS
src/node/*.cc src/node/*.cc
src/ir/*.cc src/ir/*.cc
src/target/*.cc
src/api/*.cc src/api/*.cc
src/arithmetic/*.cc src/arithmetic/*.cc
src/autotvm/*.cc src/autotvm/*.cc
......
...@@ -24,11 +24,14 @@ ...@@ -24,11 +24,14 @@
#ifndef TVM_BUILD_MODULE_H_ #ifndef TVM_BUILD_MODULE_H_
#define TVM_BUILD_MODULE_H_ #define TVM_BUILD_MODULE_H_
#include <tvm/target/target.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "runtime/packed_func.h" #include "runtime/packed_func.h"
#include "schedule_pass.h" #include "schedule_pass.h"
#include "lowered_func.h" #include "lowered_func.h"
...@@ -36,146 +39,6 @@ ...@@ -36,146 +39,6 @@
namespace tvm { namespace tvm {
/*! /*!
* \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
*/
class TargetNode : public Object {
public:
/*! \brief The name of the target device */
std::string target_name;
/*! \brief The name of the target device */
std::string device_name;
/*! \brief The type of the target device */
int device_type;
/*! \brief The maximum threads that a schedule should use for this device */
int max_num_threads = 1;
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
Array<PrimExpr> keys_array;
/*! \brief Options for this target */
Array<PrimExpr> options_array;
/*! \brief Collection of imported libs */
Array<PrimExpr> libs_array;
/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
void VisitAttrs(AttrVisitor* v) {
v->Visit("target_name", &target_name);
v->Visit("device_name", &device_name);
v->Visit("device_type", &device_type);
v->Visit("max_num_threads", &max_num_threads);
v->Visit("thread_warp_size", &thread_warp_size);
v->Visit("keys_array", &keys_array);
v->Visit("options_array", &options_array);
v->Visit("libs_array", &libs_array);
}
/*! \brief Get the keys for this target as a vector of string */
TVM_DLL std::vector<std::string> keys() const;
/*! \brief Get the options for this target as a vector of string */
TVM_DLL std::vector<std::string> options() const;
/*! \brief Get the keys for this target as an unordered_set of string */
TVM_DLL std::unordered_set<std::string> libs() const;
static constexpr const char* _type_key = "Target";
TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object);
private:
/*! \brief Internal string repr. */
mutable std::string str_repr_;
};
/*! \brief reference cpass to the target. */
class Target : public ObjectRef {
public:
Target() {}
explicit Target(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Create a Target given a string
* \param target_str the string to parse
*/
TVM_DLL static Target Create(const std::string& target_str);
/*!
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
* undefined Target will be returned. Otherwise, an empty context stack will cause a
* runtime error.
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
*/
TVM_DLL static tvm::Target Current(bool allow_not_defined = true);
const TargetNode* operator->() const {
return static_cast<const TargetNode*>(get());
}
using ContainerType = TargetNode;
class Internal;
private:
// enable with syntax.
friend class Internal;
friend class With<Target>;
/*!
* \brief Push a new target context onto the thread local stack.
* The Target on top of the stack is used to determine which
* specialization to use when invoking a GenericFunc.
*/
TVM_DLL void EnterWithScope();
/*!
* \brief Pop a target off the thread local context stack,
* restoring the previous target as the current context.
*/
TVM_DLL void ExitWithScope();
};
/*! \brief This namespace provides functions to construct Target instances */
namespace target {
/*! \return A target for LLVM */
TVM_DLL Target llvm(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for CUDA */
TVM_DLL Target cuda(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for ROCm */
TVM_DLL Target rocm(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for OpenCL */
TVM_DLL Target opencl(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for Metal */
TVM_DLL Target metal(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for rasp */
TVM_DLL Target rasp(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for Mali */
TVM_DLL Target mali(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for Intel Graphics */
TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for stackvm */
TVM_DLL Target stackvm(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for external device */
TVM_DLL Target ext_dev(const std::vector<std::string>& options =
std::vector<std::string>());
} // namespace target
/*!
* \brief Container for build configuration options * \brief Container for build configuration options
*/ */
class BuildConfigNode : public Object { class BuildConfigNode : public Object {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/target/target.h
* \brief Compilation target object.
*/
#ifndef TVM_TARGET_TARGET_H_
#define TVM_TARGET_TARGET_H_
#include <tvm/support/with.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <string>
#include <vector>
#include <unordered_set>
namespace tvm {
/*!
* \brief Compilation target.
* \note Use target::llvm, target::cuda etc functions.
* \sa Target
*/
class TargetNode : public Object {
public:
/*! \brief The name of the target device */
std::string target_name;
/*! \brief The name of the target device */
std::string device_name;
/*! \brief The type of the target device */
int device_type;
/*! \brief The maximum threads that a schedule should use for this device */
int max_num_threads = 1;
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
Array<PrimExpr> keys_array;
/*! \brief Options for this target */
Array<PrimExpr> options_array;
/*! \brief Collection of imported libs */
Array<PrimExpr> libs_array;
/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
void VisitAttrs(AttrVisitor* v) {
v->Visit("target_name", &target_name);
v->Visit("device_name", &device_name);
v->Visit("device_type", &device_type);
v->Visit("max_num_threads", &max_num_threads);
v->Visit("thread_warp_size", &thread_warp_size);
v->Visit("keys_array", &keys_array);
v->Visit("options_array", &options_array);
v->Visit("libs_array", &libs_array);
}
/*! \brief Get the keys for this target as a vector of string */
TVM_DLL std::vector<std::string> keys() const;
/*! \brief Get the options for this target as a vector of string */
TVM_DLL std::vector<std::string> options() const;
/*! \brief Get the keys for this target as an unordered_set of string */
TVM_DLL std::unordered_set<std::string> libs() const;
static constexpr const char* _type_key = "Target";
TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object);
private:
/*! \brief Internal string repr. */
mutable std::string str_repr_;
};
/*!
* \brief Managed reference class to TargetNode.
* \sa TargetNode
*/
class Target : public ObjectRef {
public:
Target() {}
explicit Target(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Create a Target given a string
* \param target_str the string to parse
*/
TVM_DLL static Target Create(const std::string& target_str);
/*!
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
* undefined Target will be returned. Otherwise, an empty context stack will cause a
* runtime error.
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
*/
TVM_DLL static tvm::Target Current(bool allow_not_defined = true);
const TargetNode* operator->() const {
return static_cast<const TargetNode*>(get());
}
using ContainerType = TargetNode;
class Internal;
private:
// enable with syntax.
friend class Internal;
friend class With<Target>;
/*!
* \brief Push a new target context onto the thread local stack.
* The Target on top of the stack is used to determine which
* specialization to use when invoking a GenericFunc.
*/
TVM_DLL void EnterWithScope();
/*!
* \brief Pop a target off the thread local context stack,
* restoring the previous target as the current context.
*/
TVM_DLL void ExitWithScope();
};
/*! \brief This namespace provides functions to construct Target instances */
namespace target {
/*! \return A target for LLVM */
TVM_DLL Target llvm(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for CUDA */
TVM_DLL Target cuda(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for ROCm */
TVM_DLL Target rocm(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for OpenCL */
TVM_DLL Target opencl(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for Metal */
TVM_DLL Target metal(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for rasp */
TVM_DLL Target rasp(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for Mali */
TVM_DLL Target mali(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for Intel Graphics */
TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for stackvm */
TVM_DLL Target stackvm(const std::vector<std::string>& options =
std::vector<std::string>());
/*! \return A target for external device */
TVM_DLL Target ext_dev(const std::vector<std::string>& options =
std::vector<std::string>());
} // namespace target
} // namespace tvm
#endif // TVM_TARGET_TARGET_H_
...@@ -18,14 +18,14 @@ ...@@ -18,14 +18,14 @@
*/ */
/*! /*!
* \file tvm/target_info.h * \file tvm/target/target_info.h
* \brief Various information about target. * \brief Various information about target.
*/ */
#ifndef TVM_TARGET_INFO_H_ #ifndef TVM_TARGET_TARGET_INFO_H_
#define TVM_TARGET_INFO_H_ #define TVM_TARGET_TARGET_INFO_H_
#include <tvm/ir/expr.h>
#include <string> #include <string>
#include "expr.h"
namespace tvm { namespace tvm {
...@@ -33,7 +33,8 @@ namespace tvm { ...@@ -33,7 +33,8 @@ namespace tvm {
* \brief Memory information of special memory region. * \brief Memory information of special memory region.
* Use MemoryInfo as its container type * Use MemoryInfo as its container type
*/ */
struct MemoryInfoNode : public Object { class MemoryInfoNode : public Object {
public:
/*! \brief The addressable unit */ /*! \brief The addressable unit */
int unit_bits; int unit_bits;
/*! \brief Maximum number of bits supported in the memory */ /*! \brief Maximum number of bits supported in the memory */
...@@ -71,4 +72,4 @@ class MemoryInfo : public ObjectRef { ...@@ -71,4 +72,4 @@ class MemoryInfo : public ObjectRef {
TVM_DLL MemoryInfo GetMemoryInfo(const std::string& scope); TVM_DLL MemoryInfo GetMemoryInfo(const std::string& scope);
} // namespace tvm } // namespace tvm
#endif // TVM_TARGET_INFO_H_ #endif // TVM_TARGET_TARGET_INFO_H_
...@@ -38,288 +38,8 @@ using runtime::TVMArgs; ...@@ -38,288 +38,8 @@ using runtime::TVMArgs;
using runtime::TVMRetValue; using runtime::TVMRetValue;
using runtime::PackedFunc; using runtime::PackedFunc;
TVM_REGISTER_NODE_TYPE(TargetNode);
TVM_REGISTER_NODE_TYPE(GenericFuncNode); TVM_REGISTER_NODE_TYPE(GenericFuncNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const TargetNode*>(node.get());
p->stream << op->str();
});
/*!
* \brief Construct a Target node from the given name and options.
* \param target_name The major target name. Should be one of
* {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hybrid", "llvm", "metal",
* "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"}
* \param options Additional options appended to the target
* \return The constructed Target
*/
Target CreateTarget(const std::string& target_name,
const std::vector<std::string>& options) {
auto t = make_object<TargetNode>();
t->target_name = target_name;
std::string libs_flag = "-libs=";
std::string device_flag = "-device=";
std::string keys_flag = "-keys=";
for (auto& item : options) {
t->options_array.push_back(ir::StringImmNode::make(item));
if (item.find(libs_flag) == 0) {
std::stringstream ss(item.substr(libs_flag.length()));
std::string lib_item;
while (std::getline(ss, lib_item, ',')) {
t->libs_array.push_back(ir::StringImmNode::make(lib_item));
}
} else if (item.find(device_flag) == 0) {
t->device_name = item.substr(device_flag.length());
t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
} else if (item.find(keys_flag) == 0) {
std::stringstream ss(item.substr(keys_flag.length()));
std::string key_item;
while (std::getline(ss, key_item, ',')) {
t->keys_array.push_back(ir::StringImmNode::make(key_item));
}
}
}
if (t->device_name.length() > 0) {
t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
}
t->device_type = kDLCPU;
t->thread_warp_size = 1;
if (target_name == "c" && t->device_name == "micro_dev") {
t->device_type = kDLMicroDev;
} else if (target_name == "c" || target_name == "llvm") {
t->keys_array.push_back(ir::StringImmNode::make("cpu"));
} else if (target_name == "cuda" || target_name == "nvptx") {
t->device_type = kDLGPU;
t->keys_array.push_back(ir::StringImmNode::make("cuda"));
t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 1024;
t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") {
// For now assume rocm schedule for opencl
if (target_name == "opencl") {
t->device_type = kDLOpenCL;
} else {
t->device_type = kDLROCM;
}
t->keys_array.push_back(ir::StringImmNode::make(target_name));
t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 256;
if (t->device_name == "intel_graphics") {
t->thread_warp_size = 16;
}
} else if (target_name == "metal" || target_name == "vulkan") {
if (target_name == "metal") {
t->device_type = kDLMetal;
} else {
t->device_type = kDLVulkan;
}
t->keys_array.push_back(ir::StringImmNode::make(target_name));
t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 256;
} else if (target_name == "sdaccel") {
t->device_type = kDLOpenCL;
t->keys_array.push_back(ir::StringImmNode::make("sdaccel"));
t->keys_array.push_back(ir::StringImmNode::make("hls"));
} else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
t->device_type = kDLAOCL;
t->keys_array.push_back(ir::StringImmNode::make("aocl"));
t->keys_array.push_back(ir::StringImmNode::make("hls"));
} else if (target_name == "opengl") {
t->device_type = kOpenGL;
t->keys_array.push_back(ir::StringImmNode::make("opengl"));
} else if (target_name == "stackvm") {
t->device_type = kDLCPU;
} else if (target_name == "ext_dev") {
t->device_type = kDLExtDev;
} else if (target_name == "hybrid") {
t->device_type = kDLCPU;
} else {
LOG(ERROR) << "Unknown target name " << target_name;
return target::stackvm();
}
return Target(t);
}
TVM_REGISTER_GLOBAL("_TargetCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::string target_name = args[0];
std::vector<std::string> options;
for (int i = 1; i < args.num_args; ++i) {
std::string arg = args[i];
options.push_back(arg);
}
*ret = CreateTarget(target_name, options);
});
TVM_REGISTER_GLOBAL("_TargetFromString")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::string target_str = args[0];
*ret = Target::Create(target_str);
});
std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> result;
for (auto& expr : keys_array) {
result.push_back(expr.as<ir::StringImmNode>()->value);
}
return result;
}
std::vector<std::string> TargetNode::options() const {
std::vector<std::string> result;
for (auto& expr : options_array) {
result.push_back(expr.as<ir::StringImmNode>()->value);
}
return result;
}
std::unordered_set<std::string> TargetNode::libs() const {
std::unordered_set<std::string> result;
for (auto& expr : libs_array) {
result.insert(expr.as<ir::StringImmNode>()->value);
}
return result;
}
const std::string& TargetNode::str() const {
if (str_repr_.length() != 0) return str_repr_;
std::ostringstream result;
result << target_name;
for (const auto &x : options()) {
result << " " << x;
}
str_repr_ = result.str();
return str_repr_;
}
bool StartsWith(const std::string& str, const std::string& pattern) {
return str.compare(0, pattern.length(), pattern) == 0;
}
std::string GetDeviceName(const std::string& target_str) {
std::istringstream ss(target_str);
std::string target_name;
ss >> target_name;
std::string item;
while (ss >> item) {
if (StartsWith(item, "-device=")) {
return item.substr(std::string("-device=").length());
}
}
return "";
}
Target Target::Create(const std::string& target_str) {
if (target_str.length() == 0) {
LOG(ERROR) << "target_str must not be empty";
}
std::istringstream ss(target_str);
std::string target_name;
ss >> target_name;
auto device_name = GetDeviceName(target_str);
std::vector<std::string> options;
std::string item;
while (ss >> item) {
options.push_back(item);
}
return CreateTarget(target_name, options);
}
/*! \brief Entry to hold the Target context stack. */
struct TVMTargetThreadLocalEntry {
/*! \brief The current target context */
std::stack<tvm::Target> context_stack;
};
/*! \brief Thread local store to hold the Target context stack. */
typedef dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry> TVMTargetThreadLocalStore;
void Target::EnterWithScope() {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
entry->context_stack.push(*this);
}
void Target::ExitWithScope() {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
}
tvm::Target Target::Current(bool allow_not_defined) {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
CHECK(allow_not_defined)
<< "Target context required. Please set it by constructing a TargetContext";
return Target();
}
namespace target {
std::vector<std::string> MergeOptions(std::vector<std::string> opts,
const std::vector<std::string>& new_opts) {
opts.insert(opts.end(), new_opts.begin(), new_opts.end());
return opts;
}
Target llvm(const std::vector<std::string>& options) {
return CreateTarget("llvm", options);
}
Target cuda(const std::vector<std::string>& options) {
return CreateTarget("cuda", options);
}
Target rocm(const std::vector<std::string>& options) {
return CreateTarget("rocm", options);
}
Target opencl(const std::vector<std::string>& options) {
return CreateTarget("opencl", options);
}
Target metal(const std::vector<std::string>& options) {
return CreateTarget("metal", options);
}
Target mali(const std::vector<std::string>& options) {
return CreateTarget("opencl", MergeOptions(options, {
"-device=mali"
}));
}
Target intel_graphics(const std::vector<std::string>& options) {
return CreateTarget("opencl", MergeOptions(options, {
"-device=intel_graphics"
}));
}
Target stackvm(const std::vector<std::string>& options) {
return CreateTarget("stackvm", options);
}
Target ext_dev(const std::vector<std::string>& options) {
return CreateTarget("ext_dev", options);
}
} // namespace target
bool LLVMEnabled() { bool LLVMEnabled() {
const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.build_llvm"); const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.build_llvm");
return pf != nullptr; return pf != nullptr;
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* \file storage_access.cc * \file storage_access.cc
*/ */
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/target_info.h> #include <tvm/target/target_info.h>
#include <string> #include <string>
#include <utility> #include <utility>
#include "ir_util.h" #include "ir_util.h"
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/target_info.h> #include <tvm/target/target_info.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <unordered_map> #include <unordered_map>
#include "ir_util.h" #include "ir_util.h"
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h> #include <tvm/ir_functor_ext.h>
#include <tvm/target_info.h> #include <tvm/target/target_info.h>
#include <map> #include <map>
#include <unordered_set> #include <unordered_set>
#include <unordered_map> #include <unordered_map>
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/target_info.h> #include <tvm/target/target_info.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <unordered_map> #include <unordered_map>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Compile executable modules.
* \file src/target/target.cc
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/printer.h>
#include <tvm/target/target.h>
#include <tvm/ir.h>
#include <algorithm>
#include <stack>
namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
TVM_REGISTER_NODE_TYPE(TargetNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const TargetNode*>(node.get());
p->stream << op->str();
});
/*!
* \brief Construct a Target node from the given name and options.
* \param target_name The major target name. Should be one of
* {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hybrid", "llvm", "metal",
* "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"}
* \param options Additional options appended to the target
* \return The constructed Target
*/
Target CreateTarget(const std::string& target_name,
const std::vector<std::string>& options) {
auto t = make_object<TargetNode>();
t->target_name = target_name;
std::string libs_flag = "-libs=";
std::string device_flag = "-device=";
std::string keys_flag = "-keys=";
for (auto& item : options) {
t->options_array.push_back(ir::StringImmNode::make(item));
if (item.find(libs_flag) == 0) {
std::stringstream ss(item.substr(libs_flag.length()));
std::string lib_item;
while (std::getline(ss, lib_item, ',')) {
t->libs_array.push_back(ir::StringImmNode::make(lib_item));
}
} else if (item.find(device_flag) == 0) {
t->device_name = item.substr(device_flag.length());
t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
} else if (item.find(keys_flag) == 0) {
std::stringstream ss(item.substr(keys_flag.length()));
std::string key_item;
while (std::getline(ss, key_item, ',')) {
t->keys_array.push_back(ir::StringImmNode::make(key_item));
}
}
}
if (t->device_name.length() > 0) {
t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
}
t->device_type = kDLCPU;
t->thread_warp_size = 1;
if (target_name == "c" && t->device_name == "micro_dev") {
t->device_type = kDLMicroDev;
} else if (target_name == "c" || target_name == "llvm") {
t->keys_array.push_back(ir::StringImmNode::make("cpu"));
} else if (target_name == "cuda" || target_name == "nvptx") {
t->device_type = kDLGPU;
t->keys_array.push_back(ir::StringImmNode::make("cuda"));
t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 1024;
t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") {
// For now assume rocm schedule for opencl
if (target_name == "opencl") {
t->device_type = kDLOpenCL;
} else {
t->device_type = kDLROCM;
}
t->keys_array.push_back(ir::StringImmNode::make(target_name));
t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 256;
if (t->device_name == "intel_graphics") {
t->thread_warp_size = 16;
}
} else if (target_name == "metal" || target_name == "vulkan") {
if (target_name == "metal") {
t->device_type = kDLMetal;
} else {
t->device_type = kDLVulkan;
}
t->keys_array.push_back(ir::StringImmNode::make(target_name));
t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 256;
} else if (target_name == "sdaccel") {
t->device_type = kDLOpenCL;
t->keys_array.push_back(ir::StringImmNode::make("sdaccel"));
t->keys_array.push_back(ir::StringImmNode::make("hls"));
} else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
t->device_type = kDLAOCL;
t->keys_array.push_back(ir::StringImmNode::make("aocl"));
t->keys_array.push_back(ir::StringImmNode::make("hls"));
} else if (target_name == "opengl") {
t->device_type = kOpenGL;
t->keys_array.push_back(ir::StringImmNode::make("opengl"));
} else if (target_name == "stackvm") {
t->device_type = kDLCPU;
} else if (target_name == "ext_dev") {
t->device_type = kDLExtDev;
} else if (target_name == "hybrid") {
t->device_type = kDLCPU;
} else {
LOG(ERROR) << "Unknown target name " << target_name;
return target::stackvm();
}
return Target(t);
}
TVM_REGISTER_GLOBAL("_TargetCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::string target_name = args[0];
std::vector<std::string> options;
for (int i = 1; i < args.num_args; ++i) {
std::string arg = args[i];
options.push_back(arg);
}
*ret = CreateTarget(target_name, options);
});
TVM_REGISTER_GLOBAL("_TargetFromString")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::string target_str = args[0];
*ret = Target::Create(target_str);
});
std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> result;
for (auto& expr : keys_array) {
result.push_back(expr.as<ir::StringImmNode>()->value);
}
return result;
}
std::vector<std::string> TargetNode::options() const {
std::vector<std::string> result;
for (auto& expr : options_array) {
result.push_back(expr.as<ir::StringImmNode>()->value);
}
return result;
}
std::unordered_set<std::string> TargetNode::libs() const {
std::unordered_set<std::string> result;
for (auto& expr : libs_array) {
result.insert(expr.as<ir::StringImmNode>()->value);
}
return result;
}
const std::string& TargetNode::str() const {
if (str_repr_.length() != 0) return str_repr_;
std::ostringstream result;
result << target_name;
for (const auto &x : options()) {
result << " " << x;
}
str_repr_ = result.str();
return str_repr_;
}
bool StartsWith(const std::string& str, const std::string& pattern) {
return str.compare(0, pattern.length(), pattern) == 0;
}
std::string GetDeviceName(const std::string& target_str) {
std::istringstream ss(target_str);
std::string target_name;
ss >> target_name;
std::string item;
while (ss >> item) {
if (StartsWith(item, "-device=")) {
return item.substr(std::string("-device=").length());
}
}
return "";
}
Target Target::Create(const std::string& target_str) {
if (target_str.length() == 0) {
LOG(ERROR) << "target_str must not be empty";
}
std::istringstream ss(target_str);
std::string target_name;
ss >> target_name;
auto device_name = GetDeviceName(target_str);
std::vector<std::string> options;
std::string item;
while (ss >> item) {
options.push_back(item);
}
return CreateTarget(target_name, options);
}
/*! \brief Entry to hold the Target context stack. */
struct TVMTargetThreadLocalEntry {
/*! \brief The current target context */
std::stack<tvm::Target> context_stack;
};
/*! \brief Thread local store to hold the Target context stack. */
typedef dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry> TVMTargetThreadLocalStore;
void Target::EnterWithScope() {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
entry->context_stack.push(*this);
}
void Target::ExitWithScope() {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
}
tvm::Target Target::Current(bool allow_not_defined) {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
CHECK(allow_not_defined)
<< "Target context required. Please set it by constructing a TargetContext";
return Target();
}
namespace target {
std::vector<std::string> MergeOptions(std::vector<std::string> opts,
const std::vector<std::string>& new_opts) {
opts.insert(opts.end(), new_opts.begin(), new_opts.end());
return opts;
}
Target llvm(const std::vector<std::string>& options) {
return CreateTarget("llvm", options);
}
Target cuda(const std::vector<std::string>& options) {
return CreateTarget("cuda", options);
}
Target rocm(const std::vector<std::string>& options) {
return CreateTarget("rocm", options);
}
Target opencl(const std::vector<std::string>& options) {
return CreateTarget("opencl", options);
}
Target metal(const std::vector<std::string>& options) {
return CreateTarget("metal", options);
}
Target mali(const std::vector<std::string>& options) {
return CreateTarget("opencl", MergeOptions(options, {
"-device=mali"
}));
}
Target intel_graphics(const std::vector<std::string>& options) {
return CreateTarget("opencl", MergeOptions(options, {
"-device=intel_graphics"
}));
}
Target stackvm(const std::vector<std::string>& options) {
return CreateTarget("stackvm", options);
}
Target ext_dev(const std::vector<std::string>& options) {
return CreateTarget("ext_dev", options);
}
} // namespace target
} // namespace tvm
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
*/ */
/*! /*!
* \file target_info.cc * \file target/target_info.cc
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/target_info.h> #include <tvm/node/printer.h>
#include <tvm/packed_func_ext.h> #include <tvm/target/target_info.h>
namespace tvm { namespace tvm {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment