Commit 6292204e by alex-weaver Committed by Tianqi Chen

Implement C++ registry to back Python target.generic_func (#892)

parent 6588662f
......@@ -8,81 +8,146 @@
#include <string>
#include <vector>
#include "./tvm/runtime/packed_func.h"
#include "./tvm/schedule_pass.h"
#include "./tvm/lowered_func.h"
#include "./runtime/packed_func.h"
#include "./schedule_pass.h"
#include "./lowered_func.h"
namespace tvm {
using namespace tvm::runtime;
/*!
* \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
*/
struct Target {
class TargetNode : public Node {
public:
/*! \brief The name of the target device */
std::string target_name;
/*! \brief The type of the target device */
DLDeviceType device_type;
int device_type;
/*! \brief The maximum threads that a schedule should use for this device */
int max_num_threads = 1;
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
std::unordered_set<std::string> keys;
Array<Expr> keys_array;
/*! \brief Options for this target */
std::vector<std::string> options;
/*! \brief Set of imported libs */
std::unordered_set<std::string> libs;
Target(const std::string& target_name,
DLDeviceType device_type,
int max_num_threads,
int thread_warp_size,
const std::unordered_set<std::string>& keys,
const std::vector<std::string>& options,
const std::unordered_set<std::string>& libs =
std::unordered_set<std::string>()) :
target_name(target_name),
device_type(device_type),
max_num_threads(max_num_threads),
thread_warp_size(thread_warp_size),
keys(keys),
options(options),
libs(libs) {
}
Array<Expr> options_array;
/*! \brief Collection of imported libs */
Array<Expr> libs_array;
/*! \return the full device string to pass to codegen::Build */
EXPORT std::string str() const;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("target_name", &target_name);
v->Visit("device_type", &device_type);
v->Visit("max_num_threads", &max_num_threads);
v->Visit("thread_warp_size", &thread_warp_size);
v->Visit("keys_array", &keys_array);
v->Visit("options_array", &options_array);
v->Visit("libs_array", &libs_array);
}
/*! \brief Get the keys for this target as a vector of string */
EXPORT std::vector<std::string> keys() const;
/*! \brief Get the options for this target as a vector of string */
EXPORT std::vector<std::string> options() const;
/*! \brief Get the keys for this target as an unordered_set of string */
EXPORT std::unordered_set<std::string> libs() const;
static constexpr const char* _type_key = "Target";
TVM_DECLARE_NODE_TYPE_INFO(TargetNode, Node);
};
class Target : public NodeRef {
public:
Target() {}
explicit Target(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief Create a Target given a string
* \param target_str the string to parse
*/
EXPORT static Target create(const std::string& target_str);
/*!
* \brief Push a new target context onto the thread local stack. The Target on top of
* the stack is used to determine which specialization to use when invoking a GenericFunc.
* \param target The target to set as the current context.
*/
EXPORT static void EnterTargetScope(const tvm::Target& target);
/*!
* \brief Pop a target off the thread local context stack, restoring the previous target
* as the current context.
*/
EXPORT static void ExitTargetScope();
/*!
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
* undefined Target will be returned. Otherwise, an empty context stack will cause a
* runtime error.
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
*/
EXPORT static tvm::Target current_target(bool allow_not_defined = true);
inline const TargetNode* operator->() const {
return static_cast<const TargetNode*>(node_.get());
}
using ContainerType = TargetNode;
};
/*!
* \brief RAII container to provide a scoped target context. Pushes a target onto the
* context stack when constructed, and pops it when destructed.
*/
struct TargetContext {
/*!
* \brief Enter a new target context. The given target becomes the new current context.
* When the TargetContext is destructed, the previous context is restored.
* \param target The target to set as the new current context.
*/
explicit TargetContext(const tvm::Target& target) {
Target::EnterTargetScope(target);
}
/*! \brief Destructor. Pops the context off the thread local stack. */
~TargetContext() {
Target::ExitTargetScope();
}
};
/*! \brief This namespace provides functions to construct Target instances */
namespace target {
/*! \return A target for LLVM */
EXPORT Target llvm();
EXPORT Target llvm(const std::unordered_set<std::string>& options = {});
/*! \return A target for CUDA */
EXPORT Target cuda();
EXPORT Target cuda(const std::unordered_set<std::string>& options = {});
/*! \return A target for ROCm */
EXPORT Target rocm();
EXPORT Target rocm(const std::unordered_set<std::string>& options = {});
/*! \return A target for OpenCL */
EXPORT Target opencl(const std::unordered_set<std::string>& options = {});
/*! \return A target for Metal */
EXPORT Target metal();
EXPORT Target metal(const std::unordered_set<std::string>& options = {});
/*! \return A target for rasp */
EXPORT Target rasp();
EXPORT Target rasp(const std::unordered_set<std::string>& options = {});
/*! \return A target for Mali */
EXPORT Target mali();
EXPORT Target mali(const std::unordered_set<std::string>& options = {});
/*! \return A target for stackvm */
EXPORT Target stackvm();
EXPORT Target stackvm(const std::unordered_set<std::string>& options = {});
} // namespace target
......@@ -174,15 +239,147 @@ EXPORT Array<LoweredFunc> lower(Schedule sch,
* \brief Build a device and host module for a specific target from an array of lowered functions.
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. If null, a suitable default will be used.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The built module.
*/
EXPORT runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
Target* target_host,
const Target& target_host,
const BuildConfig& config);
class GenericFuncNode;
/*!
* \brief Generic function that can be specialized on a per-target basis.
*/
class GenericFunc : public NodeRef {
public:
GenericFunc() {}
explicit GenericFunc(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief Set the default function implementaiton.
* \param value The default function
* \param allow_override If true, this call may override a previously registered function. If
* false, an error will be logged if the call would override a previously registered function.
* \return reference to self.
*/
TVM_DLL GenericFunc& set_default(const PackedFunc value,
bool allow_override = false);
/*!
* \brief Register a specialized function
* \param tags The tags for this specialization
* \param value The specialized function
* \param allow_override If true, this call may override previously registered tags. If false,
* an error will be logged if the call would override previously registered tags.
* \return reference to self.
*/
TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
const PackedFunc value,
bool allow_override = false);
/*!
* \brief Call generic function by directly passing in unpacked format.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
*
* \code
* // Example code on how to call generic function
* void CallGeneirc(GenericFunc f) {
* // call like normal functions by pass in arguments
* // return value is automatically converted back
* int rvalue = f(1, 2.0);
* }
* \endcode
*/
template<typename... Args>
inline TVMRetValue operator()(Args&& ...args) const;
/*!
* \brief Invoke the relevant function for the current target context, set by set_target_context.
* Arguments are passed in packed format.
* \param args The arguments to pass to the function.
* \param ret The return value
*/
TVM_DLL void CallPacked(TVMArgs args, TVMRetValue* ret) const;
/*!
* \brief Find or register the GenericFunc instance corresponding to the give name
* \param name The name of the registered GenericFunc
* \return The GenericFunc instance
*/
TVM_DLL static GenericFunc Get(const std::string& name);
/*!
* \brief Add a GenericFunc instance to the registry
* \param func The GenericFunc instance
* \param name The name of the registered GenericFunc
*/
TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline GenericFuncNode* operator->();
// declare container type
using ContainerType = GenericFuncNode;
// Internal class.
struct Manager;
private:
friend struct Manager;
};
template<typename... Args>
inline TVMRetValue GenericFunc::operator()(Args&& ...args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
TVMRetValue rv;
CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);
return rv;
}
/*!
* \brief Represents a generic function that can be specialized on a per-target basis.
*/
class GenericFuncNode : public Node {
public:
/*! \brief name of the function */
std::string name_;
/* \brief the generic builder */
PackedFunc generic_func_;
/* \brief map from keys to registered functions */
std::unordered_map<std::string, PackedFunc> dispatch_dict_;
static constexpr const char* _type_key = "GenericFunc";
TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
};
inline GenericFuncNode* GenericFunc::operator->() {
return static_cast<GenericFuncNode*>(node_.get());
}
#define TVM_GENERIC_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM
/*!
* \def TVM_REGISTER_GENERIC_FUNC
* \brief Register a new generic function, or set a device-specific variant
* of the corresponding function.
*
* \param name The name of the function
*/
#define TVM_REGISTER_GENERIC_FUNC(name) \
TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::GenericFunc::Get(#name)
} // namespace tvm
#endif // TVM_BUILD_MODULE_H_
......@@ -40,8 +40,9 @@ We can also use other specific function in this module to create specific target
"""
from __future__ import absolute_import
import warnings
from ._ffi.base import _LIB_NAME
from ._ffi.node import NodeBase, register_node
from . import _api_internal
try:
from decorator import decorate
......@@ -62,17 +63,10 @@ def _merge_opts(opts, new_opts):
return opts
class Target(object):
@register_node
class Target(NodeBase):
"""Target device information, use through TVM API.
Parameters
----------
target_name : {"llvm", "cuda", "opencl", "metal", "rocm", "stackvm", "opengl", "ext_dev"}
The major target name.
options : list of str, optional
Additional arguments appended to the target.
Note
----
Do not use class constructor, you can create target using the following functions
......@@ -83,68 +77,190 @@ class Target(object):
- :any:`tvm.target.rocm` create ROCM target
- :any:`tvm.target.mali` create Mali target
"""
current = None
def __init__(self,
target_name,
options=None):
self.target_name = target_name
self.options = _merge_opts([], options)
self.device_name = ""
self.libs = []
# Parse device option
for item in self.options:
if item.startswith("-libs="):
libs = item.split("=")[1]
self.libs += libs.split(",")
elif item.startswith("-device="):
self.device_name = item.split("=")[1]
# Target query searches device name first
if self.device_name:
self.keys = (self.device_name,)
else:
self.keys = ()
# Target configuration handling
self.thread_warp_size = 1
if target_name in ("llvm", ):
self.keys += ("cpu",)
elif target_name in ("cuda", "nvptx"):
self.keys += ("cuda", "gpu")
self.max_num_threads = 512
self.thread_warp_size = 32
elif target_name in ("rocm", "opencl"):
# For now assume rocm schedule for opencl
self.keys += ("rocm", "gpu")
self.max_num_threads = 256
elif target_name in ("metal", "vulkan"):
self.keys += (target_name, "gpu",)
self.max_num_threads = 256
elif target_name in ("opengl",):
self.keys += ("opengl",)
elif target_name in ("stackvm", "ext_dev"):
# Do not now class for stackvm or ext_dev
pass
else:
raise ValueError("Unknown target name %s" % target_name)
def __str__(self):
return " ".join([self.target_name] + self.options)
def __repr__(self):
return self.__str__()
def __init__(self, handle):
super(Target, self).__init__(handle)
self._keys = None
self._options = None
self._libs = None
@property
def keys(self):
if not self._keys:
self._keys = [k.value for k in self.keys_array]
return self._keys
@property
def options(self):
if not self._options:
self._options = [o.value for o in self.options_array]
return self._options
@property
def libs(self):
if not self._libs:
self._libs = [l.value for l in self.libs_array]
return self._libs
def __enter__(self):
self._old_target = Target.current
if self._old_target is not None and str(self) != str(self._old_target):
warnings.warn(
"Override target '%s' with new target scope '%s'" % (
self._old_target, self))
Target.current = self
_api_internal._EnterTargetScope(self)
return self
def __exit__(self, ptype, value, trace):
Target.current = self._old_target
_api_internal._ExitTargetScope()
@register_node
class GenericFunc(NodeBase):
"""GenericFunc node reference. This represents a generic function
that may be specialized for different targets. When this object is
called, a specialization is chosen based on the current target.
Note
----
Do not construct an instance of this object, it should only ever be
used as a return value from calling into C++.
"""
def __call__(self, *args):
return _api_internal._GenericFuncCallFunc(self, *args)
def set_default(self, func, allow_override=False):
"""Set the default function to be used if no specializations match
the current target.
Parameters
----------
func : function
The default function
allow_override : bool
Whether to allow the current default to be overridden
"""
_api_internal._GenericFuncSetDefault(self, func, allow_override)
def register(self, func, key_list, allow_override=False):
"""Register a specialization for this GenericFunc.
Parameters
----------
func : function
The function to be registered.
key : str or list of str
The key to be registered.
allow_override : bool, optional
Whether to allow existing keys to be overridden.
"""
key_list = [key_list] if isinstance(key_list, str) else key_list
_api_internal._GenericFuncRegisterFunc(self, func, key_list, allow_override)
def get_native_generic_func(name):
"""Get a generic function from the global registry. If no
function is registered under the given name, a new generic
function is created.
Parameters
----------
name : string
The name of the generic function to get
Returns
-------
func : GenericFunc
The generic function for the given name
"""
return _api_internal._GenericFuncGetGlobal(name)
def override_native_generic_func(func_name):
"""Override a generic function defined in C++
Generic function allows registration of further functions
that can be dispatched on current target context.
If no registered dispatch is matched, the fdefault will be called.
Parameters
----------
func_name : string
The name of the generic func to be overridden
Returns
-------
fgeneric : function
A wrapped generic function.
Example
-------
.. code-block:: python
import tvm
# wrap function as target generic
@tvm.target.override_native_generic_func("my_func")
def my_func(a):
return a + 1
# register specialization of my_func under target cuda
@my_func.register("cuda")
def my_func_cuda(a):
return a + 2
# displays 3, because my_func is called
print(my_func(2))
# displays 4, because my_func_cuda is called
with tvm.target.cuda():
print(my_func(2))
"""
generic_func_node = get_native_generic_func(func_name)
def fdecorate(fdefault):
"""Wrap a target generic function, overriding the previous
default that was set for the generic function.
Parameters
----------
fdefault : function
The default function.
Returns
-------
fgeneric : function
A wrapped generic function.
"""
generic_func_node.set_default(fdefault, allow_override=True)
def register(key, func=None, override=True):
"""Register function to be the dispatch function.
Parameters
----------
key : str or list of str
The key to be registered.
func : function
The function to be registered.
override : bool, optional
Whether override existing registration.
Returns
-------
The register function is necessary.
"""
def _do_reg(myf):
generic_func_node.register(myf, key, override)
return myf
if func:
return _do_reg(func)
return _do_reg
def dispatch_func(func, *args, **kwargs):
#pylint: disable=unused-argument
"""The wrapped dispath function"""
if kwargs:
raise RuntimeError(
"Keyword arguments cannot be used when invoking generic_func %s" % func_name)
return generic_func_node(*args)
fresult = decorate(fdefault, dispatch_func)
fresult.register = register
return fresult
return fdecorate
def generic_func(fdefault):
"""Wrap a target generic function.
......@@ -228,7 +344,6 @@ def generic_func(fdefault):
fdecorate.register = register
return fdecorate
def cuda(options=None):
"""Returns a cuda target.
......@@ -237,7 +352,8 @@ def cuda(options=None):
options : list of str
Additional options
"""
return Target("cuda", options)
options = options if options else []
return _api_internal._TargetCreate("cuda", *options)
def rocm(options=None):
......@@ -248,7 +364,8 @@ def rocm(options=None):
options : list of str
Additional options
"""
return Target("rocm", options)
options = options if options else []
return _api_internal._TargetCreate("rocm", *options)
def rasp(options=None):
......@@ -264,7 +381,7 @@ def rasp(options=None):
"-mcpu=cortex-a53",
"-mattr=+neon"]
opts = _merge_opts(opts, options)
return Target("llvm", opts)
return _api_internal._TargetCreate("llvm", *opts)
def mali(options=None):
......@@ -277,7 +394,7 @@ def mali(options=None):
"""
opts = ["-device=mali"]
opts = _merge_opts(opts, options)
return Target("opencl", opts)
return _api_internal._TargetCreate("opencl", *opts)
def opengl(options=None):
......@@ -288,7 +405,8 @@ def opengl(options=None):
options : list of str
Additional options
"""
return Target("opengl", options)
options = options if options else []
return _api_internal._TargetCreate("opengl", *options)
def create(target_str):
......@@ -312,17 +430,8 @@ def create(target_str):
return target_str
if not isinstance(target_str, str):
raise ValueError("target_str has to be string type")
arr = target_str.split()
# Parse device option
device_name = ""
for item in arr[1:]:
if item.startswith("-device="):
device_name = item.split("=")[1]
if device_name == "rasp":
return rasp(arr[1:])
if device_name == "mali":
return mali(arr[1:])
return Target(arr[0], arr[1:])
return _api_internal._TargetFromString(target_str)
def current_target(allow_none=True):
......@@ -337,10 +446,5 @@ def current_target(allow_none=True):
------
ValueError if current target is not set.
"""
if Target.current:
return Target.current
if not allow_none:
raise RuntimeError(
"Requires a current target in generic function, but it is not set. "
"Please set it using `with TargetObject:`")
return Target.current
target_str = _api_internal._GetCurrentTarget(allow_none)
return create(target_str) if target_str is not None else None
......@@ -3,40 +3,147 @@
* Compile executable modules.
* \file build_module.cc
*/
#include <dmlc/thread_local.h>
#include <tvm/build_module.h>
#include <tvm/operation.h>
#include <tvm/ir_pass.h>
#include <tvm/codegen.h>
#include <algorithm>
#include <mutex>
#include <stack>
namespace tvm {
std::string Target::str() const {
TVM_REGISTER_NODE_TYPE(TargetNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TargetNode>([](const TargetNode *op, IRPrinter *p) {
p->stream << op->str();
});
/*!
* \brief Construct a Target node from the given name and options.
* \param target_name The major target name. Should be one of
* {"llvm", "cuda", "opencl", "metal", "rocm", "stackvm", "opengl", "ext_dev"}
* \param options Additional options appended to the target
* \return The constructed Target
*/
Target CreateTarget(const std::string& target_name,
const std::unordered_set<std::string>& options) {
auto target = Target(std::make_shared<TargetNode>());
auto t = static_cast<TargetNode*>(target.node_.get());
t->target_name = target_name;
std::string device_name = "";
std::string libs_flag = "-libs=";
std::string device_flag = "-device=";
for (auto& item : options) {
t->options_array.push_back(ir::StringImm::make(item));
if (item.find(libs_flag) == 0) {
std::stringstream ss(item.substr(libs_flag.length()));
std::string lib_item;
while (std::getline(ss, lib_item, ',')) {
t->libs_array.push_back(ir::StringImm::make(lib_item));
}
} else if (item.find(device_flag) == 0) {
device_name = item.substr(device_flag.length());
}
}
if (device_name.length() > 0) {
t->keys_array.push_back(ir::StringImm::make(device_name));
}
t->device_type = kDLCPU;
t->thread_warp_size = 1;
if (target_name == "llvm") {
t->keys_array.push_back(ir::StringImm::make("cpu"));
} else if (target_name == "cuda" || target_name == "nvptx") {
t->device_type = kDLGPU;
t->keys_array.push_back(ir::StringImm::make("cuda"));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 512;
t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") {
// For now assume rocm schedule for opencl
t->device_type = static_cast<int>(target_name == "rocm" ? kDLROCM : kDLOpenCL);
t->keys_array.push_back(ir::StringImm::make("rocm"));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 256;
} else if (target_name == "metal" || target_name == "vulkan") {
t->device_type = static_cast<int>(target_name == "metal" ? kDLMetal : kDLVulkan);
t->keys_array.push_back(ir::StringImm::make(target_name));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 256;
} else if (target_name == "opengl") {
t->device_type = kDLGPU;
t->keys_array.push_back(ir::StringImm::make("opengl"));
} else if (target_name == "stackvm" || target_name == "ext_dev") {
} else {
LOG(ERROR) << "Unknown target name " << target_name;
return target::stackvm();
}
return target;
}
TVM_REGISTER_API("_TargetCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::string target_name = args[0];
std::unordered_set<std::string> options;
for (int i = 1; i < args.num_args; ++i) {
std::string arg = args[i];
options.insert(arg);
}
*ret = CreateTarget(target_name, options);
});
TVM_REGISTER_API("_TargetFromString")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::string target_str = args[0];
*ret = Target::create(target_str);
});
std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> result;
for (auto& expr : keys_array) {
result.push_back(expr.as<ir::StringImm>()->value);
}
return result;
}
std::vector<std::string> TargetNode::options() const {
std::vector<std::string> result;
for (auto& expr : options_array) {
result.push_back(expr.as<ir::StringImm>()->value);
}
return result;
}
std::unordered_set<std::string> TargetNode::libs() const {
std::unordered_set<std::string> result;
for (auto& expr : libs_array) {
result.insert(expr.as<ir::StringImm>()->value);
}
return result;
}
std::string TargetNode::str() const {
std::ostringstream result;
result << target_name;
for (const auto &x : options) {
for (const auto &x : options()) {
result << " " << x;
}
return result.str();
}
Target TargetFromName(const std::string& name) {
if (name == "llvm") {
return target::llvm();
} else if (name == "cuda" || name == "nvptx") {
return target::cuda();
} else if (name == "rocm" || name == "opencl") {
/* For now, assume rocm schedule for opencl */
return target::rocm();
} else if (name == "metal") {
return target::metal();
} else if (name == "stackvm" || name == "ext_dev") {
return target::stackvm();
} else {
LOG(ERROR) << "Unknown target name " << name;
return target::stackvm();
}
}
bool StartsWith(const std::string& str, const std::string& pattern) {
return str.compare(0, pattern.length(), pattern) == 0;
......@@ -68,74 +175,99 @@ Target Target::create(const std::string& target_str) {
ss >> target_name;
auto device_name = GetDeviceName(target_str);
auto result = device_name == "rasp" ?
target::rasp() :
(device_name == "mali" ? target::mali() :
TargetFromName(target_name));
std::unordered_set<std::string> options;
std::string item;
while (ss >> item) {
result.options.push_back(item);
options.insert(item);
}
return result;
if (device_name == "rasp") {
return target::rasp(options);
} else if (device_name == "mail") {
return target::mali(options);
} else {
return CreateTarget(target_name, options);
}
}
/*! \brief Entry to hold the Target context stack. */
struct TVMTargetThreadLocalEntry {
/*! \brief The current target context */
std::stack<tvm::Target> context_stack;
TVMTargetThreadLocalEntry() {
}
};
/*! \brief Thread local store to hold the Target context stack. */
typedef dmlc::ThreadLocalStore<TVMTargetThreadLocalEntry> TVMTargetThreadLocalStore;
void Target::EnterTargetScope(const tvm::Target& target) {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
entry->context_stack.push(target);
}
void Target::ExitTargetScope() {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
entry->context_stack.pop();
}
tvm::Target Target::current_target(bool allow_not_defined) {
TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
CHECK(allow_not_defined)
<< "Target context required. Please set it by constructing a TargetContext";
return Target();
}
namespace target {
Target llvm() {
std::unordered_set<std::string> keys({ "llvm", "cpu" });
std::vector<std::string> options;
return Target("llvm", kDLCPU, 512, 1, keys, options,
std::unordered_set<std::string>());
std::unordered_set<std::string> MergeOptions(std::unordered_set<std::string> opts,
const std::unordered_set<std::string>& new_opts) {
opts.insert(new_opts.begin(), new_opts.end());
return opts;
}
Target llvm(const std::unordered_set<std::string>& options) {
return CreateTarget("llvm", options);
}
Target cuda() {
std::unordered_set<std::string> keys({ "cuda", "gpu" });
std::vector<std::string> options;
return Target("cuda", kDLGPU, 512, 32, keys, options,
std::unordered_set<std::string>());
Target cuda(const std::unordered_set<std::string>& options) {
return CreateTarget("cuda", options);
}
Target rocm() {
std::unordered_set<std::string> keys({ "rocm", "gpu" });
std::vector<std::string> options;
return Target("rocm", kDLROCM, 256, 1, keys, options,
std::unordered_set<std::string>());
Target rocm(const std::unordered_set<std::string>& options) {
return CreateTarget("rocm", options);
}
Target metal() {
std::unordered_set<std::string> keys({ "gpu" });
std::vector<std::string> options;
return Target("metal", kDLMetal, 256, 1, keys, options,
std::unordered_set<std::string>());
Target opencl(const std::unordered_set<std::string>& options) {
return CreateTarget("opencl", options);
}
Target rasp() {
std::unordered_set<std::string> keys({ "llvm", "cpu" });
std::vector<std::string> options({
Target metal(const std::unordered_set<std::string>& options) {
return CreateTarget("metal", options);
}
Target rasp(const std::unordered_set<std::string>& options) {
return CreateTarget("llvm", MergeOptions(options, {
"-device=rasp",
"-mtriple=armv7l-none-linux-gnueabihf",
"-mcpu=cortex-a53",
"-mattr=+neon"
});
return Target("llvm", kDLCPU, 512, 1, keys, options,
std::unordered_set<std::string>());
}));
}
Target mali() {
std::unordered_set<std::string> keys({ "rocm", "gpu" });
std::vector<std::string> options({
Target mali(const std::unordered_set<std::string>& options) {
return CreateTarget("opencl", MergeOptions(options, {
"-device=mali"
});
return Target("opencl", kDLOpenCL, 256, 1, keys, options);
}));
}
Target stackvm() {
std::unordered_set<std::string> keys({ "stackvm", "cpu" });
std::vector<std::string> options;
return Target("stackvm", kDLCPU, 512, 1, keys, options,
std::unordered_set<std::string>());
Target stackvm(const std::unordered_set<std::string>& options) {
return CreateTarget("stackvm", options);
}
} // namespace target
......@@ -146,7 +278,7 @@ bool LLVMEnabled() {
/*! \return The default host target for a given device target */
Target DefaultTargetHost(Target target) {
if (target.device_type == kDLCPU) {
if (target->device_type == kDLCPU) {
return target;
} else {
if (LLVMEnabled()) {
......@@ -254,7 +386,7 @@ Array<LoweredFunc> lower(Schedule sch,
runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
Target* target_host,
const Target& target_host,
const BuildConfig& config) {
std::unordered_set<std::string> all_names;
for (const auto &x : funcs) {
......@@ -262,15 +394,13 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
all_names.insert(x->name);
}
Target target_host_val = target_host == nullptr ?
DefaultTargetHost(target) :
*target_host;
auto target_host_val = target_host.defined() ? target_host : DefaultTargetHost(target);
Array<LoweredFunc> fhost;
Array<LoweredFunc> fdevice;
for (const auto& x : funcs) {
CHECK(ir::VerifyMemory(x, target.device_type))
CHECK(ir::VerifyMemory(x, target->device_type))
<< "Direct host side access to device memory is detected in " << x->func_name()
<< ". Did you forget to bind?";
......@@ -281,7 +411,7 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
}
func = ir::ThreadSync(func, "shared");
func = ir::LowerThreadAllreduce(func, target.thread_warp_size);
func = ir::LowerThreadAllreduce(func, target->thread_warp_size);
auto fsplits = ir::SplitHostDevice(func);
fhost.push_back(fsplits[0]);
for (auto f = fsplits.begin() + 1; f != fsplits.end(); ++f) {
......@@ -296,14 +426,17 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
}
}
if (target.keys.count("gpu") > 0 && fdevice.size() == 0) {
LOG(WARNING) << "Specified target " + target.str() +
auto keys = target->keys();
bool target_is_gpu =
std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && fdevice.size() == 0) {
LOG(WARNING) << "Specified target " + target->str() +
" but cannot find device code. Did you forget to bind?";
}
for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = ir::BindDeviceType(func, target.device_type);
func = ir::BindDeviceType(func, target->device_type);
func = ir::LowerTVMBuiltin(func);
fhost.Set(i, func);
}
......@@ -311,21 +444,21 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
for (size_t i = 0; i < fdevice.size(); ++i) {
auto func = fdevice[i];
func = ir::LowerIntrin(func, target.target_name);
func = ir::LowerIntrin(func, target->target_name);
fdevice.Set(i, func);
}
for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = ir::LowerIntrin(func, target_host_val.target_name);
func = ir::LowerIntrin(func, target_host_val->target_name);
func = ir::CombineContextCall(func);
fhost.Set(i, func);
}
auto mhost = codegen::Build(fhost, target_host_val.str());
auto mhost = codegen::Build(fhost, target_host_val->str());
if (fdevice.size() > 0) {
auto mdev = codegen::Build(fdevice, target.str());
auto mdev = codegen::Build(fdevice, target->str());
mhost.Import(mdev);
}
......@@ -354,4 +487,160 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ")";
});
struct GenericFunc::Manager {
std::unordered_map<std::string, std::shared_ptr<Node> > fmap;
// mutex
std::mutex mutex;
Manager() {
}
static Manager* Global() {
static Manager inst;
return &inst;
}
};
GenericFunc GenericFunc::Get(const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name);
if (it == m->fmap.end()) {
auto f = std::make_shared<GenericFuncNode>();
f->name_ = name;
m->fmap[name] = f;
return GenericFunc(f);
} else {
return GenericFunc(it->second);
}
}
void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name);
CHECK(it == m->fmap.end()) << "GenericFunc already registered " << name;
func->name_ = name;
m->fmap[name] = func.node_;
}
GenericFunc& GenericFunc::set_default(const PackedFunc value,
bool allow_override) {
auto node = static_cast<GenericFuncNode*>(node_.get());
if (!allow_override) {
CHECK(node->generic_func_ == nullptr)
<< "Generic function already registered for " << node->name_;
}
node->generic_func_ = value;
return *this;
}
GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags,
const PackedFunc value,
bool allow_override) {
for (auto &t : tags) {
if (!allow_override) {
auto iter = (*this)->dispatch_dict_.find(t);
CHECK(iter == (*this)->dispatch_dict_.end())
<< "Tag " << t << " already registered for schedule factory " << (*this)->name_;
}
(*this)->dispatch_dict_[t] = value;
}
return *this;
}
void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
auto node = static_cast<GenericFuncNode*>(node_.get());
auto target = Target::current_target(true);
PackedFunc func;
if (target.defined()) {
for (auto &k : target->keys()) {
auto iter = node->dispatch_dict_.find(k);
if (iter != node->dispatch_dict_.end()) {
func = iter->second;
break;
}
}
}
if (func == nullptr) {
CHECK(node->generic_func_ != nullptr) << "No generic function registered for " << node->name_;
func = node->generic_func_;
}
func.CallPacked(args, ret);
}
TVM_REGISTER_API("_GenericFuncCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GenericFunc(std::make_shared<GenericFuncNode>());
});
TVM_REGISTER_API("_GenericFuncGetGlobal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::string func_name = args[0];
*ret = GenericFunc::Get(func_name);
});
TVM_REGISTER_API("_GenericFuncSetDefault")
.set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
bool allow_override = args[2];
generic_func
.set_default(*func, allow_override);
});
TVM_REGISTER_API("_GenericFuncRegisterFunc")
.set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
Array<Expr> tags = args[2];
bool allow_override = args[3];
std::vector<std::string> tags_vector;
for (auto& tag : tags) {
tags_vector.push_back(tag.as<tvm::ir::StringImm>()->value);
}
generic_func
.register_func(tags_vector, *func, allow_override);
});
TVM_REGISTER_API("_GenericFuncCallFunc")
.set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0];
TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1);
generic_func
.CallPacked(func_args, ret);
});
TVM_REGISTER_API("_GetCurrentTarget")
.set_body([](TVMArgs args, TVMRetValue* ret) {
bool allow_not_defined = args[0];
*ret = Target::current_target(allow_not_defined);
});
TVM_REGISTER_API("_EnterTargetScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Target target = args[0];
auto current = Target::current_target();
if (current.defined() && target->str() != current->str()) {
LOG(WARNING) << "Overriding target " << current->str()
<< " with new target scope " << target->str();
}
Target::EnterTargetScope(target);
});
TVM_REGISTER_API("_ExitTargetScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Target::ExitTargetScope();
});
} // namespace tvm
......@@ -6,6 +6,7 @@
#include <tvm/runtime/threading_backend.h>
#include <dmlc/logging.h>
#include <thread>
#include <algorithm>
#if defined(__linux__)
#include <sched.h>
#endif
......
......@@ -31,7 +31,7 @@ TEST(BuildModule, Basic) {
auto target = target::llvm();
auto lowered = lower(s, args, "func", binds, config);
auto module = build(lowered, target, nullptr, config);
auto module = build(lowered, target, Target(), config);
}
......
......@@ -34,11 +34,16 @@ def test_target_dispatch():
with tvm.target.create("metal"):
assert mygeneric(1) == 3
try:
mygeneric(0)
raise RuntimeError("not reached")
except RuntimeError:
pass
assert tvm.target.current_target() == None
def test_target_string_parse():
target = tvm.target.create("cuda -libs=cublas,cudnn")
assert target.target_name == "cuda"
assert target.options == ['-libs=cublas,cudnn']
assert target.keys == ['cuda', 'gpu']
assert target.libs == ['cublas', 'cudnn']
if __name__ == "__main__":
test_target_dispatch()
test_target_string_parse()
......@@ -24,31 +24,30 @@ namespace cuda {
* \param target The target device
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim] (optional)
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense_cuda(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor* bias) {
const tvm::Tensor& bias) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias != nullptr) {
CHECK_EQ((*bias)->shape.size(), 1) << "dense requires 1-D bias";
if (bias.defined()) {
CHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias";
}
auto batch = data->shape[0];
auto in_dim = data->shape[1];
auto out_dim = weight->shape[0];
if (target.libs.count("cublas") > 0) {
if (target->libs().count("cublas")) {
auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
if (bias != nullptr) {
auto bias_val = *bias;
if (bias.defined()) {
mm = tvm::compute({ batch, out_dim },
[&](Var i, Var j) {
return mm(i, j) + bias_val(j);
return mm(i, j) + bias(j);
}, "tensor", kBroadcast);
}
......@@ -67,8 +66,8 @@ inline tvm::Tensor dense_cuda(const Target& target,
* \return A schedule for the given ops.
*/
inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs) {
if (target.target_name == "cuda" &&
target.libs.count("cublas") > 0) {
if (target->target_name == "cuda" &&
target->libs().count("cublas")) {
return topi::generic::schedule_extern(target, outs);
}
......
......@@ -28,7 +28,7 @@ namespace cuda {
inline Schedule ScheduleOutputForExtern(Target target, Operation op, Schedule sch) {
auto x = op.output(0);
auto fused = detail::Fuse(sch[x], sch[x]->op.as<ComputeOpNode>()->axis);
auto num_thread = target.max_num_threads;
auto num_thread = target->max_num_threads;
IterVar bx, tx;
sch[x].split(fused, num_thread, &bx, &tx);
sch[x].bind(bx, tvm::thread_axis(Range(), "blockIdx.x"));
......
......@@ -25,7 +25,7 @@ namespace cuda {
inline void ScheduleInjectiveOp(const Target &target, Operation op, Schedule s) {
auto x = op.output(0);
auto fused = detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
auto num_thread = target.max_num_threads;
auto num_thread = target->max_num_threads;
IterVar bx, tx;
s[x].split(fused, num_thread, &bx, &tx);
s[x].bind(bx, thread_axis(Range(), "blockIdx.x"));
......
......@@ -34,7 +34,7 @@ inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
auto _schedule = [&](const Tensor& padded_input, const Tensor& pool) {
s[padded_input].compute_inline();
auto num_thread = target.max_num_threads;
auto num_thread = target->max_num_threads;
Tensor out;
Tensor OL;
if (detail::contains(s->outputs, pool->op)) {
......
......@@ -51,7 +51,7 @@ Schedule ScheduleReduce(const Target& target,
if (out_stage->op.as<ComputeOpNode>()->axis.size() > 0) {
all_reduce = false;
num_thread = 32;
if (target.target_name == "opencl") {
if (target->target_name == "opencl") {
// Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests.
// Don't know why.
num_thread = 16;
......@@ -61,7 +61,7 @@ Schedule ScheduleReduce(const Target& target,
thread_y = tvm::thread_axis(Range(0, num_thread), "threadIdx.y");
} else {
all_reduce = true;
num_thread = target.max_num_threads;
num_thread = target->max_num_threads;
thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
}
......
......@@ -20,17 +20,17 @@ using namespace tvm;
*
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim] (optional)
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense(const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor* bias) {
const tvm::Tensor& bias) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias != nullptr) {
CHECK_EQ((*bias)->shape.size(), 1) << "dense requires 1-D bias";
if (bias.defined()) {
CHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias";
}
auto batch = data->shape[0];
......@@ -44,12 +44,11 @@ inline tvm::Tensor dense(const tvm::Tensor& data,
return tvm::sum(data(i, k) * weight(j, k), { k });
}, "tensor", "dense");
if (bias != nullptr) {
auto bias_val = *bias;
if (bias.defined()) {
matmul = tvm::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return matmul(i, j) + bias_val(j);
return matmul(i, j) + bias(j);
}, "tensor", kBroadcast);
}
......
......@@ -25,31 +25,30 @@ namespace rocm {
* \param target The target device
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim] (optional)
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense_rocm(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor* bias) {
const tvm::Tensor& bias) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias != nullptr) {
CHECK_EQ((*bias)->shape.size(), 1) << "dense requires 1-D bias";
if (bias.defined()) {
CHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias";
}
auto batch = data->shape[0];
auto in_dim = data->shape[1];
auto out_dim = weight->shape[0];
if (target.libs.count("rocblas") > 0) {
if (target->libs().count("rocblas")) {
auto mm = topi::contrib::rocblas_matmul(data, weight, false, true);
if (bias != nullptr) {
auto bias_val = *bias;
if (bias.defined()) {
mm = tvm::compute({ batch, out_dim },
[&](Var i, Var j) {
return mm(i, j) + bias_val(j);
return mm(i, j) + bias(j);
}, "tensor", kBroadcast);
}
......@@ -68,8 +67,8 @@ inline tvm::Tensor dense_rocm(const Target& target,
* \return A schedule for the given ops.
*/
inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs) {
if (target.target_name == "rocm" &&
target.libs.count("rocblas") > 0) {
if (target->target_name == "rocm" &&
target->libs().count("rocblas")) {
return topi::generic::schedule_extern(target, outs);
}
......
......@@ -11,6 +11,10 @@ from __future__ import absolute_import as _abs
from tvm._ffi.libinfo import __version__
# Ensure C++ schedules get registered first, so python schedules can
# override them.
from . import cpp
from .math import *
from .tensor import *
from .reduction import *
......@@ -24,7 +28,6 @@ from . import mali
from . import opengl
from . import util
from . import rocm
from . import cpp
from . import vision
# not import testing by default
# because testing can have extra deps that are not necessary
......
......@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
import tvm
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_injective")
def schedule_injective(outs):
"""Schedule for injective op.
......
......@@ -106,7 +106,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
return _default_schedule(outs, False)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_reduce")
def schedule_reduce(outs):
"""Schedule for reduction
......@@ -124,7 +124,7 @@ def schedule_reduce(outs):
return _default_schedule(outs, True)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_softmax")
def schedule_softmax(outs):
"""Schedule for softmax
......@@ -142,7 +142,7 @@ def schedule_softmax(outs):
return _default_schedule(outs, False)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_dense")
def schedule_dense(outs):
"""Schedule for dense
......@@ -160,7 +160,7 @@ def schedule_dense(outs):
return _default_schedule(outs, False)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_pool")
def schedule_pool(outs):
"""Schedule for pool
......@@ -178,7 +178,7 @@ def schedule_pool(outs):
return _default_schedule(outs, False)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_global_pool")
def schedule_global_pool(outs):
"""Schedule for global pool
......@@ -195,7 +195,7 @@ def schedule_global_pool(outs):
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_binarize_pack")
def schedule_binarize_pack(outs):
"""Schedule for binarize_pack
......@@ -213,7 +213,7 @@ def schedule_binarize_pack(outs):
return _default_schedule(outs, False)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_binary_dense")
def schedule_binary_dense(outs):
"""Schedule for binary_dense
......
......@@ -39,7 +39,7 @@ def dense_default(data, weight, bias=None):
return matmul
@tvm.target.generic_func
@tvm.target.override_native_generic_func("dense")
def dense(data, weight, bias=None):
"""Applies a linear transformation: :math:`Y = XW^T + b`.
......
......@@ -51,6 +51,7 @@ struct extension_class_info<tvm::Target> {
} // namespace runtime
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
......@@ -281,15 +282,7 @@ TVM_REGISTER_GLOBAL("topi.nn.binary_dense")
/* Ops from nn/dense.h */
TVM_REGISTER_GLOBAL("topi.nn.dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Tensor bias_val;
Tensor *bias;
if (args[2].type_code() == kNull) {
bias = nullptr;
} else {
bias_val = args[2];
bias = &bias_val;
}
*rv = nn::dense(args[0], args[1], bias);
*rv = nn::dense(args[0], args[1], args[2]);
});
/* Ops from nn/dilate.h */
......@@ -388,15 +381,7 @@ TVM_REGISTER_GLOBAL("topi.x86.schedule_injective")
/* ROCm schedules */
TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Tensor bias_val;
Tensor *bias;
if (args[3].type_code() == kNull) {
bias = nullptr;
} else {
bias_val = args[3];
bias = &bias_val;
}
*rv = rocm::dense_rocm(args[0], args[1], args[2], bias);
*rv = rocm::dense_rocm(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
......@@ -407,15 +392,7 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
/* CUDA schedules */
TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Tensor bias_val;
Tensor *bias;
if (args[3].type_code() == kNull) {
bias = nullptr;
} else {
bias_val = args[3];
bias = &bias_val;
}
*rv = cuda::dense_cuda(args[0], args[1], args[2], bias);
*rv = cuda::dense_cuda(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense")
......@@ -453,4 +430,106 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax")
*rv = topi::cuda::schedule_softmax(args[0], args[1]);
});
/*! \brief Builder function for instantiating schedules. */
using FTVMScheduleBuilder = std::function<
tvm::Schedule(const tvm::Target& target, const tvm::Array<tvm::Tensor>& outs)>;
/*!
* \brief Helper function for registering generic functions matching the
* FTVMScheduleBuilder signature. The schedule builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The schedule builder to wrap.
*
* \return The wrapped schedule builder
*/
inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::current_target(false);
Array<Tensor> outs;
NodeRef argNodeRef = args[0];
if (argNodeRef->type_index() == outs->type_index()) {
outs = args[0];
} else {
outs = Array<Tensor> { args[0] };
}
*ret = builder(target, outs);
});
}
TVM_REGISTER_GENERIC_FUNC(schedule_injective)
.set_default(WrapSchedule(topi::generic::schedule_injective))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_injective))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_injective));
TVM_REGISTER_GENERIC_FUNC(schedule_softmax)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_softmax));
TVM_REGISTER_GENERIC_FUNC(schedule_dense)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense))
.register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense));
TVM_REGISTER_GENERIC_FUNC(schedule_pool)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_pool));
TVM_REGISTER_GENERIC_FUNC(schedule_global_pool)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool));
TVM_REGISTER_GENERIC_FUNC(schedule_reduce)
.set_default(WrapSchedule(topi::generic::default_schedule_auto_inline))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule_auto_inline))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_reduce));
TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binarize_pack));
TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense));
/*! \brief Builder function for instantiating dense ops. */
using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias)>;
/*!
* \brief Helper function for registering dense ops matching the
* FTVMDenseOpBuilder signature. The op builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The op builder to wrap.
*
* \return The wrapped op builder
*/
inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::current_target(false);
Tensor data = args[0];
Tensor weight = args[1];
Tensor bias = args[2];
*ret = builder(target, data, weight, bias);
});
}
TVM_REGISTER_GENERIC_FUNC(dense)
.set_default(WrapDenseOp([](const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias) {
return topi::nn::dense(data, weight, bias);
}))
.register_func({ "cuda", "gpu" }, WrapDenseOp(topi::cuda::dense_cuda))
.register_func({ "rocm" }, WrapDenseOp(topi::rocm::dense_rocm));
} // namespace topi
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment