Commit eb761f36 by Tianqi Chen Committed by GitHub

[Refactor] Introduce target generic dispatch system (#556)

* [TVM] Introduce target generic dispatch system

* fix target warning
parent c3cac464
...@@ -8,6 +8,7 @@ Python API ...@@ -8,6 +8,7 @@ Python API
intrin intrin
tensor tensor
schedule schedule
target
build build
module module
ndarray ndarray
......
tvm.target
----------
.. automodule:: tvm.target
.. autofunction:: tvm.target.generic_func
.. autoclass:: tvm.target.Target
:members:
.. autofunction:: tvm.target.cuda
.. autofunction:: tvm.target.rocm
.. autofunction:: tvm.target.rasp
.. autofunction:: tvm.target.create
...@@ -37,13 +37,11 @@ Index ...@@ -37,13 +37,11 @@ Index
.. autosummary:: .. autosummary::
topi.cuda.schedule_conv2d_nchw topi.generic.schedule_conv2d_nchw
topi.cuda.schedule_conv2d_hwcn topi.generic.schedule_depthwise_conv2d_nchw
topi.cuda.schedule_depthwise_conv2d_nchw topi.generic.schedule_reduce
topi.cuda.schedule_depthwise_conv2d_nhwc topi.generic.schedule_broadcast
topi.cuda.schedule_reduce topi.generic.schedule_injective
topi.cuda.schedule_broadcast
topi.cuda.schedule_injective
topi topi
~~~~ ~~~~
...@@ -75,14 +73,12 @@ topi.nn ...@@ -75,14 +73,12 @@ topi.nn
.. autofunction:: topi.nn.depthwise_conv2d_nhwc .. autofunction:: topi.nn.depthwise_conv2d_nhwc
topi.cuda topi.generic
~~~~~~~~~ ~~~~~~~~~~~~
.. automodule:: topi.cuda .. automodule:: topi.generic
.. autofunction:: topi.cuda.schedule_conv2d_nchw .. autofunction:: topi.generic.schedule_conv2d_nchw
.. autofunction:: topi.cuda.schedule_conv2d_hwcn .. autofunction:: topi.generic.schedule_depthwise_conv2d_nchw
.. autofunction:: topi.cuda.schedule_depthwise_conv2d_nchw .. autofunction:: topi.generic.schedule_reduce
.. autofunction:: topi.cuda.schedule_depthwise_conv2d_nhwc .. autofunction:: topi.generic.schedule_broadcast
.. autofunction:: topi.cuda.schedule_reduce .. autofunction:: topi.generic.schedule_injective
.. autofunction:: topi.cuda.schedule_broadcast
.. autofunction:: topi.cuda.schedule_injective
...@@ -56,11 +56,7 @@ def context(dev_type, dev_id=0): ...@@ -56,11 +56,7 @@ def context(dev_type, dev_id=0):
assert tvm.context("cuda", 0) == tvm.gpu(0) assert tvm.context("cuda", 0) == tvm.gpu(0)
""" """
if isinstance(dev_type, string_types): if isinstance(dev_type, string_types):
if dev_type not in TVMContext.STR2MASK: dev_type = dev_type.split()[0]
if dev_type.find("nvptx") != -1:
dev_type = "cuda"
if dev_type.find("rocm") != -1:
dev_type = "rocm"
if dev_type not in TVMContext.STR2MASK: if dev_type not in TVMContext.STR2MASK:
raise ValueError("Unknown device type %s" % dev_type) raise ValueError("Unknown device type %s" % dev_type)
dev_type = TVMContext.STR2MASK[dev_type] dev_type = TVMContext.STR2MASK[dev_type]
......
...@@ -100,9 +100,12 @@ class TVMContext(ctypes.Structure): ...@@ -100,9 +100,12 @@ class TVMContext(ctypes.Structure):
12: 'ext_dev', 12: 'ext_dev',
} }
STR2MASK = { STR2MASK = {
'llvm': 1,
'stackvm': 1,
'cpu': 1, 'cpu': 1,
'gpu': 2, 'gpu': 2,
'cuda': 2, 'cuda': 2,
'nvptx': 2,
'cl': 4, 'cl': 4,
'opencl': 4, 'opencl': 4,
'metal': 8, 'metal': 8,
......
...@@ -15,6 +15,7 @@ from . import container ...@@ -15,6 +15,7 @@ from . import container
from . import module from . import module
from . import codegen from . import codegen
from . import ndarray from . import ndarray
from . import target as _target
class BuildConfig(object): class BuildConfig(object):
"""Configuration scope to set a build config option. """Configuration scope to set a build config option.
...@@ -238,7 +239,7 @@ def lower(sch, ...@@ -238,7 +239,7 @@ def lower(sch,
def build(sch, def build(sch,
args=None, args=None,
target="llvm", target=None,
target_host=None, target_host=None,
name="default_function", name="default_function",
binds=None): binds=None):
...@@ -252,36 +253,10 @@ def build(sch, ...@@ -252,36 +253,10 @@ def build(sch,
args : list of Buffer or Tensor or Var, optional args : list of Buffer or Tensor or Var, optional
The argument lists to the function. The argument lists to the function.
target : str, optional target : str or :any:`tvm.target.Target`, optional
The target and option of the compilation. The target and option of the compilation.
When the target is llvm, you can set options like:
- **-mtriple=<target triple>** or **-target** target_host : str or :any:`tvm.target.Target` optional
Specify the target triple, which is useful for cross
compilation.
- **-mcpu=<cpuname>**
Specify a specific chip in the current architecture to
generate code for. By default this is infered from the
target triple and autodetected to the current architecture.
- **-mattr=a1,+a2,-a3,...**
Override or control specific attributes of the target,
such as whether SIMD operations are enabled or not. The
default set of attributes is set by the current CPU.
- **-system-lib**
Build TVM system library module. System lib is a global module that contains
self registered functions in program startup. User can get the module using
:any:`tvm.module.system_lib`.
It is useful in environments where dynamic loading api like dlopen is banned.
The system lib will be available as long as the result code is linked by the program.
target_host : str, optional
Host compilation target, if target is device. Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA, When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver we also need host(CPU) side code to interact with the driver
...@@ -301,6 +276,10 @@ def build(sch, ...@@ -301,6 +276,10 @@ def build(sch,
------- -------
f : Function, or pair of functions f : Function, or pair of functions
The result function. The result function.
Note
----
See the note on :any:`tvm.target` on target string format.
""" """
if isinstance(sch, schedule.Schedule): if isinstance(sch, schedule.Schedule):
if args is None: if args is None:
...@@ -325,6 +304,9 @@ def build(sch, ...@@ -325,6 +304,9 @@ def build(sch,
if x.name in fname_set: if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name) raise ValueError("Duplicate function name %s" % x.name)
target = _target.current_target() if target is None else target
target = _target.create(target) if target else _target.create("llvm")
fhost = [] fhost = []
fdevice = [] fdevice = []
for func in flist: for func in flist:
...@@ -332,7 +314,7 @@ def build(sch, ...@@ -332,7 +314,7 @@ def build(sch,
if BuildConfig.current.detect_global_barrier: if BuildConfig.current.detect_global_barrier:
func = ir_pass.ThreadSync(func, "global") func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared") func = ir_pass.ThreadSync(func, "shared")
warp_size = 32 if target == "cuda" else 1 warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size) func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(func)] fsplits = [s for s in ir_pass.SplitHostDevice(func)]
fhost.append(fsplits[0]) fhost.append(fsplits[0])
...@@ -345,29 +327,28 @@ def build(sch, ...@@ -345,29 +327,28 @@ def build(sch,
else: else:
raise ValueError("unknown function type %d" % func.func_type) raise ValueError("unknown function type %d" % func.func_type)
if not target.startswith("llvm") and target not in ("stackvm", "ext_dev") and not fdevice: if "gpu" in target.keys and not fdevice:
warnings.warn( warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target) "Specified target %s, but cannot find device code, did you do bind?" % target)
device = "cpu" if target.startswith("llvm") or target == "stackvm" else target device_type = ndarray.context(target.target_name, 0).device_type
device_type = ndarray.context(device, 0).device_type
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost] fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
if not target_host: if not target_host:
if device == "cpu": if device_type == ndarray.cpu(0).device_type:
target_host = target target_host = target
assert not fdevice assert not fdevice
else: else:
target_host = "llvm" if module.enabled("llvm") else "stackvm" target_host = "llvm" if module.enabled("llvm") else "stackvm"
target_host = _target.create(target_host)
target_device = target target_device = target
fdevice = [ir_pass.LowerIntrin(x, target_device) for x in fdevice] fdevice = [ir_pass.LowerIntrin(x, target_device.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host) for x in fhost] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mhost = codegen.build_module(fhost, target_host) mhost = codegen.build_module(fhost, str(target_host))
if fdevice: if fdevice:
mdev = codegen.build_module(fdevice, target_device) mdev = codegen.build_module(fdevice, str(target_device))
mhost.import_module(mdev) mhost.import_module(mdev)
return mhost return mhost
"""Target management API of tvm""" """Target management API of TVM.
TVM's target string is in fomat ``<target_name> [-option=value]...``.
Note
----
The list of options include:
- **-device=<device name>**
The device name.
- **-mtriple=<target triple>** or **-target**
Specify the target triple, which is useful for cross
compilation.
- **-mcpu=<cpuname>**
Specify a specific chip in the current architecture to
generate code for. By default this is infered from the
target triple and autodetected to the current architecture.
- **-mattr=a1,+a2,-a3,...**
Override or control specific attributes of the target,
such as whether SIMD operations are enabled or not. The
default set of attributes is set by the current CPU.
- **-system-lib**
Build TVM system library module. System lib is a global module that contains
self registered functions in program startup. User can get the module using
:any:`tvm.module.system_lib`.
It is useful in environments where dynamic loading api like dlopen is banned.
The system lib will be available as long as the result code is linked by the program.
We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string.
We can also use other specific function in this module to create specific targets.
"""
from __future__ import absolute_import from __future__ import absolute_import
import warnings
from ._ffi.base import _LIB_NAME
try:
from decorator import decorate
except ImportError as err_msg:
# Allow decorator to be missing in runtime
if _LIB_NAME != "libtvm_runtime.so":
raise err_msg
def _merge_opts(opts, new_opts):
"""Helper function to merge options"""
if isinstance(new_opts, str):
new_opts = new_opts.split()
if new_opts:
return opts + new_opts
return opts
class Target(object): class Target(object):
"""A Target describes the target type on which computation should be carried on""" """Target device information, use through TVM API.
default_target = None
str2type = {'x86': 1, 'cuda': 2, 'rasp': 3}
type2str = {1: 'x86', 2: 'cuda', 3: 'rasp'}
def __init__(self, target_type):
"""Constructs a context."""
if isinstance(target_type, Target):
self.target_typeid = target_type.target_typeid
else:
self.target_typeid = Target.str2type[target_type]
@property Parameters
def target_type(self): ----------
"""Returns the target type of current target.""" target_name : {"llvm", "cuda", "opencl", "metal", "rocm", "stackvm", "ext_dev"}
return Target.type2str[self.target_typeid] The major target name.
def __hash__(self): options : list of str, optional
"""Compute hash value of target for dictionary lookup""" Additional arguments appended to the target.
return hash(self.target_typeid)
def __eq__(self, other): Note
"""Compares two targets. Two targets are equal if they ----
have the same target type. Do not use class constructor, you can create target using the following functions
"""
return isinstance(other, Target) and \ - :any:`tvm.target.create` create target from string
self.target_typeid == other.target_typeid - :any:`tvm.target.rasp` create raspberry pi target
- :any:`tvm.target.cuda` create CUDA target
- :any:`tvm.target.rocm` create ROCM target
"""
current = None
def __init__(self,
target_name,
options=None):
self.target_name = target_name
self.options = _merge_opts([], options)
self.device_name = ""
# Parse device option
for item in self.options:
if item.startswith("-device="):
self.device_name = item.split("=")[1]
# Target query searchs device name first
if self.device_name:
self.keys = (self.device_name,)
else:
self.keys = ()
# Target configuration handling
self.thread_warp_size = 1
if target_name in ("llvm", ):
self.keys += ("cpu",)
elif target_name in ("cuda", "nvptx"):
self.keys += ("cuda", "gpu")
self.max_num_threads = 512
self.thread_warp_size = 32
elif target_name in ("rocm", "opencl"):
# For now assume rocm schedule for opencl
self.keys += ("rocm", "gpu")
self.max_num_threads = 256
elif target_name in ("metal",):
self.keys += ("gpu",)
self.max_num_threads = 256
elif target_name in ("stackvm", "ext_dev"):
# Do not now class for stacvm or ext_dev
pass
else:
raise ValueError("Unknown target name %s" % target_name)
def __str__(self): def __str__(self):
return '%s' % (self.target_type) return " ".join([self.target_name] + self.options)
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
def __enter__(self): def __enter__(self):
self._old_target = Target.default_target self._old_target = Target.current
Target.default_target = self if self._old_target is not None and str(self) != str(self._old_target):
warnings.warn(
"Override target '%s' with new target scope '%s'" % (
self._old_target, self))
Target.current = self
return self return self
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
Target.default_target = self._old_target Target.current = self._old_target
def generic_func(fdefault):
"""Wrap a target generic function.
Generic function allows registeration of further functions
that can be dispatched on current target context.
If no registered dispatch is matched, the fdefault will be called.
Parameters
----------
fdefault : function
The default function.
Returns
-------
fgeneric : function
A wrapped generic function.
Example
-------
.. code-block:: python
import tvm
# wrap function as target generic
@tvm.target.generic_func
def my_func(a):
return a + 1
# register specialization of my_func under target cuda
@my_func.register("cuda")
def my_func_cuda(a):
return a + 2
# displays 3, because my_func is called
print(my_func(2))
# displays 4, because my_func_cuda is called
with tvm.target.cuda():
print(my_func(2))
"""
dispatch_dict = {}
func_name = fdefault.__name__
def register(key, func=None, override=False):
"""Register function to be the dispatch function.
Parameters
----------
key : str or list of str
The key to be registered.
func : function
The function to be registered.
override : bool
Whether override existing registeration.
Returns
-------
The register function is necessary.
"""
def _do_reg(myf):
key_list = [key] if isinstance(key, str) else key
for k in key_list:
if k in dispatch_dict and not override:
raise ValueError(
"Key is already registered for %s" % func_name)
dispatch_dict[k] = myf
return myf
if func:
return _do_reg(myf)
return _do_reg
def dispatch_func(func, *args, **kwargs):
"""The wrapped dispath function"""
target = current_target()
if target is None:
return func(*args, **kwargs)
for k in target.keys:
if k in dispatch_dict:
return dispatch_dict[k](*args, **kwargs)
return func(*args, **kwargs)
fdecorate = decorate(fdefault, dispatch_func)
fdecorate.register = register
return fdecorate
def cuda(options=None):
"""Returns a cuda target.
Parameters
----------
options : list of str
Additional options
"""
return Target("cuda", options)
def rocm(options=None):
"""Returns a ROCM target.
Parameters
----------
options : list of str
Additional options
"""
return Target("rocm", options)
def rasp(options=None):
"""Returns a rasp target.
Parameters
----------
options : list of str
Additional options
"""
opts = ["-device=rasp",
"-mtriple=armv7l-none-linux-gnueabihf",
"-mcpu=cortex-a53",
"-mattr=+neon"]
opts = _merge_opts(opts, options)
return Target("llvm", opts)
def create(target_str):
"""Get a target given target string.
Parameters
----------
target_str : str
The target string.
Returns
-------
target : Target
The target object
Target.default_target = Target('x86') Note
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(target_str, Target):
return target_str
if not isinstance(target_str, str):
raise ValueError("target_str has to be string type")
arr = target_str.split()
# Parse device option
device_name = ""
for item in arr[1:]:
if item.startswith("-device="):
device_name = item.split("=")[1]
if device_name == "rasp":
return rasp(arr[1:])
return Target(arr[0], arr[1:])
def x86():
"""Returns a x86 target."""
return Target('x86')
def cuda(): def current_target(allow_none=True):
"""Returns a cuda target.""" """Returns the current target.
return Target('cuda')
def rasp(): Parameters
"""Returns a rasp target.""" ----------
return Target('rasp') allow_none : bool
Whether allow the current target to be none
def current_target(): Raises
"""Returns the current target.""" ------
return Target.default_target ValueError if current target is not set.
"""
if Target.current:
return Target.current
if not allow_none:
raise RuntimeError(
"Requires a current target in generic function, but it is not set. "
"Please set it using `with TargetObject:`")
return Target.current
...@@ -82,6 +82,8 @@ GetLLVMTargetMachine(const std::string& target_str, ...@@ -82,6 +82,8 @@ GetLLVMTargetMachine(const std::string& target_str,
} else { } else {
LOG(FATAL) << "invalid -mfloat-abi option " << value; LOG(FATAL) << "invalid -mfloat-abi option " << value;
} }
} else if (key == "-device") {
// pass
} else { } else {
LOG(FATAL) << "unknown option " << key; LOG(FATAL) << "unknown option " << key;
} }
......
...@@ -68,7 +68,8 @@ def test_gemm(): ...@@ -68,7 +68,8 @@ def test_gemm():
print("skip because %s is not enabled.." % device) print("skip because %s is not enabled.." % device)
return return
f = tvm.build(s, [A, B, C], device) with tvm.target.create(device):
f = tvm.build(s, [A, B, C])
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
# launch the kernel. # launch the kernel.
n = nn n = nn
......
import tvm
@tvm.target.generic_func
def mygeneric(data):
# default generic function
return data + 1
@mygeneric.register(["cuda", "gpu"])
def cuda_func(data):
return data + 2
@mygeneric.register("rocm")
def rocm_func(data):
return data + 3
@mygeneric.register("cpu")
def rocm_func(data):
return data + 10
def test_target_dispatch():
with tvm.target.cuda():
assert mygeneric(1) == 3
with tvm.target.rocm():
assert mygeneric(1) == 4
with tvm.target.create("cuda"):
assert mygeneric(1) == 3
with tvm.target.rasp():
assert mygeneric(1) == 11
with tvm.target.create("metal"):
assert mygeneric(1) == 3
try:
mygeneric(0)
raise RuntimeError("not reached")
except RuntimeError:
pass
if __name__ == "__main__":
test_target_dispatch()
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import tvm import tvm
from .. import util from .. import util
from .. import tag from .. import tag
from .. import generic
def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern""" """Schedule conv2d for specific feature_in_out_filter pattern"""
...@@ -483,6 +484,8 @@ def schedule_conv2d_small_batch(outs): ...@@ -483,6 +484,8 @@ def schedule_conv2d_small_batch(outs):
traverse(outs[0].op) traverse(outs[0].op)
return s return s
@generic.schedule_conv2d_nchw.register(["cuda", "gpu"])
def schedule_conv2d_nchw(outs): def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw. """Schedule for conv2d_nchw.
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import tag from .. import tag
from .. import generic
@generic.schedule_dense.register(["cuda", "gpu"])
def schedule_dense(outs): def schedule_dense(outs):
"""Schedule for dense operator. """Schedule for dense operator.
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
import tvm import tvm
from ..util import get_const_tuple from ..util import get_const_tuple
from .. import tag from .. import tag
from .. import generic
@generic.schedule_depthwise_conv2d_nchw.register(["cuda", "gpu"])
def schedule_depthwise_conv2d_nchw(outs): def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for depthwise_conv2d nchw forward. """Schedule for depthwise_conv2d nchw forward.
......
# pylint: disable=invalid-name, unused-variable, # pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator""" """Schedule for composition of injective operator"""
import tvm import tvm
from .. import generic
def _schedule_injective(op, sch): def _schedule_injective(op, sch):
x = op.output(0) x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis) fused = sch[x].fuse(*sch[x].op.axis)
num_thread = 512 target = tvm.target.current_target()
target = target if target else tvm.target.cuda()
num_thread = target.max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread) bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x")) sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x")) sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch return sch
@generic.schedule_injective.register(["cuda", "gpu"])
def schedule_injective(outs): def schedule_injective(outs):
"""Schedule for injective op. """Schedule for injective op.
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
"""Schedule for pooling operators""" """Schedule for pooling operators"""
import tvm import tvm
from .. import tag from .. import tag
from .. import generic
@generic.schedule_global_pool.register(["cuda", "gpu"])
def schedule_global_pool(outs): def schedule_global_pool(outs):
"""Schedule for global_pool. """Schedule for global_pool.
...@@ -63,6 +65,7 @@ def schedule_global_pool(outs): ...@@ -63,6 +65,7 @@ def schedule_global_pool(outs):
return s return s
@generic.schedule_pool.register(["cuda", "gpu"])
def schedule_pool(outs): def schedule_pool(outs):
"""Schedule for pool. """Schedule for pool.
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import tag from .. import tag
from .. import generic
def _schedule_reduce(op, sch, is_idx_reduce=False): def _schedule_reduce(op, sch, is_idx_reduce=False):
if is_idx_reduce: if is_idx_reduce:
...@@ -62,6 +63,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): ...@@ -62,6 +63,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
return sch return sch
@generic.schedule_reduce.register(["cuda", "gpu"])
def schedule_reduce(outs): def schedule_reduce(outs):
"""Schedule for inject->reduce->bcast ops. """Schedule for inject->reduce->bcast ops.
......
# pylint: disable=invalid-name, unused-variable, trailing-whitespace # pylint: disable=invalid-name, unused-variable, trailing-whitespace
"""Schedule for softmax operator""" """Schedule for softmax operator"""
import tvm import tvm
from .. import generic
@generic.schedule_softmax.register(["cuda", "gpu"])
def schedule_softmax(outs): def schedule_softmax(outs):
"""Schedule for softmax op. """Schedule for softmax op.
......
# pylint: disable=wildcard-import
"""Generic declaration and schedules.
This is a recommended way of using TOPI API.
To use the generic schedule function, user must set
the current target scope using with block. See also :any:`tvm.target`
Example
-------
.. code-block:: python
# create schedule that dispatches to topi.cuda.schedule_injective
with tvm.target.create("cuda"):
s = tvm.generic.schedule_injective(outs)
"""
from __future__ import absolute_import as _abs
from .nn import *
from .injective import *
# pylint: disable=invalid-name
"""generic declaration and schedules."""
from __future__ import absolute_import as _abs
import tvm
@tvm.target.generic_func
def schedule_injective(outs):
"""Schedule for injective op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm":
raise RuntimeError("schedule_injective not registered for '%s'" % target)
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
"""Generic nn operators"""
from __future__ import absolute_import as _abs
import tvm
def _default_schedule(outs, auto_inline):
"""Default schedule for llvm."""
target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm":
raise RuntimeError("schedule_pool not registered for '%s'" % target)
s = tvm.create_schedule([x.op for x in outs])
if auto_inline:
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
@tvm.target.generic_func
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d nchow
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for conv2d nchow
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_reduce(outs):
"""Schedule for reduction
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, True)
@tvm.target.generic_func
def schedule_softmax(outs):
"""Schedule for softmax
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_dense(outs):
"""Schedule for dense
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_pool(outs):
"""Schedule for pool
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_global_pool(outs):
"""Schedule for global pool
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
# pylint: disable=invalid-name, unused-variable, too-many-locals # pylint: disable=invalid-name, unused-variable, too-many-locals, unused-argument
"""Conv2D operators""" """Conv2D operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple from collections import namedtuple
import tvm import tvm
from tvm import target as _target
from .pad import pad from .pad import pad
from .util import get_pad_tuple from .util import get_pad_tuple
from ..util import simplify from ..util import simplify
...@@ -51,9 +50,7 @@ _WORKLOADS = [ ...@@ -51,9 +50,7 @@ _WORKLOADS = [
# platform specific schedule # platform specific schedule
_CONV_SCHEDULE = {} _CONV_SCHEDULE = {}
# platform specific declaration @tvm.target.generic_func
_CONV_DECLARATION = {}
def conv2d(data, kernel, stride, padding, layout='NCHW'): def conv2d(data, kernel, stride, padding, layout='NCHW'):
"""Conv2D operator. """Conv2D operator.
...@@ -80,10 +77,6 @@ def conv2d(data, kernel, stride, padding, layout='NCHW'): ...@@ -80,10 +77,6 @@ def conv2d(data, kernel, stride, padding, layout='NCHW'):
4-D with shape [batch, out_channel, out_height, out_width] 4-D with shape [batch, out_channel, out_height, out_width]
""" """
# search platform specific declaration first # search platform specific declaration first
target = _target.current_target()
if target in _CONV_DECLARATION:
return _CONV_DECLARATION[target](data, kernel, stride, padding, layout)
# default declaration # default declaration
if layout == 'NCHW': if layout == 'NCHW':
return conv2d_nchw(data, kernel, stride, padding) return conv2d_nchw(data, kernel, stride, padding)
...@@ -105,15 +98,15 @@ def _get_workload(data, kernel, stride, padding): ...@@ -105,15 +98,15 @@ def _get_workload(data, kernel, stride, padding):
return Workload(IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) return Workload(IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _get_schedule(wkl, target=None): @tvm.target.generic_func
def _get_schedule(wkl):
# pylint: disable=unreachable
""" Get the platform specific schedule. """ """ Get the platform specific schedule. """
if target is None: target = tvm.target.current_target()
target = _target.current_target() raise RuntimeError(
else: "No schedule for current target:{}".format(target))
target = _target.Target(target) # This return has no use, merely to supress pylint warning
assert target in _CONV_SCHEDULE, "no schedule for such target: {}".format(target) return wkl
return _CONV_SCHEDULE[target](wkl)
def _spatial_pack(data, kernel, stride, padding): def _spatial_pack(data, kernel, stride, padding):
""" Compute convolution with pack on spatial axes. """ """ Compute convolution with pack on spatial axes. """
......
...@@ -4,11 +4,12 @@ from __future__ import absolute_import as _abs ...@@ -4,11 +4,12 @@ from __future__ import absolute_import as _abs
import tvm import tvm
from tvm import target as _target from tvm import target as _target
from .. import tag from .. import tag
from ..nn.conv2d import conv2d, _get_schedule
from ..nn.conv2d import SpatialPack, Im2ColPack from ..nn.conv2d import SpatialPack, Im2ColPack
from ..nn.conv2d import _CONV_DECLARATION, _CONV_SCHEDULE
from ..nn.conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC from ..nn.conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC
from ..nn.conv2d import _get_workload, _get_schedule from ..nn.conv2d import _get_workload
from ..nn.util import infer_pad, infer_stride from ..nn.util import infer_pad, infer_stride
from .. import generic
_SCHEDULES = [ _SCHEDULES = [
SpatialPack(1, 8, 4, 1, 4, True), SpatialPack(1, 8, 4, 1, 4, True),
...@@ -36,6 +37,7 @@ _SCHEDULES = [ ...@@ -36,6 +37,7 @@ _SCHEDULES = [
Im2ColPack(7, 4, 1, 4, True), Im2ColPack(7, 4, 1, 4, True),
] ]
@_get_schedule.register("rasp")
def _schedule_conv2d(wkl): def _schedule_conv2d(wkl):
if wkl not in _WORKLOADS: if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl)) raise ValueError("no schedule for such workload: {}".format(wkl))
...@@ -43,8 +45,8 @@ def _schedule_conv2d(wkl): ...@@ -43,8 +45,8 @@ def _schedule_conv2d(wkl):
sch = _SCHEDULES[idx] sch = _SCHEDULES[idx]
return sch return sch
_CONV_SCHEDULE[_target.rasp()] = _schedule_conv2d
@conv2d.register("rasp")
def _declaration_conv2d(data, kernel, stride, padding, layout): def _declaration_conv2d(data, kernel, stride, padding, layout):
assert layout == 'NCHW', "only support NCHW convolution on rasp" assert layout == 'NCHW', "only support NCHW convolution on rasp"
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp" assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
...@@ -52,7 +54,6 @@ def _declaration_conv2d(data, kernel, stride, padding, layout): ...@@ -52,7 +54,6 @@ def _declaration_conv2d(data, kernel, stride, padding, layout):
sch = _get_schedule(wkl) sch = _get_schedule(wkl)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding) return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding)
_CONV_DECLARATION[_target.rasp()] = _declaration_conv2d
def _schedule_spatial_conv2d(s, data, data_pad, data_vec, def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
kernel, kernel_vec, kernel, kernel_vec,
...@@ -64,7 +65,9 @@ def _schedule_spatial_conv2d(s, data, data_pad, data_vec, ...@@ -64,7 +65,9 @@ def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
else: else:
stride = infer_stride(data_pad, kernel, output) stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding) wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl, 'rasp')
with tvm.target.rasp():
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width H, W = wkl.height, wkl.width
CI, CO = wkl.in_filter, wkl.out_filter CI, CO = wkl.in_filter, wkl.out_filter
...@@ -170,7 +173,9 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, ...@@ -170,7 +173,9 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
else: else:
stride = infer_stride(data_pad, kernel, output) stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding) wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl, 'rasp')
with _target.rasp():
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width H, W = wkl.height, wkl.width
CI = wkl.in_filter CI = wkl.in_filter
...@@ -275,6 +280,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, ...@@ -275,6 +280,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
return s return s
@generic.schedule_conv2d_nchw.register(["cpu", "rasp"])
def schedule_conv2d(outs): def schedule_conv2d(outs):
"""Create schedule for tensors""" """Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
......
...@@ -5,7 +5,7 @@ from collections import namedtuple ...@@ -5,7 +5,7 @@ from collections import namedtuple
import tvm import tvm
from .. import tag from .. import tag
from ..nn.util import infer_pad, infer_stride, get_pad_tuple from ..nn.util import infer_pad, infer_stride, get_pad_tuple
from .. import generic
_Workload = namedtuple('Workload', _Workload = namedtuple('Workload',
['height', 'width', 'channel', 'multiplier', ['height', 'width', 'channel', 'multiplier',
...@@ -145,7 +145,7 @@ def _schedule(s, data, data_pad, kernel, output, last): ...@@ -145,7 +145,7 @@ def _schedule(s, data, data_pad, kernel, output, last):
return s return s
@generic.schedule_depthwise_conv2d_nchw.register(["cpu", "rasp"])
def schedule_depthwise_conv2d(outs): def schedule_depthwise_conv2d(outs):
"""Schedule for depthwise_conv2d nchw forward. """Schedule for depthwise_conv2d nchw forward.
......
...@@ -8,16 +8,16 @@ def verify_broadcast_to_ele(in_shape, out_shape): ...@@ -8,16 +8,16 @@ def verify_broadcast_to_ele(in_shape, out_shape):
# Build the logic and compile the function # Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.broadcast_to(A, out_shape) B = topi.broadcast_to(A, out_shape)
s = topi.cuda.schedule_broadcast(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="broadcast_to") foo = tvm.build(s, [A, B], device, name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype) data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape) out_npy = np.broadcast_to(data_npy, out_shape)
data_nd = tvm.nd.array(data_npy, ctx) data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx) out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
for _ in range(1): for _ in range(1):
...@@ -48,11 +48,12 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): ...@@ -48,11 +48,12 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
C = topi.broadcast_minimum(A, B) C = topi.broadcast_minimum(A, B)
else: else:
raise NotImplementedError raise NotImplementedError
s = topi.cuda.schedule_broadcast(C)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ) foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype) lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
......
...@@ -14,8 +14,8 @@ def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, paddin ...@@ -14,8 +14,8 @@ def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, paddin
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d(A, W, stride, padding) B = topi.nn.conv2d(A, W, stride, padding)
s = topi.generic.schedule_conv2d_nchw([B])
s = topi.rasp.schedule_conv2d([B])
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
dtype = A.dtype dtype = A.dtype
......
...@@ -14,8 +14,6 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -14,8 +14,6 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d_nchw(A, W, stride, padding) B = topi.nn.conv2d_nchw(A, W, stride, padding)
C = topi.nn.relu(B) C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_nchw([B])
s2 = topi.cuda.schedule_conv2d_nchw([C])
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
...@@ -35,6 +33,9 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -35,6 +33,9 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s1 = topi.generic.schedule_conv2d_nchw([B])
s2 = topi.generic.schedule_conv2d_nchw([C])
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
......
...@@ -12,7 +12,6 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): ...@@ -12,7 +12,6 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
C = tvm.placeholder((out_dim,), name='C') C = tvm.placeholder((out_dim,), name='C')
D = topi.nn.dense(A, B, C if use_bias else None) D = topi.nn.dense(A, B, C if use_bias else None)
D = topi.nn.relu(D) D = topi.nn.relu(D)
s = topi.cuda.schedule_dense(D)
dtype = A.dtype dtype = A.dtype
# use memoize to pickle the test data for next time use # use memoize to pickle the test data for next time use
...@@ -33,6 +32,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): ...@@ -33,6 +32,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_dense(D)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx) b = tvm.nd.array(b_np, ctx)
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
from scipy import signal from scipy import signal
from topi.util import get_const_tuple from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nhwc
def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
...@@ -21,15 +21,18 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -21,15 +21,18 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, stride=[stride_h, stride_w], padding=padding) DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, stride=[stride_h, stride_w], padding=padding)
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift) Relu = topi.nn.relu(ScaleShift)
# schedule
s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = schedule_depthwise_conv2d_nchw(Relu)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
# schedule
s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
# build the kernels # build the kernels
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
...@@ -88,7 +91,7 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -88,7 +91,7 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
in_width = in_height in_width = in_height
filter_channel = in_channel filter_channel = in_channel
......
...@@ -12,7 +12,6 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type): ...@@ -12,7 +12,6 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type):
A = tvm.placeholder((n, ic, ih, iw), name='A') A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, pool_type=pool_type) B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, pool_type=pool_type)
B = topi.nn.relu(B) B = topi.nn.relu(B)
s = topi.cuda.schedule_pool(B)
dtype = A.dtype dtype = A.dtype
a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype) a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype)
...@@ -36,6 +35,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type): ...@@ -36,6 +35,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_pool(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
...@@ -57,7 +58,6 @@ def verify_global_pool(n, c, h, w, pool_type): ...@@ -57,7 +58,6 @@ def verify_global_pool(n, c, h, w, pool_type):
A = tvm.placeholder((n, c, h, w), name='A') A = tvm.placeholder((n, c, h, w), name='A')
B = topi.nn.global_pool(A, pool_type=pool_type) B = topi.nn.global_pool(A, pool_type=pool_type)
B = topi.nn.relu(B) B = topi.nn.relu(B)
s = topi.cuda.schedule_global_pool(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
if pool_type == 'avg': if pool_type == 'avg':
...@@ -70,6 +70,8 @@ def verify_global_pool(n, c, h, w, pool_type): ...@@ -70,6 +70,8 @@ def verify_global_pool(n, c, h, w, pool_type):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_global_pool(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
......
...@@ -45,11 +45,13 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -45,11 +45,13 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
out_dtype = "int32" out_dtype = "int32"
else: else:
raise NotImplementedError raise NotImplementedError
s = topi.cuda.schedule_reduce(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_reduce(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="sum") foo = tvm.build(s, [A, B], device, name="sum")
# Test # Test
......
...@@ -12,8 +12,6 @@ def verify_softmax(m, n): ...@@ -12,8 +12,6 @@ def verify_softmax(m, n):
s = tvm.create_schedule([B.op]) s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True) tvm.lower(s, [A, B], simple_mode=True)
s = topi.cuda.schedule_softmax(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.softmax_python(a_np) b_np = topi.testing.softmax_python(a_np)
...@@ -21,6 +19,8 @@ def verify_softmax(m, n): ...@@ -21,6 +19,8 @@ def verify_softmax(m, n):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
...@@ -43,7 +43,6 @@ def verify_log_softmax(m, n): ...@@ -43,7 +43,6 @@ def verify_log_softmax(m, n):
s = tvm.create_schedule([B.op]) s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True) tvm.lower(s, [A, B], simple_mode=True)
s = topi.cuda.schedule_softmax(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.log_softmax_python(a_np) b_np = topi.testing.log_softmax_python(a_np)
...@@ -52,6 +51,8 @@ def verify_log_softmax(m, n): ...@@ -52,6 +51,8 @@ def verify_log_softmax(m, n):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
......
...@@ -6,11 +6,12 @@ import topi ...@@ -6,11 +6,12 @@ import topi
def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.expand_dims(A, axis, num_newaxis) B = topi.expand_dims(A, axis, num_newaxis)
s = topi.cuda.schedule_broadcast(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="expand_dims") foo = tvm.build(s, [A, B], device, name="expand_dims")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype) data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
...@@ -23,17 +24,18 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): ...@@ -23,17 +24,18 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def verify_tranpose(in_shape, axes): def verify_tranpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.transpose(A, axes) B = topi.transpose(A, axes)
s = topi.cuda.schedule_injective(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="tranpose") foo = tvm.build(s, [A, B], device, name="tranpose")
data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype) data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
...@@ -46,16 +48,17 @@ def verify_tranpose(in_shape, axes): ...@@ -46,16 +48,17 @@ def verify_tranpose(in_shape, axes):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def verify_reshape(src_shape, dst_shape): def verify_reshape(src_shape, dst_shape):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
B = topi.reshape(A, dst_shape) B = topi.reshape(A, dst_shape)
s = topi.cuda.schedule_injective(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="reshape") foo = tvm.build(s, [A, B], device, name="reshape")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
...@@ -68,16 +71,17 @@ def verify_reshape(src_shape, dst_shape): ...@@ -68,16 +71,17 @@ def verify_reshape(src_shape, dst_shape):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def verify_squeeze(src_shape, axis): def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
B = topi.squeeze(A, axis=axis) B = topi.squeeze(A, axis=axis)
s = topi.cuda.schedule_injective(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="squeeze") foo = tvm.build(s, [A, B], device, name="squeeze")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
...@@ -94,18 +98,19 @@ def verify_squeeze(src_shape, axis): ...@@ -94,18 +98,19 @@ def verify_squeeze(src_shape, axis):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def verify_concatenate(shapes, axis): def verify_concatenate(shapes, axis):
tensor_l = [] tensor_l = []
for i, shape in enumerate(shapes): for i, shape in enumerate(shapes):
tensor_l.append(tvm.placeholder(shape, name="A" + str(i))) tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis) out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis)
s = topi.cuda.schedule_injective(out_tensor)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate") foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes] data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
...@@ -118,16 +123,17 @@ def verify_concatenate(shapes, axis): ...@@ -118,16 +123,17 @@ def verify_concatenate(shapes, axis):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def verify_split(src_shape, indices_or_sections, axis): def verify_split(src_shape, indices_or_sections, axis):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
tensor_l = topi.split(A, indices_or_sections, axis=axis) tensor_l = topi.split(A, indices_or_sections, axis=axis)
s = topi.cuda.schedule_injective(tensor_l)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_injective(tensor_l)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A] + tensor_l, device, name="split") foo = tvm.build(s, [A] + tensor_l, device, name="split")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
...@@ -142,7 +148,7 @@ def verify_split(src_shape, indices_or_sections, axis): ...@@ -142,7 +148,7 @@ def verify_split(src_shape, indices_or_sections, axis):
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def test_expand_dims(): def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2) verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), (1, 3, 10), -3, 1) verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
...@@ -190,4 +196,3 @@ if __name__ == "__main__": ...@@ -190,4 +196,3 @@ if __name__ == "__main__":
test_squeeze() test_squeeze()
test_concatenate() test_concatenate()
test_split() test_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment