Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6292204e
Commit
6292204e
authored
Mar 19, 2018
by
alex-weaver
Committed by
Tianqi Chen
Mar 19, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Implement C++ registry to back Python target.generic_func (#892)
parent
6588662f
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
951 additions
and
276 deletions
+951
-276
include/tvm/build_module.h
+232
-35
python/tvm/target.py
+194
-90
src/codegen/build_module.cc
+365
-76
src/runtime/threading_backend.cc
+1
-0
tests/cpp/build_module_test.cc
+1
-1
tests/python/unittest/test_lang_target.py
+10
-5
topi/include/topi/cuda/dense.h
+9
-10
topi/include/topi/cuda/extern.h
+1
-1
topi/include/topi/cuda/injective.h
+1
-1
topi/include/topi/cuda/pooling.h
+1
-1
topi/include/topi/cuda/reduction.h
+2
-2
topi/include/topi/nn/dense.h
+6
-7
topi/include/topi/rocm/dense.h
+9
-10
topi/python/topi/__init__.py
+4
-1
topi/python/topi/generic/injective.py
+1
-1
topi/python/topi/generic/nn.py
+7
-7
topi/python/topi/nn/dense.py
+1
-1
topi/src/topi.cc
+106
-27
No files found.
include/tvm/build_module.h
View file @
6292204e
...
...
@@ -8,81 +8,146 @@
#include <string>
#include <vector>
#include "./
tvm/
runtime/packed_func.h"
#include "./
tvm/
schedule_pass.h"
#include "./
tvm/
lowered_func.h"
#include "./runtime/packed_func.h"
#include "./schedule_pass.h"
#include "./lowered_func.h"
namespace
tvm
{
using
namespace
tvm
::
runtime
;
/*!
* \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
*/
struct
Target
{
class
TargetNode
:
public
Node
{
public
:
/*! \brief The name of the target device */
std
::
string
target_name
;
/*! \brief The type of the target device */
DLDeviceType
device_type
;
int
device_type
;
/*! \brief The maximum threads that a schedule should use for this device */
int
max_num_threads
=
1
;
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int
thread_warp_size
=
1
;
/*! \brief Keys for this target */
std
::
unordered_set
<
std
::
string
>
keys
;
Array
<
Expr
>
keys_array
;
/*! \brief Options for this target */
std
::
vector
<
std
::
string
>
options
;
/*! \brief Set of imported libs */
std
::
unordered_set
<
std
::
string
>
libs
;
Target
(
const
std
::
string
&
target_name
,
DLDeviceType
device_type
,
int
max_num_threads
,
int
thread_warp_size
,
const
std
::
unordered_set
<
std
::
string
>&
keys
,
const
std
::
vector
<
std
::
string
>&
options
,
const
std
::
unordered_set
<
std
::
string
>&
libs
=
std
::
unordered_set
<
std
::
string
>
())
:
target_name
(
target_name
),
device_type
(
device_type
),
max_num_threads
(
max_num_threads
),
thread_warp_size
(
thread_warp_size
),
keys
(
keys
),
options
(
options
),
libs
(
libs
)
{
}
Array
<
Expr
>
options_array
;
/*! \brief Collection of imported libs */
Array
<
Expr
>
libs_array
;
/*! \return the full device string to pass to codegen::Build */
EXPORT
std
::
string
str
()
const
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"target_name"
,
&
target_name
);
v
->
Visit
(
"device_type"
,
&
device_type
);
v
->
Visit
(
"max_num_threads"
,
&
max_num_threads
);
v
->
Visit
(
"thread_warp_size"
,
&
thread_warp_size
);
v
->
Visit
(
"keys_array"
,
&
keys_array
);
v
->
Visit
(
"options_array"
,
&
options_array
);
v
->
Visit
(
"libs_array"
,
&
libs_array
);
}
/*! \brief Get the keys for this target as a vector of string */
EXPORT
std
::
vector
<
std
::
string
>
keys
()
const
;
/*! \brief Get the options for this target as a vector of string */
EXPORT
std
::
vector
<
std
::
string
>
options
()
const
;
/*! \brief Get the keys for this target as an unordered_set of string */
EXPORT
std
::
unordered_set
<
std
::
string
>
libs
()
const
;
static
constexpr
const
char
*
_type_key
=
"Target"
;
TVM_DECLARE_NODE_TYPE_INFO
(
TargetNode
,
Node
);
};
class
Target
:
public
NodeRef
{
public
:
Target
()
{}
explicit
Target
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
* \brief Create a Target given a string
* \param target_str the string to parse
*/
EXPORT
static
Target
create
(
const
std
::
string
&
target_str
);
/*!
* \brief Push a new target context onto the thread local stack. The Target on top of
* the stack is used to determine which specialization to use when invoking a GenericFunc.
* \param target The target to set as the current context.
*/
EXPORT
static
void
EnterTargetScope
(
const
tvm
::
Target
&
target
);
/*!
* \brief Pop a target off the thread local context stack, restoring the previous target
* as the current context.
*/
EXPORT
static
void
ExitTargetScope
();
/*!
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
* undefined Target will be returned. Otherwise, an empty context stack will cause a
* runtime error.
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
*/
EXPORT
static
tvm
::
Target
current_target
(
bool
allow_not_defined
=
true
);
inline
const
TargetNode
*
operator
->
()
const
{
return
static_cast
<
const
TargetNode
*>
(
node_
.
get
());
}
using
ContainerType
=
TargetNode
;
};
/*!
* \brief RAII container to provide a scoped target context. Pushes a target onto the
* context stack when constructed, and pops it when destructed.
*/
struct
TargetContext
{
/*!
* \brief Enter a new target context. The given target becomes the new current context.
* When the TargetContext is destructed, the previous context is restored.
* \param target The target to set as the new current context.
*/
explicit
TargetContext
(
const
tvm
::
Target
&
target
)
{
Target
::
EnterTargetScope
(
target
);
}
/*! \brief Destructor. Pops the context off the thread local stack. */
~
TargetContext
()
{
Target
::
ExitTargetScope
();
}
};
/*! \brief This namespace provides functions to construct Target instances */
namespace
target
{
/*! \return A target for LLVM */
EXPORT
Target
llvm
();
EXPORT
Target
llvm
(
const
std
::
unordered_set
<
std
::
string
>&
options
=
{}
);
/*! \return A target for CUDA */
EXPORT
Target
cuda
();
EXPORT
Target
cuda
(
const
std
::
unordered_set
<
std
::
string
>&
options
=
{}
);
/*! \return A target for ROCm */
EXPORT
Target
rocm
();
EXPORT
Target
rocm
(
const
std
::
unordered_set
<
std
::
string
>&
options
=
{});
/*! \return A target for OpenCL */
EXPORT
Target
opencl
(
const
std
::
unordered_set
<
std
::
string
>&
options
=
{});
/*! \return A target for Metal */
EXPORT
Target
metal
();
EXPORT
Target
metal
(
const
std
::
unordered_set
<
std
::
string
>&
options
=
{}
);
/*! \return A target for rasp */
EXPORT
Target
rasp
();
EXPORT
Target
rasp
(
const
std
::
unordered_set
<
std
::
string
>&
options
=
{}
);
/*! \return A target for Mali */
EXPORT
Target
mali
();
EXPORT
Target
mali
(
const
std
::
unordered_set
<
std
::
string
>&
options
=
{}
);
/*! \return A target for stackvm */
EXPORT
Target
stackvm
();
EXPORT
Target
stackvm
(
const
std
::
unordered_set
<
std
::
string
>&
options
=
{}
);
}
// namespace target
...
...
@@ -174,15 +239,147 @@ EXPORT Array<LoweredFunc> lower(Schedule sch,
* \brief Build a device and host module for a specific target from an array of lowered functions.
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code.
If null, a suitable default will be used.
* \param target_host The target for building host code.
To use the default, pass Target()
* \param config The build configuration.
* \return The built module.
*/
EXPORT
runtime
::
Module
build
(
const
Array
<
LoweredFunc
>&
funcs
,
const
Target
&
target
,
Target
*
target_host
,
const
Target
&
target_host
,
const
BuildConfig
&
config
);
class
GenericFuncNode
;
/*!
* \brief Generic function that can be specialized on a per-target basis.
*/
class
GenericFunc
:
public
NodeRef
{
public
:
GenericFunc
()
{}
explicit
GenericFunc
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
* \brief Set the default function implementaiton.
* \param value The default function
* \param allow_override If true, this call may override a previously registered function. If
* false, an error will be logged if the call would override a previously registered function.
* \return reference to self.
*/
TVM_DLL
GenericFunc
&
set_default
(
const
PackedFunc
value
,
bool
allow_override
=
false
);
/*!
* \brief Register a specialized function
* \param tags The tags for this specialization
* \param value The specialized function
* \param allow_override If true, this call may override previously registered tags. If false,
* an error will be logged if the call would override previously registered tags.
* \return reference to self.
*/
TVM_DLL
GenericFunc
&
register_func
(
const
std
::
vector
<
std
::
string
>&
tags
,
const
PackedFunc
value
,
bool
allow_override
=
false
);
/*!
* \brief Call generic function by directly passing in unpacked format.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
*
* \code
* // Example code on how to call generic function
* void CallGeneirc(GenericFunc f) {
* // call like normal functions by pass in arguments
* // return value is automatically converted back
* int rvalue = f(1, 2.0);
* }
* \endcode
*/
template
<
typename
...
Args
>
inline
TVMRetValue
operator
()(
Args
&&
...
args
)
const
;
/*!
* \brief Invoke the relevant function for the current target context, set by set_target_context.
* Arguments are passed in packed format.
* \param args The arguments to pass to the function.
* \param ret The return value
*/
TVM_DLL
void
CallPacked
(
TVMArgs
args
,
TVMRetValue
*
ret
)
const
;
/*!
* \brief Find or register the GenericFunc instance corresponding to the give name
* \param name The name of the registered GenericFunc
* \return The GenericFunc instance
*/
TVM_DLL
static
GenericFunc
Get
(
const
std
::
string
&
name
);
/*!
* \brief Add a GenericFunc instance to the registry
* \param func The GenericFunc instance
* \param name The name of the registered GenericFunc
*/
TVM_DLL
static
void
RegisterGenericFunc
(
GenericFunc
func
,
const
std
::
string
&
name
);
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
GenericFuncNode
*
operator
->
();
// declare container type
using
ContainerType
=
GenericFuncNode
;
// Internal class.
struct
Manager
;
private
:
friend
struct
Manager
;
};
template
<
typename
...
Args
>
inline
TVMRetValue
GenericFunc
::
operator
()(
Args
&&
...
args
)
const
{
const
int
kNumArgs
=
sizeof
...(
Args
);
const
int
kArraySize
=
kNumArgs
>
0
?
kNumArgs
:
1
;
TVMValue
values
[
kArraySize
];
int
type_codes
[
kArraySize
];
detail
::
for_each
(
TVMArgsSetter
(
values
,
type_codes
),
std
::
forward
<
Args
>
(
args
)...);
TVMRetValue
rv
;
CallPacked
(
TVMArgs
(
values
,
type_codes
,
kNumArgs
),
&
rv
);
return
rv
;
}
/*!
* \brief Represents a generic function that can be specialized on a per-target basis.
*/
class
GenericFuncNode
:
public
Node
{
public
:
/*! \brief name of the function */
std
::
string
name_
;
/* \brief the generic builder */
PackedFunc
generic_func_
;
/* \brief map from keys to registered functions */
std
::
unordered_map
<
std
::
string
,
PackedFunc
>
dispatch_dict_
;
static
constexpr
const
char
*
_type_key
=
"GenericFunc"
;
TVM_DECLARE_NODE_TYPE_INFO
(
GenericFuncNode
,
Node
);
};
inline
GenericFuncNode
*
GenericFunc
::
operator
->
()
{
return
static_cast
<
GenericFuncNode
*>
(
node_
.
get
());
}
#define TVM_GENERIC_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM
/*!
* \def TVM_REGISTER_GENERIC_FUNC
* \brief Register a new generic function, or set a device-specific variant
* of the corresponding function.
*
* \param name The name of the function
*/
#define TVM_REGISTER_GENERIC_FUNC(name) \
TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::GenericFunc::Get(#name)
}
// namespace tvm
#endif // TVM_BUILD_MODULE_H_
python/tvm/target.py
View file @
6292204e
...
...
@@ -40,8 +40,9 @@ We can also use other specific function in this module to create specific target
"""
from
__future__
import
absolute_import
import
warnings
from
._ffi.base
import
_LIB_NAME
from
._ffi.node
import
NodeBase
,
register_node
from
.
import
_api_internal
try
:
from
decorator
import
decorate
...
...
@@ -62,17 +63,10 @@ def _merge_opts(opts, new_opts):
return
opts
class
Target
(
object
):
@register_node
class
Target
(
NodeBase
):
"""Target device information, use through TVM API.
Parameters
----------
target_name : {"llvm", "cuda", "opencl", "metal", "rocm", "stackvm", "opengl", "ext_dev"}
The major target name.
options : list of str, optional
Additional arguments appended to the target.
Note
----
Do not use class constructor, you can create target using the following functions
...
...
@@ -83,68 +77,190 @@ class Target(object):
- :any:`tvm.target.rocm` create ROCM target
- :any:`tvm.target.mali` create Mali target
"""
current
=
None
def
__init__
(
self
,
target_name
,
options
=
None
):
self
.
target_name
=
target_name
self
.
options
=
_merge_opts
([],
options
)
self
.
device_name
=
""
self
.
libs
=
[]
# Parse device option
for
item
in
self
.
options
:
if
item
.
startswith
(
"-libs="
):
libs
=
item
.
split
(
"="
)[
1
]
self
.
libs
+=
libs
.
split
(
","
)
elif
item
.
startswith
(
"-device="
):
self
.
device_name
=
item
.
split
(
"="
)[
1
]
# Target query searches device name first
if
self
.
device_name
:
self
.
keys
=
(
self
.
device_name
,)
else
:
self
.
keys
=
()
# Target configuration handling
self
.
thread_warp_size
=
1
if
target_name
in
(
"llvm"
,
):
self
.
keys
+=
(
"cpu"
,)
elif
target_name
in
(
"cuda"
,
"nvptx"
):
self
.
keys
+=
(
"cuda"
,
"gpu"
)
self
.
max_num_threads
=
512
self
.
thread_warp_size
=
32
elif
target_name
in
(
"rocm"
,
"opencl"
):
# For now assume rocm schedule for opencl
self
.
keys
+=
(
"rocm"
,
"gpu"
)
self
.
max_num_threads
=
256
elif
target_name
in
(
"metal"
,
"vulkan"
):
self
.
keys
+=
(
target_name
,
"gpu"
,)
self
.
max_num_threads
=
256
elif
target_name
in
(
"opengl"
,):
self
.
keys
+=
(
"opengl"
,)
elif
target_name
in
(
"stackvm"
,
"ext_dev"
):
# Do not now class for stackvm or ext_dev
pass
else
:
raise
ValueError
(
"Unknown target name
%
s"
%
target_name
)
def
__str__
(
self
):
return
" "
.
join
([
self
.
target_name
]
+
self
.
options
)
def
__repr__
(
self
):
return
self
.
__str__
()
def
__init__
(
self
,
handle
):
super
(
Target
,
self
)
.
__init__
(
handle
)
self
.
_keys
=
None
self
.
_options
=
None
self
.
_libs
=
None
@property
def
keys
(
self
):
if
not
self
.
_keys
:
self
.
_keys
=
[
k
.
value
for
k
in
self
.
keys_array
]
return
self
.
_keys
@property
def
options
(
self
):
if
not
self
.
_options
:
self
.
_options
=
[
o
.
value
for
o
in
self
.
options_array
]
return
self
.
_options
@property
def
libs
(
self
):
if
not
self
.
_libs
:
self
.
_libs
=
[
l
.
value
for
l
in
self
.
libs_array
]
return
self
.
_libs
def
__enter__
(
self
):
self
.
_old_target
=
Target
.
current
if
self
.
_old_target
is
not
None
and
str
(
self
)
!=
str
(
self
.
_old_target
):
warnings
.
warn
(
"Override target '
%
s' with new target scope '
%
s'"
%
(
self
.
_old_target
,
self
))
Target
.
current
=
self
_api_internal
.
_EnterTargetScope
(
self
)
return
self
def
__exit__
(
self
,
ptype
,
value
,
trace
):
Target
.
current
=
self
.
_old_target
_api_internal
.
_ExitTargetScope
()
@register_node
class
GenericFunc
(
NodeBase
):
"""GenericFunc node reference. This represents a generic function
that may be specialized for different targets. When this object is
called, a specialization is chosen based on the current target.
Note
----
Do not construct an instance of this object, it should only ever be
used as a return value from calling into C++.
"""
def
__call__
(
self
,
*
args
):
return
_api_internal
.
_GenericFuncCallFunc
(
self
,
*
args
)
def
set_default
(
self
,
func
,
allow_override
=
False
):
"""Set the default function to be used if no specializations match
the current target.
Parameters
----------
func : function
The default function
allow_override : bool
Whether to allow the current default to be overridden
"""
_api_internal
.
_GenericFuncSetDefault
(
self
,
func
,
allow_override
)
def
register
(
self
,
func
,
key_list
,
allow_override
=
False
):
"""Register a specialization for this GenericFunc.
Parameters
----------
func : function
The function to be registered.
key : str or list of str
The key to be registered.
allow_override : bool, optional
Whether to allow existing keys to be overridden.
"""
key_list
=
[
key_list
]
if
isinstance
(
key_list
,
str
)
else
key_list
_api_internal
.
_GenericFuncRegisterFunc
(
self
,
func
,
key_list
,
allow_override
)
def
get_native_generic_func
(
name
):
"""Get a generic function from the global registry. If no
function is registered under the given name, a new generic
function is created.
Parameters
----------
name : string
The name of the generic function to get
Returns
-------
func : GenericFunc
The generic function for the given name
"""
return
_api_internal
.
_GenericFuncGetGlobal
(
name
)
def
override_native_generic_func
(
func_name
):
"""Override a generic function defined in C++
Generic function allows registration of further functions
that can be dispatched on current target context.
If no registered dispatch is matched, the fdefault will be called.
Parameters
----------
func_name : string
The name of the generic func to be overridden
Returns
-------
fgeneric : function
A wrapped generic function.
Example
-------
.. code-block:: python
import tvm
# wrap function as target generic
@tvm.target.override_native_generic_func("my_func")
def my_func(a):
return a + 1
# register specialization of my_func under target cuda
@my_func.register("cuda")
def my_func_cuda(a):
return a + 2
# displays 3, because my_func is called
print(my_func(2))
# displays 4, because my_func_cuda is called
with tvm.target.cuda():
print(my_func(2))
"""
generic_func_node
=
get_native_generic_func
(
func_name
)
def
fdecorate
(
fdefault
):
"""Wrap a target generic function, overriding the previous
default that was set for the generic function.
Parameters
----------
fdefault : function
The default function.
Returns
-------
fgeneric : function
A wrapped generic function.
"""
generic_func_node
.
set_default
(
fdefault
,
allow_override
=
True
)
def
register
(
key
,
func
=
None
,
override
=
True
):
"""Register function to be the dispatch function.
Parameters
----------
key : str or list of str
The key to be registered.
func : function
The function to be registered.
override : bool, optional
Whether override existing registration.
Returns
-------
The register function is necessary.
"""
def
_do_reg
(
myf
):
generic_func_node
.
register
(
myf
,
key
,
override
)
return
myf
if
func
:
return
_do_reg
(
func
)
return
_do_reg
def
dispatch_func
(
func
,
*
args
,
**
kwargs
):
#pylint: disable=unused-argument
"""The wrapped dispath function"""
if
kwargs
:
raise
RuntimeError
(
"Keyword arguments cannot be used when invoking generic_func
%
s"
%
func_name
)
return
generic_func_node
(
*
args
)
fresult
=
decorate
(
fdefault
,
dispatch_func
)
fresult
.
register
=
register
return
fresult
return
fdecorate
def
generic_func
(
fdefault
):
"""Wrap a target generic function.
...
...
@@ -228,7 +344,6 @@ def generic_func(fdefault):
fdecorate
.
register
=
register
return
fdecorate
def
cuda
(
options
=
None
):
"""Returns a cuda target.
...
...
@@ -237,7 +352,8 @@ def cuda(options=None):
options : list of str
Additional options
"""
return
Target
(
"cuda"
,
options
)
options
=
options
if
options
else
[]
return
_api_internal
.
_TargetCreate
(
"cuda"
,
*
options
)
def
rocm
(
options
=
None
):
...
...
@@ -248,7 +364,8 @@ def rocm(options=None):
options : list of str
Additional options
"""
return
Target
(
"rocm"
,
options
)
options
=
options
if
options
else
[]
return
_api_internal
.
_TargetCreate
(
"rocm"
,
*
options
)
def
rasp
(
options
=
None
):
...
...
@@ -264,7 +381,7 @@ def rasp(options=None):
"-mcpu=cortex-a53"
,
"-mattr=+neon"
]
opts
=
_merge_opts
(
opts
,
options
)
return
Target
(
"llvm"
,
opts
)
return
_api_internal
.
_TargetCreate
(
"llvm"
,
*
opts
)
def
mali
(
options
=
None
):
...
...
@@ -277,7 +394,7 @@ def mali(options=None):
"""
opts
=
[
"-device=mali"
]
opts
=
_merge_opts
(
opts
,
options
)
return
Target
(
"opencl"
,
opts
)
return
_api_internal
.
_TargetCreate
(
"opencl"
,
*
opts
)
def
opengl
(
options
=
None
):
...
...
@@ -288,7 +405,8 @@ def opengl(options=None):
options : list of str
Additional options
"""
return
Target
(
"opengl"
,
options
)
options
=
options
if
options
else
[]
return
_api_internal
.
_TargetCreate
(
"opengl"
,
*
options
)
def
create
(
target_str
):
...
...
@@ -312,17 +430,8 @@ def create(target_str):
return
target_str
if
not
isinstance
(
target_str
,
str
):
raise
ValueError
(
"target_str has to be string type"
)
arr
=
target_str
.
split
()
# Parse device option
device_name
=
""
for
item
in
arr
[
1
:]:
if
item
.
startswith
(
"-device="
):
device_name
=
item
.
split
(
"="
)[
1
]
if
device_name
==
"rasp"
:
return
rasp
(
arr
[
1
:])
if
device_name
==
"mali"
:
return
mali
(
arr
[
1
:])
return
Target
(
arr
[
0
],
arr
[
1
:])
return
_api_internal
.
_TargetFromString
(
target_str
)
def
current_target
(
allow_none
=
True
):
...
...
@@ -337,10 +446,5 @@ def current_target(allow_none=True):
------
ValueError if current target is not set.
"""
if
Target
.
current
:
return
Target
.
current
if
not
allow_none
:
raise
RuntimeError
(
"Requires a current target in generic function, but it is not set. "
"Please set it using `with TargetObject:`"
)
return
Target
.
current
target_str
=
_api_internal
.
_GetCurrentTarget
(
allow_none
)
return
create
(
target_str
)
if
target_str
is
not
None
else
None
src/codegen/build_module.cc
View file @
6292204e
...
...
@@ -3,40 +3,147 @@
* Compile executable modules.
* \file build_module.cc
*/
#include <dmlc/thread_local.h>
#include <tvm/build_module.h>
#include <tvm/operation.h>
#include <tvm/ir_pass.h>
#include <tvm/codegen.h>
#include <algorithm>
#include <mutex>
#include <stack>
namespace
tvm
{
std
::
string
Target
::
str
()
const
{
TVM_REGISTER_NODE_TYPE
(
TargetNode
);
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
TargetNode
>
([](
const
TargetNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
op
->
str
();
});
/*!
* \brief Construct a Target node from the given name and options.
* \param target_name The major target name. Should be one of
* {"llvm", "cuda", "opencl", "metal", "rocm", "stackvm", "opengl", "ext_dev"}
* \param options Additional options appended to the target
* \return The constructed Target
*/
Target
CreateTarget
(
const
std
::
string
&
target_name
,
const
std
::
unordered_set
<
std
::
string
>&
options
)
{
auto
target
=
Target
(
std
::
make_shared
<
TargetNode
>
());
auto
t
=
static_cast
<
TargetNode
*>
(
target
.
node_
.
get
());
t
->
target_name
=
target_name
;
std
::
string
device_name
=
""
;
std
::
string
libs_flag
=
"-libs="
;
std
::
string
device_flag
=
"-device="
;
for
(
auto
&
item
:
options
)
{
t
->
options_array
.
push_back
(
ir
::
StringImm
::
make
(
item
));
if
(
item
.
find
(
libs_flag
)
==
0
)
{
std
::
stringstream
ss
(
item
.
substr
(
libs_flag
.
length
()));
std
::
string
lib_item
;
while
(
std
::
getline
(
ss
,
lib_item
,
','
))
{
t
->
libs_array
.
push_back
(
ir
::
StringImm
::
make
(
lib_item
));
}
}
else
if
(
item
.
find
(
device_flag
)
==
0
)
{
device_name
=
item
.
substr
(
device_flag
.
length
());
}
}
if
(
device_name
.
length
()
>
0
)
{
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
device_name
));
}
t
->
device_type
=
kDLCPU
;
t
->
thread_warp_size
=
1
;
if
(
target_name
==
"llvm"
)
{
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"cpu"
));
}
else
if
(
target_name
==
"cuda"
||
target_name
==
"nvptx"
)
{
t
->
device_type
=
kDLGPU
;
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"cuda"
));
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"gpu"
));
t
->
max_num_threads
=
512
;
t
->
thread_warp_size
=
32
;
}
else
if
(
target_name
==
"rocm"
||
target_name
==
"opencl"
)
{
// For now assume rocm schedule for opencl
t
->
device_type
=
static_cast
<
int
>
(
target_name
==
"rocm"
?
kDLROCM
:
kDLOpenCL
);
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"rocm"
));
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"gpu"
));
t
->
max_num_threads
=
256
;
}
else
if
(
target_name
==
"metal"
||
target_name
==
"vulkan"
)
{
t
->
device_type
=
static_cast
<
int
>
(
target_name
==
"metal"
?
kDLMetal
:
kDLVulkan
);
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
target_name
));
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"gpu"
));
t
->
max_num_threads
=
256
;
}
else
if
(
target_name
==
"opengl"
)
{
t
->
device_type
=
kDLGPU
;
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"opengl"
));
}
else
if
(
target_name
==
"stackvm"
||
target_name
==
"ext_dev"
)
{
}
else
{
LOG
(
ERROR
)
<<
"Unknown target name "
<<
target_name
;
return
target
::
stackvm
();
}
return
target
;
}
TVM_REGISTER_API
(
"_TargetCreate"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
std
::
string
target_name
=
args
[
0
];
std
::
unordered_set
<
std
::
string
>
options
;
for
(
int
i
=
1
;
i
<
args
.
num_args
;
++
i
)
{
std
::
string
arg
=
args
[
i
];
options
.
insert
(
arg
);
}
*
ret
=
CreateTarget
(
target_name
,
options
);
});
TVM_REGISTER_API
(
"_TargetFromString"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
std
::
string
target_str
=
args
[
0
];
*
ret
=
Target
::
create
(
target_str
);
});
std
::
vector
<
std
::
string
>
TargetNode
::
keys
()
const
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
expr
:
keys_array
)
{
result
.
push_back
(
expr
.
as
<
ir
::
StringImm
>
()
->
value
);
}
return
result
;
}
std
::
vector
<
std
::
string
>
TargetNode
::
options
()
const
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
expr
:
options_array
)
{
result
.
push_back
(
expr
.
as
<
ir
::
StringImm
>
()
->
value
);
}
return
result
;
}
std
::
unordered_set
<
std
::
string
>
TargetNode
::
libs
()
const
{
std
::
unordered_set
<
std
::
string
>
result
;
for
(
auto
&
expr
:
libs_array
)
{
result
.
insert
(
expr
.
as
<
ir
::
StringImm
>
()
->
value
);
}
return
result
;
}
std
::
string
TargetNode
::
str
()
const
{
std
::
ostringstream
result
;
result
<<
target_name
;
for
(
const
auto
&
x
:
options
)
{
for
(
const
auto
&
x
:
options
()
)
{
result
<<
" "
<<
x
;
}
return
result
.
str
();
}
Target
TargetFromName
(
const
std
::
string
&
name
)
{
if
(
name
==
"llvm"
)
{
return
target
::
llvm
();
}
else
if
(
name
==
"cuda"
||
name
==
"nvptx"
)
{
return
target
::
cuda
();
}
else
if
(
name
==
"rocm"
||
name
==
"opencl"
)
{
/* For now, assume rocm schedule for opencl */
return
target
::
rocm
();
}
else
if
(
name
==
"metal"
)
{
return
target
::
metal
();
}
else
if
(
name
==
"stackvm"
||
name
==
"ext_dev"
)
{
return
target
::
stackvm
();
}
else
{
LOG
(
ERROR
)
<<
"Unknown target name "
<<
name
;
return
target
::
stackvm
();
}
}
bool
StartsWith
(
const
std
::
string
&
str
,
const
std
::
string
&
pattern
)
{
return
str
.
compare
(
0
,
pattern
.
length
(),
pattern
)
==
0
;
...
...
@@ -68,74 +175,99 @@ Target Target::create(const std::string& target_str) {
ss
>>
target_name
;
auto
device_name
=
GetDeviceName
(
target_str
);
auto
result
=
device_name
==
"rasp"
?
target
::
rasp
()
:
(
device_name
==
"mali"
?
target
::
mali
()
:
TargetFromName
(
target_name
));
std
::
unordered_set
<
std
::
string
>
options
;
std
::
string
item
;
while
(
ss
>>
item
)
{
result
.
options
.
push_back
(
item
);
options
.
insert
(
item
);
}
return
result
;
if
(
device_name
==
"rasp"
)
{
return
target
::
rasp
(
options
);
}
else
if
(
device_name
==
"mail"
)
{
return
target
::
mali
(
options
);
}
else
{
return
CreateTarget
(
target_name
,
options
);
}
}
/*! \brief Entry to hold the Target context stack. */
struct
TVMTargetThreadLocalEntry
{
/*! \brief The current target context */
std
::
stack
<
tvm
::
Target
>
context_stack
;
TVMTargetThreadLocalEntry
()
{
}
};
/*! \brief Thread local store to hold the Target context stack. */
typedef
dmlc
::
ThreadLocalStore
<
TVMTargetThreadLocalEntry
>
TVMTargetThreadLocalStore
;
void
Target
::
EnterTargetScope
(
const
tvm
::
Target
&
target
)
{
TVMTargetThreadLocalEntry
*
entry
=
TVMTargetThreadLocalStore
::
Get
();
entry
->
context_stack
.
push
(
target
);
}
void
Target
::
ExitTargetScope
()
{
TVMTargetThreadLocalEntry
*
entry
=
TVMTargetThreadLocalStore
::
Get
();
entry
->
context_stack
.
pop
();
}
tvm
::
Target
Target
::
current_target
(
bool
allow_not_defined
)
{
TVMTargetThreadLocalEntry
*
entry
=
TVMTargetThreadLocalStore
::
Get
();
if
(
entry
->
context_stack
.
size
()
>
0
)
{
return
entry
->
context_stack
.
top
();
}
CHECK
(
allow_not_defined
)
<<
"Target context required. Please set it by constructing a TargetContext"
;
return
Target
();
}
namespace
target
{
Target
llvm
()
{
std
::
unordered_set
<
std
::
string
>
keys
({
"llvm"
,
"cpu"
});
std
::
vector
<
std
::
string
>
options
;
return
Target
(
"llvm"
,
kDLCPU
,
512
,
1
,
keys
,
options
,
std
::
unordered_set
<
std
::
string
>
());
std
::
unordered_set
<
std
::
string
>
MergeOptions
(
std
::
unordered_set
<
std
::
string
>
opts
,
const
std
::
unordered_set
<
std
::
string
>&
new_opts
)
{
opts
.
insert
(
new_opts
.
begin
(),
new_opts
.
end
());
return
opts
;
}
Target
llvm
(
const
std
::
unordered_set
<
std
::
string
>&
options
)
{
return
CreateTarget
(
"llvm"
,
options
);
}
Target
cuda
()
{
std
::
unordered_set
<
std
::
string
>
keys
({
"cuda"
,
"gpu"
});
std
::
vector
<
std
::
string
>
options
;
return
Target
(
"cuda"
,
kDLGPU
,
512
,
32
,
keys
,
options
,
std
::
unordered_set
<
std
::
string
>
());
Target
cuda
(
const
std
::
unordered_set
<
std
::
string
>&
options
)
{
return
CreateTarget
(
"cuda"
,
options
);
}
Target
rocm
()
{
std
::
unordered_set
<
std
::
string
>
keys
({
"rocm"
,
"gpu"
});
std
::
vector
<
std
::
string
>
options
;
return
Target
(
"rocm"
,
kDLROCM
,
256
,
1
,
keys
,
options
,
std
::
unordered_set
<
std
::
string
>
());
Target
rocm
(
const
std
::
unordered_set
<
std
::
string
>&
options
)
{
return
CreateTarget
(
"rocm"
,
options
);
}
Target
metal
()
{
std
::
unordered_set
<
std
::
string
>
keys
({
"gpu"
});
std
::
vector
<
std
::
string
>
options
;
return
Target
(
"metal"
,
kDLMetal
,
256
,
1
,
keys
,
options
,
std
::
unordered_set
<
std
::
string
>
());
Target
opencl
(
const
std
::
unordered_set
<
std
::
string
>&
options
)
{
return
CreateTarget
(
"opencl"
,
options
);
}
Target
rasp
()
{
std
::
unordered_set
<
std
::
string
>
keys
({
"llvm"
,
"cpu"
});
std
::
vector
<
std
::
string
>
options
({
Target
metal
(
const
std
::
unordered_set
<
std
::
string
>&
options
)
{
return
CreateTarget
(
"metal"
,
options
);
}
Target
rasp
(
const
std
::
unordered_set
<
std
::
string
>&
options
)
{
return
CreateTarget
(
"llvm"
,
MergeOptions
(
options
,
{
"-device=rasp"
,
"-mtriple=armv7l-none-linux-gnueabihf"
,
"-mcpu=cortex-a53"
,
"-mattr=+neon"
});
return
Target
(
"llvm"
,
kDLCPU
,
512
,
1
,
keys
,
options
,
std
::
unordered_set
<
std
::
string
>
());
}));
}
Target
mali
()
{
std
::
unordered_set
<
std
::
string
>
keys
({
"rocm"
,
"gpu"
});
std
::
vector
<
std
::
string
>
options
({
Target
mali
(
const
std
::
unordered_set
<
std
::
string
>&
options
)
{
return
CreateTarget
(
"opencl"
,
MergeOptions
(
options
,
{
"-device=mali"
});
return
Target
(
"opencl"
,
kDLOpenCL
,
256
,
1
,
keys
,
options
);
}));
}
Target
stackvm
()
{
std
::
unordered_set
<
std
::
string
>
keys
({
"stackvm"
,
"cpu"
});
std
::
vector
<
std
::
string
>
options
;
return
Target
(
"stackvm"
,
kDLCPU
,
512
,
1
,
keys
,
options
,
std
::
unordered_set
<
std
::
string
>
());
Target
stackvm
(
const
std
::
unordered_set
<
std
::
string
>&
options
)
{
return
CreateTarget
(
"stackvm"
,
options
);
}
}
// namespace target
...
...
@@ -146,7 +278,7 @@ bool LLVMEnabled() {
/*! \return The default host target for a given device target */
Target
DefaultTargetHost
(
Target
target
)
{
if
(
target
.
device_type
==
kDLCPU
)
{
if
(
target
->
device_type
==
kDLCPU
)
{
return
target
;
}
else
{
if
(
LLVMEnabled
())
{
...
...
@@ -254,7 +386,7 @@ Array<LoweredFunc> lower(Schedule sch,
runtime
::
Module
build
(
const
Array
<
LoweredFunc
>&
funcs
,
const
Target
&
target
,
Target
*
target_host
,
const
Target
&
target_host
,
const
BuildConfig
&
config
)
{
std
::
unordered_set
<
std
::
string
>
all_names
;
for
(
const
auto
&
x
:
funcs
)
{
...
...
@@ -262,15 +394,13 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
all_names
.
insert
(
x
->
name
);
}
Target
target_host_val
=
target_host
==
nullptr
?
DefaultTargetHost
(
target
)
:
*
target_host
;
auto
target_host_val
=
target_host
.
defined
()
?
target_host
:
DefaultTargetHost
(
target
);
Array
<
LoweredFunc
>
fhost
;
Array
<
LoweredFunc
>
fdevice
;
for
(
const
auto
&
x
:
funcs
)
{
CHECK
(
ir
::
VerifyMemory
(
x
,
target
.
device_type
))
CHECK
(
ir
::
VerifyMemory
(
x
,
target
->
device_type
))
<<
"Direct host side access to device memory is detected in "
<<
x
->
func_name
()
<<
". Did you forget to bind?"
;
...
...
@@ -281,7 +411,7 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
}
func
=
ir
::
ThreadSync
(
func
,
"shared"
);
func
=
ir
::
LowerThreadAllreduce
(
func
,
target
.
thread_warp_size
);
func
=
ir
::
LowerThreadAllreduce
(
func
,
target
->
thread_warp_size
);
auto
fsplits
=
ir
::
SplitHostDevice
(
func
);
fhost
.
push_back
(
fsplits
[
0
]);
for
(
auto
f
=
fsplits
.
begin
()
+
1
;
f
!=
fsplits
.
end
();
++
f
)
{
...
...
@@ -296,14 +426,17 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
}
}
if
(
target
.
keys
.
count
(
"gpu"
)
>
0
&&
fdevice
.
size
()
==
0
)
{
LOG
(
WARNING
)
<<
"Specified target "
+
target
.
str
()
+
auto
keys
=
target
->
keys
();
bool
target_is_gpu
=
std
::
find
(
keys
.
begin
(),
keys
.
end
(),
"gpu"
)
!=
keys
.
end
();
if
(
target_is_gpu
&&
fdevice
.
size
()
==
0
)
{
LOG
(
WARNING
)
<<
"Specified target "
+
target
->
str
()
+
" but cannot find device code. Did you forget to bind?"
;
}
for
(
size_t
i
=
0
;
i
<
fhost
.
size
();
++
i
)
{
auto
func
=
fhost
[
i
];
func
=
ir
::
BindDeviceType
(
func
,
target
.
device_type
);
func
=
ir
::
BindDeviceType
(
func
,
target
->
device_type
);
func
=
ir
::
LowerTVMBuiltin
(
func
);
fhost
.
Set
(
i
,
func
);
}
...
...
@@ -311,21 +444,21 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
for
(
size_t
i
=
0
;
i
<
fdevice
.
size
();
++
i
)
{
auto
func
=
fdevice
[
i
];
func
=
ir
::
LowerIntrin
(
func
,
target
.
target_name
);
func
=
ir
::
LowerIntrin
(
func
,
target
->
target_name
);
fdevice
.
Set
(
i
,
func
);
}
for
(
size_t
i
=
0
;
i
<
fhost
.
size
();
++
i
)
{
auto
func
=
fhost
[
i
];
func
=
ir
::
LowerIntrin
(
func
,
target_host_val
.
target_name
);
func
=
ir
::
LowerIntrin
(
func
,
target_host_val
->
target_name
);
func
=
ir
::
CombineContextCall
(
func
);
fhost
.
Set
(
i
,
func
);
}
auto
mhost
=
codegen
::
Build
(
fhost
,
target_host_val
.
str
());
auto
mhost
=
codegen
::
Build
(
fhost
,
target_host_val
->
str
());
if
(
fdevice
.
size
()
>
0
)
{
auto
mdev
=
codegen
::
Build
(
fdevice
,
target
.
str
());
auto
mdev
=
codegen
::
Build
(
fdevice
,
target
->
str
());
mhost
.
Import
(
mdev
);
}
...
...
@@ -354,4 +487,160 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p
->
stream
<<
")"
;
});
struct
GenericFunc
::
Manager
{
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Node
>
>
fmap
;
// mutex
std
::
mutex
mutex
;
Manager
()
{
}
static
Manager
*
Global
()
{
static
Manager
inst
;
return
&
inst
;
}
};
GenericFunc
GenericFunc
::
Get
(
const
std
::
string
&
name
)
{
Manager
*
m
=
Manager
::
Global
();
std
::
lock_guard
<
std
::
mutex
>
(
m
->
mutex
);
auto
it
=
m
->
fmap
.
find
(
name
);
if
(
it
==
m
->
fmap
.
end
())
{
auto
f
=
std
::
make_shared
<
GenericFuncNode
>
();
f
->
name_
=
name
;
m
->
fmap
[
name
]
=
f
;
return
GenericFunc
(
f
);
}
else
{
return
GenericFunc
(
it
->
second
);
}
}
void
GenericFunc
::
RegisterGenericFunc
(
GenericFunc
func
,
const
std
::
string
&
name
)
{
Manager
*
m
=
Manager
::
Global
();
std
::
lock_guard
<
std
::
mutex
>
(
m
->
mutex
);
auto
it
=
m
->
fmap
.
find
(
name
);
CHECK
(
it
==
m
->
fmap
.
end
())
<<
"GenericFunc already registered "
<<
name
;
func
->
name_
=
name
;
m
->
fmap
[
name
]
=
func
.
node_
;
}
GenericFunc
&
GenericFunc
::
set_default
(
const
PackedFunc
value
,
bool
allow_override
)
{
auto
node
=
static_cast
<
GenericFuncNode
*>
(
node_
.
get
());
if
(
!
allow_override
)
{
CHECK
(
node
->
generic_func_
==
nullptr
)
<<
"Generic function already registered for "
<<
node
->
name_
;
}
node
->
generic_func_
=
value
;
return
*
this
;
}
GenericFunc
&
GenericFunc
::
register_func
(
const
std
::
vector
<
std
::
string
>&
tags
,
const
PackedFunc
value
,
bool
allow_override
)
{
for
(
auto
&
t
:
tags
)
{
if
(
!
allow_override
)
{
auto
iter
=
(
*
this
)
->
dispatch_dict_
.
find
(
t
);
CHECK
(
iter
==
(
*
this
)
->
dispatch_dict_
.
end
())
<<
"Tag "
<<
t
<<
" already registered for schedule factory "
<<
(
*
this
)
->
name_
;
}
(
*
this
)
->
dispatch_dict_
[
t
]
=
value
;
}
return
*
this
;
}
void
GenericFunc
::
CallPacked
(
TVMArgs
args
,
TVMRetValue
*
ret
)
const
{
auto
node
=
static_cast
<
GenericFuncNode
*>
(
node_
.
get
());
auto
target
=
Target
::
current_target
(
true
);
PackedFunc
func
;
if
(
target
.
defined
())
{
for
(
auto
&
k
:
target
->
keys
())
{
auto
iter
=
node
->
dispatch_dict_
.
find
(
k
);
if
(
iter
!=
node
->
dispatch_dict_
.
end
())
{
func
=
iter
->
second
;
break
;
}
}
}
if
(
func
==
nullptr
)
{
CHECK
(
node
->
generic_func_
!=
nullptr
)
<<
"No generic function registered for "
<<
node
->
name_
;
func
=
node
->
generic_func_
;
}
func
.
CallPacked
(
args
,
ret
);
}
TVM_REGISTER_API
(
"_GenericFuncCreate"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
GenericFunc
(
std
::
make_shared
<
GenericFuncNode
>
());
});
TVM_REGISTER_API
(
"_GenericFuncGetGlobal"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
std
::
string
func_name
=
args
[
0
];
*
ret
=
GenericFunc
::
Get
(
func_name
);
});
TVM_REGISTER_API
(
"_GenericFuncSetDefault"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
GenericFunc
generic_func
=
args
[
0
];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc
*
func
=
new
PackedFunc
(
args
[
1
].
operator
PackedFunc
());
bool
allow_override
=
args
[
2
];
generic_func
.
set_default
(
*
func
,
allow_override
);
});
TVM_REGISTER_API
(
"_GenericFuncRegisterFunc"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
GenericFunc
generic_func
=
args
[
0
];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc
*
func
=
new
PackedFunc
(
args
[
1
].
operator
PackedFunc
());
Array
<
Expr
>
tags
=
args
[
2
];
bool
allow_override
=
args
[
3
];
std
::
vector
<
std
::
string
>
tags_vector
;
for
(
auto
&
tag
:
tags
)
{
tags_vector
.
push_back
(
tag
.
as
<
tvm
::
ir
::
StringImm
>
()
->
value
);
}
generic_func
.
register_func
(
tags_vector
,
*
func
,
allow_override
);
});
TVM_REGISTER_API
(
"_GenericFuncCallFunc"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
GenericFunc
generic_func
=
args
[
0
];
TVMArgs
func_args
(
&
args
.
values
[
1
],
&
args
.
type_codes
[
1
],
args
.
num_args
-
1
);
generic_func
.
CallPacked
(
func_args
,
ret
);
});
TVM_REGISTER_API
(
"_GetCurrentTarget"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
bool
allow_not_defined
=
args
[
0
];
*
ret
=
Target
::
current_target
(
allow_not_defined
);
});
TVM_REGISTER_API
(
"_EnterTargetScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Target
target
=
args
[
0
];
auto
current
=
Target
::
current_target
();
if
(
current
.
defined
()
&&
target
->
str
()
!=
current
->
str
())
{
LOG
(
WARNING
)
<<
"Overriding target "
<<
current
->
str
()
<<
" with new target scope "
<<
target
->
str
();
}
Target
::
EnterTargetScope
(
target
);
});
TVM_REGISTER_API
(
"_ExitTargetScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Target
::
ExitTargetScope
();
});
}
// namespace tvm
src/runtime/threading_backend.cc
View file @
6292204e
...
...
@@ -6,6 +6,7 @@
#include <tvm/runtime/threading_backend.h>
#include <dmlc/logging.h>
#include <thread>
#include <algorithm>
#if defined(__linux__)
#include <sched.h>
#endif
...
...
tests/cpp/build_module_test.cc
View file @
6292204e
...
...
@@ -31,7 +31,7 @@ TEST(BuildModule, Basic) {
auto
target
=
target
::
llvm
();
auto
lowered
=
lower
(
s
,
args
,
"func"
,
binds
,
config
);
auto
module
=
build
(
lowered
,
target
,
nullptr
,
config
);
auto
module
=
build
(
lowered
,
target
,
Target
()
,
config
);
}
...
...
tests/python/unittest/test_lang_target.py
View file @
6292204e
...
...
@@ -34,11 +34,16 @@ def test_target_dispatch():
with
tvm
.
target
.
create
(
"metal"
):
assert
mygeneric
(
1
)
==
3
try
:
mygeneric
(
0
)
raise
RuntimeError
(
"not reached"
)
except
RuntimeError
:
pass
assert
tvm
.
target
.
current_target
()
==
None
def
test_target_string_parse
():
target
=
tvm
.
target
.
create
(
"cuda -libs=cublas,cudnn"
)
assert
target
.
target_name
==
"cuda"
assert
target
.
options
==
[
'-libs=cublas,cudnn'
]
assert
target
.
keys
==
[
'cuda'
,
'gpu'
]
assert
target
.
libs
==
[
'cublas'
,
'cudnn'
]
if
__name__
==
"__main__"
:
test_target_dispatch
()
test_target_string_parse
()
topi/include/topi/cuda/dense.h
View file @
6292204e
...
...
@@ -24,31 +24,30 @@ namespace cuda {
* \param target The target device
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim]
(optional
)
* \param bias Tensor with shape [out_dim]
. Optional; to omit bias, pass Tensor(
)
*
* \return Tensor with shape [batch, out_dim]
*/
inline
tvm
::
Tensor
dense_cuda
(
const
Target
&
target
,
const
tvm
::
Tensor
&
data
,
const
tvm
::
Tensor
&
weight
,
tvm
::
Tensor
*
bias
)
{
const
tvm
::
Tensor
&
bias
)
{
CHECK_EQ
(
data
->
shape
.
size
(),
2
)
<<
"dense requires 2-D data"
;
CHECK_EQ
(
weight
->
shape
.
size
(),
2
)
<<
"dense requires 2-D weight"
;
if
(
bias
!=
nullptr
)
{
CHECK_EQ
(
(
*
bias
)
->
shape
.
size
(),
1
)
<<
"dense requires 1-D bias"
;
if
(
bias
.
defined
()
)
{
CHECK_EQ
(
bias
->
shape
.
size
(),
1
)
<<
"dense requires 1-D bias"
;
}
auto
batch
=
data
->
shape
[
0
];
auto
in_dim
=
data
->
shape
[
1
];
auto
out_dim
=
weight
->
shape
[
0
];
if
(
target
.
libs
.
count
(
"cublas"
)
>
0
)
{
if
(
target
->
libs
().
count
(
"cublas"
)
)
{
auto
mm
=
topi
::
contrib
::
cublas_matmul
(
data
,
weight
,
false
,
true
);
if
(
bias
!=
nullptr
)
{
auto
bias_val
=
*
bias
;
if
(
bias
.
defined
())
{
mm
=
tvm
::
compute
({
batch
,
out_dim
},
[
&
](
Var
i
,
Var
j
)
{
return
mm
(
i
,
j
)
+
bias
_val
(
j
);
return
mm
(
i
,
j
)
+
bias
(
j
);
},
"tensor"
,
kBroadcast
);
}
...
...
@@ -67,8 +66,8 @@ inline tvm::Tensor dense_cuda(const Target& target,
* \return A schedule for the given ops.
*/
inline
Schedule
schedule_dense
(
const
Target
&
target
,
const
Array
<
Tensor
>&
outs
)
{
if
(
target
.
target_name
==
"cuda"
&&
target
.
libs
.
count
(
"cublas"
)
>
0
)
{
if
(
target
->
target_name
==
"cuda"
&&
target
->
libs
().
count
(
"cublas"
)
)
{
return
topi
::
generic
::
schedule_extern
(
target
,
outs
);
}
...
...
topi/include/topi/cuda/extern.h
View file @
6292204e
...
...
@@ -28,7 +28,7 @@ namespace cuda {
inline
Schedule
ScheduleOutputForExtern
(
Target
target
,
Operation
op
,
Schedule
sch
)
{
auto
x
=
op
.
output
(
0
);
auto
fused
=
detail
::
Fuse
(
sch
[
x
],
sch
[
x
]
->
op
.
as
<
ComputeOpNode
>
()
->
axis
);
auto
num_thread
=
target
.
max_num_threads
;
auto
num_thread
=
target
->
max_num_threads
;
IterVar
bx
,
tx
;
sch
[
x
].
split
(
fused
,
num_thread
,
&
bx
,
&
tx
);
sch
[
x
].
bind
(
bx
,
tvm
::
thread_axis
(
Range
(),
"blockIdx.x"
));
...
...
topi/include/topi/cuda/injective.h
View file @
6292204e
...
...
@@ -25,7 +25,7 @@ namespace cuda {
inline
void
ScheduleInjectiveOp
(
const
Target
&
target
,
Operation
op
,
Schedule
s
)
{
auto
x
=
op
.
output
(
0
);
auto
fused
=
detail
::
Fuse
(
s
[
x
],
s
[
x
]
->
op
.
as
<
ComputeOpNode
>
()
->
axis
);
auto
num_thread
=
target
.
max_num_threads
;
auto
num_thread
=
target
->
max_num_threads
;
IterVar
bx
,
tx
;
s
[
x
].
split
(
fused
,
num_thread
,
&
bx
,
&
tx
);
s
[
x
].
bind
(
bx
,
thread_axis
(
Range
(),
"blockIdx.x"
));
...
...
topi/include/topi/cuda/pooling.h
View file @
6292204e
...
...
@@ -34,7 +34,7 @@ inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
auto
_schedule
=
[
&
](
const
Tensor
&
padded_input
,
const
Tensor
&
pool
)
{
s
[
padded_input
].
compute_inline
();
auto
num_thread
=
target
.
max_num_threads
;
auto
num_thread
=
target
->
max_num_threads
;
Tensor
out
;
Tensor
OL
;
if
(
detail
::
contains
(
s
->
outputs
,
pool
->
op
))
{
...
...
topi/include/topi/cuda/reduction.h
View file @
6292204e
...
...
@@ -51,7 +51,7 @@ Schedule ScheduleReduce(const Target& target,
if
(
out_stage
->
op
.
as
<
ComputeOpNode
>
()
->
axis
.
size
()
>
0
)
{
all_reduce
=
false
;
num_thread
=
32
;
if
(
target
.
target_name
==
"opencl"
)
{
if
(
target
->
target_name
==
"opencl"
)
{
// Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests.
// Don't know why.
num_thread
=
16
;
...
...
@@ -61,7 +61,7 @@ Schedule ScheduleReduce(const Target& target,
thread_y
=
tvm
::
thread_axis
(
Range
(
0
,
num_thread
),
"threadIdx.y"
);
}
else
{
all_reduce
=
true
;
num_thread
=
target
.
max_num_threads
;
num_thread
=
target
->
max_num_threads
;
thread_x
=
tvm
::
thread_axis
(
Range
(
0
,
num_thread
),
"threadIdx.x"
);
}
...
...
topi/include/topi/nn/dense.h
View file @
6292204e
...
...
@@ -20,17 +20,17 @@ using namespace tvm;
*
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim]
(optional
)
* \param bias Tensor with shape [out_dim]
. Optional; to omit bias, pass Tensor(
)
*
* \return Tensor with shape [batch, out_dim]
*/
inline
tvm
::
Tensor
dense
(
const
tvm
::
Tensor
&
data
,
const
tvm
::
Tensor
&
weight
,
tvm
::
Tensor
*
bias
)
{
const
tvm
::
Tensor
&
bias
)
{
CHECK_EQ
(
data
->
shape
.
size
(),
2
)
<<
"dense requires 2-D data"
;
CHECK_EQ
(
weight
->
shape
.
size
(),
2
)
<<
"dense requires 2-D weight"
;
if
(
bias
!=
nullptr
)
{
CHECK_EQ
(
(
*
bias
)
->
shape
.
size
(),
1
)
<<
"dense requires 1-D bias"
;
if
(
bias
.
defined
()
)
{
CHECK_EQ
(
bias
->
shape
.
size
(),
1
)
<<
"dense requires 1-D bias"
;
}
auto
batch
=
data
->
shape
[
0
];
...
...
@@ -44,12 +44,11 @@ inline tvm::Tensor dense(const tvm::Tensor& data,
return
tvm
::
sum
(
data
(
i
,
k
)
*
weight
(
j
,
k
),
{
k
});
},
"tensor"
,
"dense"
);
if
(
bias
!=
nullptr
)
{
auto
bias_val
=
*
bias
;
if
(
bias
.
defined
())
{
matmul
=
tvm
::
compute
(
{
batch
,
out_dim
},
[
&
](
Var
i
,
Var
j
)
{
return
matmul
(
i
,
j
)
+
bias
_val
(
j
);
return
matmul
(
i
,
j
)
+
bias
(
j
);
},
"tensor"
,
kBroadcast
);
}
...
...
topi/include/topi/rocm/dense.h
View file @
6292204e
...
...
@@ -25,31 +25,30 @@ namespace rocm {
* \param target The target device
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim]
(optional
)
* \param bias Tensor with shape [out_dim]
. Optional; to omit bias, pass Tensor(
)
*
* \return Tensor with shape [batch, out_dim]
*/
inline
tvm
::
Tensor
dense_rocm
(
const
Target
&
target
,
const
tvm
::
Tensor
&
data
,
const
tvm
::
Tensor
&
weight
,
tvm
::
Tensor
*
bias
)
{
const
tvm
::
Tensor
&
bias
)
{
CHECK_EQ
(
data
->
shape
.
size
(),
2
)
<<
"dense requires 2-D data"
;
CHECK_EQ
(
weight
->
shape
.
size
(),
2
)
<<
"dense requires 2-D weight"
;
if
(
bias
!=
nullptr
)
{
CHECK_EQ
(
(
*
bias
)
->
shape
.
size
(),
1
)
<<
"dense requires 1-D bias"
;
if
(
bias
.
defined
()
)
{
CHECK_EQ
(
bias
->
shape
.
size
(),
1
)
<<
"dense requires 1-D bias"
;
}
auto
batch
=
data
->
shape
[
0
];
auto
in_dim
=
data
->
shape
[
1
];
auto
out_dim
=
weight
->
shape
[
0
];
if
(
target
.
libs
.
count
(
"rocblas"
)
>
0
)
{
if
(
target
->
libs
().
count
(
"rocblas"
)
)
{
auto
mm
=
topi
::
contrib
::
rocblas_matmul
(
data
,
weight
,
false
,
true
);
if
(
bias
!=
nullptr
)
{
auto
bias_val
=
*
bias
;
if
(
bias
.
defined
())
{
mm
=
tvm
::
compute
({
batch
,
out_dim
},
[
&
](
Var
i
,
Var
j
)
{
return
mm
(
i
,
j
)
+
bias
_val
(
j
);
return
mm
(
i
,
j
)
+
bias
(
j
);
},
"tensor"
,
kBroadcast
);
}
...
...
@@ -68,8 +67,8 @@ inline tvm::Tensor dense_rocm(const Target& target,
* \return A schedule for the given ops.
*/
inline
Schedule
schedule_dense
(
const
Target
&
target
,
const
Array
<
Tensor
>&
outs
)
{
if
(
target
.
target_name
==
"rocm"
&&
target
.
libs
.
count
(
"rocblas"
)
>
0
)
{
if
(
target
->
target_name
==
"rocm"
&&
target
->
libs
().
count
(
"rocblas"
)
)
{
return
topi
::
generic
::
schedule_extern
(
target
,
outs
);
}
...
...
topi/python/topi/__init__.py
View file @
6292204e
...
...
@@ -11,6 +11,10 @@ from __future__ import absolute_import as _abs
from
tvm._ffi.libinfo
import
__version__
# Ensure C++ schedules get registered first, so python schedules can
# override them.
from
.
import
cpp
from
.math
import
*
from
.tensor
import
*
from
.reduction
import
*
...
...
@@ -24,7 +28,6 @@ from . import mali
from
.
import
opengl
from
.
import
util
from
.
import
rocm
from
.
import
cpp
from
.
import
vision
# not import testing by default
# because testing can have extra deps that are not necessary
...
...
topi/python/topi/generic/injective.py
View file @
6292204e
...
...
@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
import
tvm
@tvm.target.
generic_func
@tvm.target.
override_native_generic_func
(
"schedule_injective"
)
def
schedule_injective
(
outs
):
"""Schedule for injective op.
...
...
topi/python/topi/generic/nn.py
View file @
6292204e
...
...
@@ -106,7 +106,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
return
_default_schedule
(
outs
,
False
)
@tvm.target.
generic_func
@tvm.target.
override_native_generic_func
(
"schedule_reduce"
)
def
schedule_reduce
(
outs
):
"""Schedule for reduction
...
...
@@ -124,7 +124,7 @@ def schedule_reduce(outs):
return
_default_schedule
(
outs
,
True
)
@tvm.target.
generic_func
@tvm.target.
override_native_generic_func
(
"schedule_softmax"
)
def
schedule_softmax
(
outs
):
"""Schedule for softmax
...
...
@@ -142,7 +142,7 @@ def schedule_softmax(outs):
return
_default_schedule
(
outs
,
False
)
@tvm.target.
generic_func
@tvm.target.
override_native_generic_func
(
"schedule_dense"
)
def
schedule_dense
(
outs
):
"""Schedule for dense
...
...
@@ -160,7 +160,7 @@ def schedule_dense(outs):
return
_default_schedule
(
outs
,
False
)
@tvm.target.
generic_func
@tvm.target.
override_native_generic_func
(
"schedule_pool"
)
def
schedule_pool
(
outs
):
"""Schedule for pool
...
...
@@ -178,7 +178,7 @@ def schedule_pool(outs):
return
_default_schedule
(
outs
,
False
)
@tvm.target.
generic_func
@tvm.target.
override_native_generic_func
(
"schedule_global_pool"
)
def
schedule_global_pool
(
outs
):
"""Schedule for global pool
...
...
@@ -195,7 +195,7 @@ def schedule_global_pool(outs):
"""
return
_default_schedule
(
outs
,
False
)
@tvm.target.
generic_func
@tvm.target.
override_native_generic_func
(
"schedule_binarize_pack"
)
def
schedule_binarize_pack
(
outs
):
"""Schedule for binarize_pack
...
...
@@ -213,7 +213,7 @@ def schedule_binarize_pack(outs):
return
_default_schedule
(
outs
,
False
)
@tvm.target.
generic_func
@tvm.target.
override_native_generic_func
(
"schedule_binary_dense"
)
def
schedule_binary_dense
(
outs
):
"""Schedule for binary_dense
...
...
topi/python/topi/nn/dense.py
View file @
6292204e
...
...
@@ -39,7 +39,7 @@ def dense_default(data, weight, bias=None):
return
matmul
@tvm.target.
generic_func
@tvm.target.
override_native_generic_func
(
"dense"
)
def
dense
(
data
,
weight
,
bias
=
None
):
"""Applies a linear transformation: :math:`Y = XW^T + b`.
...
...
topi/src/topi.cc
View file @
6292204e
...
...
@@ -51,6 +51,7 @@ struct extension_class_info<tvm::Target> {
}
// namespace runtime
namespace
topi
{
using
namespace
tvm
;
using
namespace
tvm
::
runtime
;
...
...
@@ -281,15 +282,7 @@ TVM_REGISTER_GLOBAL("topi.nn.binary_dense")
/* Ops from nn/dense.h */
TVM_REGISTER_GLOBAL
(
"topi.nn.dense"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Tensor
bias_val
;
Tensor
*
bias
;
if
(
args
[
2
].
type_code
()
==
kNull
)
{
bias
=
nullptr
;
}
else
{
bias_val
=
args
[
2
];
bias
=
&
bias_val
;
}
*
rv
=
nn
::
dense
(
args
[
0
],
args
[
1
],
bias
);
*
rv
=
nn
::
dense
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
/* Ops from nn/dilate.h */
...
...
@@ -388,15 +381,7 @@ TVM_REGISTER_GLOBAL("topi.x86.schedule_injective")
/* ROCm schedules */
TVM_REGISTER_GLOBAL
(
"topi.rocm.dense_cuda"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Tensor
bias_val
;
Tensor
*
bias
;
if
(
args
[
3
].
type_code
()
==
kNull
)
{
bias
=
nullptr
;
}
else
{
bias_val
=
args
[
3
];
bias
=
&
bias_val
;
}
*
rv
=
rocm
::
dense_rocm
(
args
[
0
],
args
[
1
],
args
[
2
],
bias
);
*
rv
=
rocm
::
dense_rocm
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
});
TVM_REGISTER_GLOBAL
(
"topi.rocm.schedule_dense"
)
...
...
@@ -407,15 +392,7 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
/* CUDA schedules */
TVM_REGISTER_GLOBAL
(
"topi.cuda.dense_cuda"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Tensor
bias_val
;
Tensor
*
bias
;
if
(
args
[
3
].
type_code
()
==
kNull
)
{
bias
=
nullptr
;
}
else
{
bias_val
=
args
[
3
];
bias
=
&
bias_val
;
}
*
rv
=
cuda
::
dense_cuda
(
args
[
0
],
args
[
1
],
args
[
2
],
bias
);
*
rv
=
cuda
::
dense_cuda
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
});
TVM_REGISTER_GLOBAL
(
"topi.cuda.schedule_dense"
)
...
...
@@ -453,4 +430,106 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax")
*
rv
=
topi
::
cuda
::
schedule_softmax
(
args
[
0
],
args
[
1
]);
});
/*! \brief Builder function for instantiating schedules. */
using
FTVMScheduleBuilder
=
std
::
function
<
tvm
::
Schedule
(
const
tvm
::
Target
&
target
,
const
tvm
::
Array
<
tvm
::
Tensor
>&
outs
)
>
;
/*!
* \brief Helper function for registering generic functions matching the
* FTVMScheduleBuilder signature. The schedule builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The schedule builder to wrap.
*
* \return The wrapped schedule builder
*/
inline
PackedFunc
WrapSchedule
(
FTVMScheduleBuilder
builder
)
{
return
PackedFunc
([
builder
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
target
=
Target
::
current_target
(
false
);
Array
<
Tensor
>
outs
;
NodeRef
argNodeRef
=
args
[
0
];
if
(
argNodeRef
->
type_index
()
==
outs
->
type_index
())
{
outs
=
args
[
0
];
}
else
{
outs
=
Array
<
Tensor
>
{
args
[
0
]
};
}
*
ret
=
builder
(
target
,
outs
);
});
}
TVM_REGISTER_GENERIC_FUNC
(
schedule_injective
)
.
set_default
(
WrapSchedule
(
topi
::
generic
::
schedule_injective
))
.
register_func
({
"cpu"
},
WrapSchedule
(
topi
::
x86
::
schedule_injective
))
.
register_func
({
"cuda"
,
"gpu"
},
WrapSchedule
(
topi
::
cuda
::
schedule_injective
));
TVM_REGISTER_GENERIC_FUNC
(
schedule_softmax
)
.
set_default
(
WrapSchedule
(
topi
::
generic
::
default_schedule
))
.
register_func
({
"cpu"
},
WrapSchedule
(
topi
::
x86
::
default_schedule
))
.
register_func
({
"cuda"
,
"gpu"
},
WrapSchedule
(
topi
::
cuda
::
schedule_softmax
));
TVM_REGISTER_GENERIC_FUNC
(
schedule_dense
)
.
set_default
(
WrapSchedule
(
topi
::
generic
::
default_schedule
))
.
register_func
({
"cuda"
,
"gpu"
},
WrapSchedule
(
topi
::
cuda
::
schedule_dense
))
.
register_func
({
"rocm"
},
WrapSchedule
(
topi
::
rocm
::
schedule_dense
));
TVM_REGISTER_GENERIC_FUNC
(
schedule_pool
)
.
set_default
(
WrapSchedule
(
topi
::
generic
::
default_schedule
))
.
register_func
({
"cpu"
},
WrapSchedule
(
topi
::
x86
::
default_schedule
))
.
register_func
({
"cuda"
,
"gpu"
},
WrapSchedule
(
topi
::
cuda
::
schedule_pool
));
TVM_REGISTER_GENERIC_FUNC
(
schedule_global_pool
)
.
set_default
(
WrapSchedule
(
topi
::
generic
::
default_schedule
))
.
register_func
({
"cpu"
},
WrapSchedule
(
topi
::
x86
::
default_schedule
))
.
register_func
({
"cuda"
,
"gpu"
},
WrapSchedule
(
topi
::
cuda
::
schedule_global_pool
));
TVM_REGISTER_GENERIC_FUNC
(
schedule_reduce
)
.
set_default
(
WrapSchedule
(
topi
::
generic
::
default_schedule_auto_inline
))
.
register_func
({
"cpu"
},
WrapSchedule
(
topi
::
x86
::
default_schedule_auto_inline
))
.
register_func
({
"cuda"
,
"gpu"
},
WrapSchedule
(
topi
::
cuda
::
schedule_reduce
));
TVM_REGISTER_GENERIC_FUNC
(
schedule_binarize_pack
)
.
set_default
(
WrapSchedule
(
topi
::
generic
::
default_schedule
))
.
register_func
({
"cpu"
},
WrapSchedule
(
topi
::
x86
::
schedule_binarize_pack
));
TVM_REGISTER_GENERIC_FUNC
(
schedule_binary_dense
)
.
set_default
(
WrapSchedule
(
topi
::
generic
::
default_schedule
))
.
register_func
({
"cpu"
},
WrapSchedule
(
topi
::
x86
::
schedule_binary_dense
));
/*! \brief Builder function for instantiating dense ops. */
using
FTVMDenseOpBuilder
=
std
::
function
<
tvm
::
Tensor
(
const
Target
&
target
,
const
tvm
::
Tensor
&
data
,
const
tvm
::
Tensor
&
weight
,
const
tvm
::
Tensor
&
bias
)
>
;
/*!
* \brief Helper function for registering dense ops matching the
* FTVMDenseOpBuilder signature. The op builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The op builder to wrap.
*
* \return The wrapped op builder
*/
inline
PackedFunc
WrapDenseOp
(
FTVMDenseOpBuilder
builder
)
{
return
PackedFunc
([
builder
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
target
=
Target
::
current_target
(
false
);
Tensor
data
=
args
[
0
];
Tensor
weight
=
args
[
1
];
Tensor
bias
=
args
[
2
];
*
ret
=
builder
(
target
,
data
,
weight
,
bias
);
});
}
TVM_REGISTER_GENERIC_FUNC
(
dense
)
.
set_default
(
WrapDenseOp
([](
const
Target
&
target
,
const
tvm
::
Tensor
&
data
,
const
tvm
::
Tensor
&
weight
,
const
tvm
::
Tensor
&
bias
)
{
return
topi
::
nn
::
dense
(
data
,
weight
,
bias
);
}))
.
register_func
({
"cuda"
,
"gpu"
},
WrapDenseOp
(
topi
::
cuda
::
dense_cuda
))
.
register_func
({
"rocm"
},
WrapDenseOp
(
topi
::
rocm
::
dense_rocm
));
}
// namespace topi
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment