Commit eb761f36 by Tianqi Chen Committed by GitHub

[Refactor] Introduce target generic dispatch system (#556)

* [TVM] Introduce target generic dispatch system

* fix target warning
parent c3cac464
...@@ -8,6 +8,7 @@ Python API ...@@ -8,6 +8,7 @@ Python API
intrin intrin
tensor tensor
schedule schedule
target
build build
module module
ndarray ndarray
......
tvm.target
----------
.. automodule:: tvm.target
.. autofunction:: tvm.target.generic_func
.. autoclass:: tvm.target.Target
:members:
.. autofunction:: tvm.target.cuda
.. autofunction:: tvm.target.rocm
.. autofunction:: tvm.target.rasp
.. autofunction:: tvm.target.create
...@@ -37,13 +37,11 @@ Index ...@@ -37,13 +37,11 @@ Index
.. autosummary:: .. autosummary::
topi.cuda.schedule_conv2d_nchw topi.generic.schedule_conv2d_nchw
topi.cuda.schedule_conv2d_hwcn topi.generic.schedule_depthwise_conv2d_nchw
topi.cuda.schedule_depthwise_conv2d_nchw topi.generic.schedule_reduce
topi.cuda.schedule_depthwise_conv2d_nhwc topi.generic.schedule_broadcast
topi.cuda.schedule_reduce topi.generic.schedule_injective
topi.cuda.schedule_broadcast
topi.cuda.schedule_injective
topi topi
~~~~ ~~~~
...@@ -75,14 +73,12 @@ topi.nn ...@@ -75,14 +73,12 @@ topi.nn
.. autofunction:: topi.nn.depthwise_conv2d_nhwc .. autofunction:: topi.nn.depthwise_conv2d_nhwc
topi.cuda topi.generic
~~~~~~~~~ ~~~~~~~~~~~~
.. automodule:: topi.cuda .. automodule:: topi.generic
.. autofunction:: topi.cuda.schedule_conv2d_nchw .. autofunction:: topi.generic.schedule_conv2d_nchw
.. autofunction:: topi.cuda.schedule_conv2d_hwcn .. autofunction:: topi.generic.schedule_depthwise_conv2d_nchw
.. autofunction:: topi.cuda.schedule_depthwise_conv2d_nchw .. autofunction:: topi.generic.schedule_reduce
.. autofunction:: topi.cuda.schedule_depthwise_conv2d_nhwc .. autofunction:: topi.generic.schedule_broadcast
.. autofunction:: topi.cuda.schedule_reduce .. autofunction:: topi.generic.schedule_injective
.. autofunction:: topi.cuda.schedule_broadcast
.. autofunction:: topi.cuda.schedule_injective
...@@ -56,11 +56,7 @@ def context(dev_type, dev_id=0): ...@@ -56,11 +56,7 @@ def context(dev_type, dev_id=0):
assert tvm.context("cuda", 0) == tvm.gpu(0) assert tvm.context("cuda", 0) == tvm.gpu(0)
""" """
if isinstance(dev_type, string_types): if isinstance(dev_type, string_types):
if dev_type not in TVMContext.STR2MASK: dev_type = dev_type.split()[0]
if dev_type.find("nvptx") != -1:
dev_type = "cuda"
if dev_type.find("rocm") != -1:
dev_type = "rocm"
if dev_type not in TVMContext.STR2MASK: if dev_type not in TVMContext.STR2MASK:
raise ValueError("Unknown device type %s" % dev_type) raise ValueError("Unknown device type %s" % dev_type)
dev_type = TVMContext.STR2MASK[dev_type] dev_type = TVMContext.STR2MASK[dev_type]
......
...@@ -100,9 +100,12 @@ class TVMContext(ctypes.Structure): ...@@ -100,9 +100,12 @@ class TVMContext(ctypes.Structure):
12: 'ext_dev', 12: 'ext_dev',
} }
STR2MASK = { STR2MASK = {
'llvm': 1,
'stackvm': 1,
'cpu': 1, 'cpu': 1,
'gpu': 2, 'gpu': 2,
'cuda': 2, 'cuda': 2,
'nvptx': 2,
'cl': 4, 'cl': 4,
'opencl': 4, 'opencl': 4,
'metal': 8, 'metal': 8,
......
...@@ -15,6 +15,7 @@ from . import container ...@@ -15,6 +15,7 @@ from . import container
from . import module from . import module
from . import codegen from . import codegen
from . import ndarray from . import ndarray
from . import target as _target
class BuildConfig(object): class BuildConfig(object):
"""Configuration scope to set a build config option. """Configuration scope to set a build config option.
...@@ -238,7 +239,7 @@ def lower(sch, ...@@ -238,7 +239,7 @@ def lower(sch,
def build(sch, def build(sch,
args=None, args=None,
target="llvm", target=None,
target_host=None, target_host=None,
name="default_function", name="default_function",
binds=None): binds=None):
...@@ -252,36 +253,10 @@ def build(sch, ...@@ -252,36 +253,10 @@ def build(sch,
args : list of Buffer or Tensor or Var, optional args : list of Buffer or Tensor or Var, optional
The argument lists to the function. The argument lists to the function.
target : str, optional target : str or :any:`tvm.target.Target`, optional
The target and option of the compilation. The target and option of the compilation.
When the target is llvm, you can set options like:
- **-mtriple=<target triple>** or **-target** target_host : str or :any:`tvm.target.Target` optional
Specify the target triple, which is useful for cross
compilation.
- **-mcpu=<cpuname>**
Specify a specific chip in the current architecture to
generate code for. By default this is infered from the
target triple and autodetected to the current architecture.
- **-mattr=a1,+a2,-a3,...**
Override or control specific attributes of the target,
such as whether SIMD operations are enabled or not. The
default set of attributes is set by the current CPU.
- **-system-lib**
Build TVM system library module. System lib is a global module that contains
self registered functions in program startup. User can get the module using
:any:`tvm.module.system_lib`.
It is useful in environments where dynamic loading api like dlopen is banned.
The system lib will be available as long as the result code is linked by the program.
target_host : str, optional
Host compilation target, if target is device. Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA, When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver we also need host(CPU) side code to interact with the driver
...@@ -301,6 +276,10 @@ def build(sch, ...@@ -301,6 +276,10 @@ def build(sch,
------- -------
f : Function, or pair of functions f : Function, or pair of functions
The result function. The result function.
Note
----
See the note on :any:`tvm.target` on target string format.
""" """
if isinstance(sch, schedule.Schedule): if isinstance(sch, schedule.Schedule):
if args is None: if args is None:
...@@ -325,6 +304,9 @@ def build(sch, ...@@ -325,6 +304,9 @@ def build(sch,
if x.name in fname_set: if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name) raise ValueError("Duplicate function name %s" % x.name)
target = _target.current_target() if target is None else target
target = _target.create(target) if target else _target.create("llvm")
fhost = [] fhost = []
fdevice = [] fdevice = []
for func in flist: for func in flist:
...@@ -332,7 +314,7 @@ def build(sch, ...@@ -332,7 +314,7 @@ def build(sch,
if BuildConfig.current.detect_global_barrier: if BuildConfig.current.detect_global_barrier:
func = ir_pass.ThreadSync(func, "global") func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared") func = ir_pass.ThreadSync(func, "shared")
warp_size = 32 if target == "cuda" else 1 warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size) func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(func)] fsplits = [s for s in ir_pass.SplitHostDevice(func)]
fhost.append(fsplits[0]) fhost.append(fsplits[0])
...@@ -345,29 +327,28 @@ def build(sch, ...@@ -345,29 +327,28 @@ def build(sch,
else: else:
raise ValueError("unknown function type %d" % func.func_type) raise ValueError("unknown function type %d" % func.func_type)
if not target.startswith("llvm") and target not in ("stackvm", "ext_dev") and not fdevice: if "gpu" in target.keys and not fdevice:
warnings.warn( warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target) "Specified target %s, but cannot find device code, did you do bind?" % target)
device = "cpu" if target.startswith("llvm") or target == "stackvm" else target device_type = ndarray.context(target.target_name, 0).device_type
device_type = ndarray.context(device, 0).device_type
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost] fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
if not target_host: if not target_host:
if device == "cpu": if device_type == ndarray.cpu(0).device_type:
target_host = target target_host = target
assert not fdevice assert not fdevice
else: else:
target_host = "llvm" if module.enabled("llvm") else "stackvm" target_host = "llvm" if module.enabled("llvm") else "stackvm"
target_host = _target.create(target_host)
target_device = target target_device = target
fdevice = [ir_pass.LowerIntrin(x, target_device) for x in fdevice] fdevice = [ir_pass.LowerIntrin(x, target_device.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host) for x in fhost] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mhost = codegen.build_module(fhost, target_host) mhost = codegen.build_module(fhost, str(target_host))
if fdevice: if fdevice:
mdev = codegen.build_module(fdevice, target_device) mdev = codegen.build_module(fdevice, str(target_device))
mhost.import_module(mdev) mhost.import_module(mdev)
return mhost return mhost
...@@ -82,6 +82,8 @@ GetLLVMTargetMachine(const std::string& target_str, ...@@ -82,6 +82,8 @@ GetLLVMTargetMachine(const std::string& target_str,
} else { } else {
LOG(FATAL) << "invalid -mfloat-abi option " << value; LOG(FATAL) << "invalid -mfloat-abi option " << value;
} }
} else if (key == "-device") {
// pass
} else { } else {
LOG(FATAL) << "unknown option " << key; LOG(FATAL) << "unknown option " << key;
} }
......
...@@ -68,7 +68,8 @@ def test_gemm(): ...@@ -68,7 +68,8 @@ def test_gemm():
print("skip because %s is not enabled.." % device) print("skip because %s is not enabled.." % device)
return return
f = tvm.build(s, [A, B, C], device) with tvm.target.create(device):
f = tvm.build(s, [A, B, C])
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
# launch the kernel. # launch the kernel.
n = nn n = nn
......
import tvm
@tvm.target.generic_func
def mygeneric(data):
# default generic function
return data + 1
@mygeneric.register(["cuda", "gpu"])
def cuda_func(data):
return data + 2
@mygeneric.register("rocm")
def rocm_func(data):
return data + 3
@mygeneric.register("cpu")
def rocm_func(data):
return data + 10
def test_target_dispatch():
with tvm.target.cuda():
assert mygeneric(1) == 3
with tvm.target.rocm():
assert mygeneric(1) == 4
with tvm.target.create("cuda"):
assert mygeneric(1) == 3
with tvm.target.rasp():
assert mygeneric(1) == 11
with tvm.target.create("metal"):
assert mygeneric(1) == 3
try:
mygeneric(0)
raise RuntimeError("not reached")
except RuntimeError:
pass
if __name__ == "__main__":
test_target_dispatch()
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import tvm import tvm
from .. import util from .. import util
from .. import tag from .. import tag
from .. import generic
def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern""" """Schedule conv2d for specific feature_in_out_filter pattern"""
...@@ -483,6 +484,8 @@ def schedule_conv2d_small_batch(outs): ...@@ -483,6 +484,8 @@ def schedule_conv2d_small_batch(outs):
traverse(outs[0].op) traverse(outs[0].op)
return s return s
@generic.schedule_conv2d_nchw.register(["cuda", "gpu"])
def schedule_conv2d_nchw(outs): def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw. """Schedule for conv2d_nchw.
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import tag from .. import tag
from .. import generic
@generic.schedule_dense.register(["cuda", "gpu"])
def schedule_dense(outs): def schedule_dense(outs):
"""Schedule for dense operator. """Schedule for dense operator.
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
import tvm import tvm
from ..util import get_const_tuple from ..util import get_const_tuple
from .. import tag from .. import tag
from .. import generic
@generic.schedule_depthwise_conv2d_nchw.register(["cuda", "gpu"])
def schedule_depthwise_conv2d_nchw(outs): def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for depthwise_conv2d nchw forward. """Schedule for depthwise_conv2d nchw forward.
......
# pylint: disable=invalid-name, unused-variable, # pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator""" """Schedule for composition of injective operator"""
import tvm import tvm
from .. import generic
def _schedule_injective(op, sch): def _schedule_injective(op, sch):
x = op.output(0) x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis) fused = sch[x].fuse(*sch[x].op.axis)
num_thread = 512 target = tvm.target.current_target()
target = target if target else tvm.target.cuda()
num_thread = target.max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread) bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x")) sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x")) sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch return sch
@generic.schedule_injective.register(["cuda", "gpu"])
def schedule_injective(outs): def schedule_injective(outs):
"""Schedule for injective op. """Schedule for injective op.
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
"""Schedule for pooling operators""" """Schedule for pooling operators"""
import tvm import tvm
from .. import tag from .. import tag
from .. import generic
@generic.schedule_global_pool.register(["cuda", "gpu"])
def schedule_global_pool(outs): def schedule_global_pool(outs):
"""Schedule for global_pool. """Schedule for global_pool.
...@@ -63,6 +65,7 @@ def schedule_global_pool(outs): ...@@ -63,6 +65,7 @@ def schedule_global_pool(outs):
return s return s
@generic.schedule_pool.register(["cuda", "gpu"])
def schedule_pool(outs): def schedule_pool(outs):
"""Schedule for pool. """Schedule for pool.
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import tag from .. import tag
from .. import generic
def _schedule_reduce(op, sch, is_idx_reduce=False): def _schedule_reduce(op, sch, is_idx_reduce=False):
if is_idx_reduce: if is_idx_reduce:
...@@ -62,6 +63,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): ...@@ -62,6 +63,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
return sch return sch
@generic.schedule_reduce.register(["cuda", "gpu"])
def schedule_reduce(outs): def schedule_reduce(outs):
"""Schedule for inject->reduce->bcast ops. """Schedule for inject->reduce->bcast ops.
......
# pylint: disable=invalid-name, unused-variable, trailing-whitespace # pylint: disable=invalid-name, unused-variable, trailing-whitespace
"""Schedule for softmax operator""" """Schedule for softmax operator"""
import tvm import tvm
from .. import generic
@generic.schedule_softmax.register(["cuda", "gpu"])
def schedule_softmax(outs): def schedule_softmax(outs):
"""Schedule for softmax op. """Schedule for softmax op.
......
# pylint: disable=wildcard-import
"""Generic declaration and schedules.
This is a recommended way of using TOPI API.
To use the generic schedule function, user must set
the current target scope using with block. See also :any:`tvm.target`
Example
-------
.. code-block:: python
# create schedule that dispatches to topi.cuda.schedule_injective
with tvm.target.create("cuda"):
s = tvm.generic.schedule_injective(outs)
"""
from __future__ import absolute_import as _abs
from .nn import *
from .injective import *
# pylint: disable=invalid-name
"""generic declaration and schedules."""
from __future__ import absolute_import as _abs
import tvm
@tvm.target.generic_func
def schedule_injective(outs):
"""Schedule for injective op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm":
raise RuntimeError("schedule_injective not registered for '%s'" % target)
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
"""Generic nn operators"""
from __future__ import absolute_import as _abs
import tvm
def _default_schedule(outs, auto_inline):
"""Default schedule for llvm."""
target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm":
raise RuntimeError("schedule_pool not registered for '%s'" % target)
s = tvm.create_schedule([x.op for x in outs])
if auto_inline:
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
@tvm.target.generic_func
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d nchow
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for conv2d nchow
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_reduce(outs):
"""Schedule for reduction
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, True)
@tvm.target.generic_func
def schedule_softmax(outs):
"""Schedule for softmax
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_dense(outs):
"""Schedule for dense
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_pool(outs):
"""Schedule for pool
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_global_pool(outs):
"""Schedule for global pool
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
# pylint: disable=invalid-name, unused-variable, too-many-locals # pylint: disable=invalid-name, unused-variable, too-many-locals, unused-argument
"""Conv2D operators""" """Conv2D operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple from collections import namedtuple
import tvm import tvm
from tvm import target as _target
from .pad import pad from .pad import pad
from .util import get_pad_tuple from .util import get_pad_tuple
from ..util import simplify from ..util import simplify
...@@ -51,9 +50,7 @@ _WORKLOADS = [ ...@@ -51,9 +50,7 @@ _WORKLOADS = [
# platform specific schedule # platform specific schedule
_CONV_SCHEDULE = {} _CONV_SCHEDULE = {}
# platform specific declaration @tvm.target.generic_func
_CONV_DECLARATION = {}
def conv2d(data, kernel, stride, padding, layout='NCHW'): def conv2d(data, kernel, stride, padding, layout='NCHW'):
"""Conv2D operator. """Conv2D operator.
...@@ -80,10 +77,6 @@ def conv2d(data, kernel, stride, padding, layout='NCHW'): ...@@ -80,10 +77,6 @@ def conv2d(data, kernel, stride, padding, layout='NCHW'):
4-D with shape [batch, out_channel, out_height, out_width] 4-D with shape [batch, out_channel, out_height, out_width]
""" """
# search platform specific declaration first # search platform specific declaration first
target = _target.current_target()
if target in _CONV_DECLARATION:
return _CONV_DECLARATION[target](data, kernel, stride, padding, layout)
# default declaration # default declaration
if layout == 'NCHW': if layout == 'NCHW':
return conv2d_nchw(data, kernel, stride, padding) return conv2d_nchw(data, kernel, stride, padding)
...@@ -105,15 +98,15 @@ def _get_workload(data, kernel, stride, padding): ...@@ -105,15 +98,15 @@ def _get_workload(data, kernel, stride, padding):
return Workload(IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) return Workload(IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _get_schedule(wkl, target=None): @tvm.target.generic_func
def _get_schedule(wkl):
# pylint: disable=unreachable
""" Get the platform specific schedule. """ """ Get the platform specific schedule. """
if target is None: target = tvm.target.current_target()
target = _target.current_target() raise RuntimeError(
else: "No schedule for current target:{}".format(target))
target = _target.Target(target) # This return has no use, merely to supress pylint warning
assert target in _CONV_SCHEDULE, "no schedule for such target: {}".format(target) return wkl
return _CONV_SCHEDULE[target](wkl)
def _spatial_pack(data, kernel, stride, padding): def _spatial_pack(data, kernel, stride, padding):
""" Compute convolution with pack on spatial axes. """ """ Compute convolution with pack on spatial axes. """
......
...@@ -4,11 +4,12 @@ from __future__ import absolute_import as _abs ...@@ -4,11 +4,12 @@ from __future__ import absolute_import as _abs
import tvm import tvm
from tvm import target as _target from tvm import target as _target
from .. import tag from .. import tag
from ..nn.conv2d import conv2d, _get_schedule
from ..nn.conv2d import SpatialPack, Im2ColPack from ..nn.conv2d import SpatialPack, Im2ColPack
from ..nn.conv2d import _CONV_DECLARATION, _CONV_SCHEDULE
from ..nn.conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC from ..nn.conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC
from ..nn.conv2d import _get_workload, _get_schedule from ..nn.conv2d import _get_workload
from ..nn.util import infer_pad, infer_stride from ..nn.util import infer_pad, infer_stride
from .. import generic
_SCHEDULES = [ _SCHEDULES = [
SpatialPack(1, 8, 4, 1, 4, True), SpatialPack(1, 8, 4, 1, 4, True),
...@@ -36,6 +37,7 @@ _SCHEDULES = [ ...@@ -36,6 +37,7 @@ _SCHEDULES = [
Im2ColPack(7, 4, 1, 4, True), Im2ColPack(7, 4, 1, 4, True),
] ]
@_get_schedule.register("rasp")
def _schedule_conv2d(wkl): def _schedule_conv2d(wkl):
if wkl not in _WORKLOADS: if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl)) raise ValueError("no schedule for such workload: {}".format(wkl))
...@@ -43,8 +45,8 @@ def _schedule_conv2d(wkl): ...@@ -43,8 +45,8 @@ def _schedule_conv2d(wkl):
sch = _SCHEDULES[idx] sch = _SCHEDULES[idx]
return sch return sch
_CONV_SCHEDULE[_target.rasp()] = _schedule_conv2d
@conv2d.register("rasp")
def _declaration_conv2d(data, kernel, stride, padding, layout): def _declaration_conv2d(data, kernel, stride, padding, layout):
assert layout == 'NCHW', "only support NCHW convolution on rasp" assert layout == 'NCHW', "only support NCHW convolution on rasp"
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp" assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
...@@ -52,7 +54,6 @@ def _declaration_conv2d(data, kernel, stride, padding, layout): ...@@ -52,7 +54,6 @@ def _declaration_conv2d(data, kernel, stride, padding, layout):
sch = _get_schedule(wkl) sch = _get_schedule(wkl)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding) return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding)
_CONV_DECLARATION[_target.rasp()] = _declaration_conv2d
def _schedule_spatial_conv2d(s, data, data_pad, data_vec, def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
kernel, kernel_vec, kernel, kernel_vec,
...@@ -64,7 +65,9 @@ def _schedule_spatial_conv2d(s, data, data_pad, data_vec, ...@@ -64,7 +65,9 @@ def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
else: else:
stride = infer_stride(data_pad, kernel, output) stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding) wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl, 'rasp')
with tvm.target.rasp():
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width H, W = wkl.height, wkl.width
CI, CO = wkl.in_filter, wkl.out_filter CI, CO = wkl.in_filter, wkl.out_filter
...@@ -170,7 +173,9 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, ...@@ -170,7 +173,9 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
else: else:
stride = infer_stride(data_pad, kernel, output) stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding) wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl, 'rasp')
with _target.rasp():
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width H, W = wkl.height, wkl.width
CI = wkl.in_filter CI = wkl.in_filter
...@@ -275,6 +280,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, ...@@ -275,6 +280,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
return s return s
@generic.schedule_conv2d_nchw.register(["cpu", "rasp"])
def schedule_conv2d(outs): def schedule_conv2d(outs):
"""Create schedule for tensors""" """Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
......
...@@ -5,7 +5,7 @@ from collections import namedtuple ...@@ -5,7 +5,7 @@ from collections import namedtuple
import tvm import tvm
from .. import tag from .. import tag
from ..nn.util import infer_pad, infer_stride, get_pad_tuple from ..nn.util import infer_pad, infer_stride, get_pad_tuple
from .. import generic
_Workload = namedtuple('Workload', _Workload = namedtuple('Workload',
['height', 'width', 'channel', 'multiplier', ['height', 'width', 'channel', 'multiplier',
...@@ -145,7 +145,7 @@ def _schedule(s, data, data_pad, kernel, output, last): ...@@ -145,7 +145,7 @@ def _schedule(s, data, data_pad, kernel, output, last):
return s return s
@generic.schedule_depthwise_conv2d_nchw.register(["cpu", "rasp"])
def schedule_depthwise_conv2d(outs): def schedule_depthwise_conv2d(outs):
"""Schedule for depthwise_conv2d nchw forward. """Schedule for depthwise_conv2d nchw forward.
......
...@@ -8,16 +8,16 @@ def verify_broadcast_to_ele(in_shape, out_shape): ...@@ -8,16 +8,16 @@ def verify_broadcast_to_ele(in_shape, out_shape):
# Build the logic and compile the function # Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.broadcast_to(A, out_shape) B = topi.broadcast_to(A, out_shape)
s = topi.cuda.schedule_broadcast(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="broadcast_to") foo = tvm.build(s, [A, B], device, name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype) data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape) out_npy = np.broadcast_to(data_npy, out_shape)
data_nd = tvm.nd.array(data_npy, ctx) data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx) out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
for _ in range(1): for _ in range(1):
...@@ -48,11 +48,12 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): ...@@ -48,11 +48,12 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
C = topi.broadcast_minimum(A, B) C = topi.broadcast_minimum(A, B)
else: else:
raise NotImplementedError raise NotImplementedError
s = topi.cuda.schedule_broadcast(C)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ) foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype) lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
......
...@@ -14,8 +14,8 @@ def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, paddin ...@@ -14,8 +14,8 @@ def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, paddin
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d(A, W, stride, padding) B = topi.nn.conv2d(A, W, stride, padding)
s = topi.generic.schedule_conv2d_nchw([B])
s = topi.rasp.schedule_conv2d([B])
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
dtype = A.dtype dtype = A.dtype
......
...@@ -14,8 +14,6 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -14,8 +14,6 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d_nchw(A, W, stride, padding) B = topi.nn.conv2d_nchw(A, W, stride, padding)
C = topi.nn.relu(B) C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_nchw([B])
s2 = topi.cuda.schedule_conv2d_nchw([C])
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
...@@ -35,6 +33,9 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -35,6 +33,9 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s1 = topi.generic.schedule_conv2d_nchw([B])
s2 = topi.generic.schedule_conv2d_nchw([C])
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
......
...@@ -12,7 +12,6 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): ...@@ -12,7 +12,6 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
C = tvm.placeholder((out_dim,), name='C') C = tvm.placeholder((out_dim,), name='C')
D = topi.nn.dense(A, B, C if use_bias else None) D = topi.nn.dense(A, B, C if use_bias else None)
D = topi.nn.relu(D) D = topi.nn.relu(D)
s = topi.cuda.schedule_dense(D)
dtype = A.dtype dtype = A.dtype
# use memoize to pickle the test data for next time use # use memoize to pickle the test data for next time use
...@@ -33,6 +32,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): ...@@ -33,6 +32,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_dense(D)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx) b = tvm.nd.array(b_np, ctx)
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
from scipy import signal from scipy import signal
from topi.util import get_const_tuple from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nhwc
def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
...@@ -21,15 +21,18 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -21,15 +21,18 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, stride=[stride_h, stride_w], padding=padding) DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, stride=[stride_h, stride_w], padding=padding)
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift) Relu = topi.nn.relu(ScaleShift)
# schedule
s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = schedule_depthwise_conv2d_nchw(Relu)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
# schedule
s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
# build the kernels # build the kernels
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
...@@ -88,7 +91,7 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -88,7 +91,7 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
in_width = in_height in_width = in_height
filter_channel = in_channel filter_channel = in_channel
......
...@@ -12,7 +12,6 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type): ...@@ -12,7 +12,6 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type):
A = tvm.placeholder((n, ic, ih, iw), name='A') A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, pool_type=pool_type) B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, pool_type=pool_type)
B = topi.nn.relu(B) B = topi.nn.relu(B)
s = topi.cuda.schedule_pool(B)
dtype = A.dtype dtype = A.dtype
a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype) a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype)
...@@ -36,6 +35,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type): ...@@ -36,6 +35,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_pool(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
...@@ -57,7 +58,6 @@ def verify_global_pool(n, c, h, w, pool_type): ...@@ -57,7 +58,6 @@ def verify_global_pool(n, c, h, w, pool_type):
A = tvm.placeholder((n, c, h, w), name='A') A = tvm.placeholder((n, c, h, w), name='A')
B = topi.nn.global_pool(A, pool_type=pool_type) B = topi.nn.global_pool(A, pool_type=pool_type)
B = topi.nn.relu(B) B = topi.nn.relu(B)
s = topi.cuda.schedule_global_pool(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
if pool_type == 'avg': if pool_type == 'avg':
...@@ -70,6 +70,8 @@ def verify_global_pool(n, c, h, w, pool_type): ...@@ -70,6 +70,8 @@ def verify_global_pool(n, c, h, w, pool_type):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_global_pool(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
......
...@@ -45,11 +45,13 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -45,11 +45,13 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
out_dtype = "int32" out_dtype = "int32"
else: else:
raise NotImplementedError raise NotImplementedError
s = topi.cuda.schedule_reduce(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_reduce(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="sum") foo = tvm.build(s, [A, B], device, name="sum")
# Test # Test
......
...@@ -12,8 +12,6 @@ def verify_softmax(m, n): ...@@ -12,8 +12,6 @@ def verify_softmax(m, n):
s = tvm.create_schedule([B.op]) s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True) tvm.lower(s, [A, B], simple_mode=True)
s = topi.cuda.schedule_softmax(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.softmax_python(a_np) b_np = topi.testing.softmax_python(a_np)
...@@ -21,6 +19,8 @@ def verify_softmax(m, n): ...@@ -21,6 +19,8 @@ def verify_softmax(m, n):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
...@@ -43,7 +43,6 @@ def verify_log_softmax(m, n): ...@@ -43,7 +43,6 @@ def verify_log_softmax(m, n):
s = tvm.create_schedule([B.op]) s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True) tvm.lower(s, [A, B], simple_mode=True)
s = topi.cuda.schedule_softmax(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.log_softmax_python(a_np) b_np = topi.testing.log_softmax_python(a_np)
...@@ -52,6 +51,8 @@ def verify_log_softmax(m, n): ...@@ -52,6 +51,8 @@ def verify_log_softmax(m, n):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
......
...@@ -6,11 +6,12 @@ import topi ...@@ -6,11 +6,12 @@ import topi
def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.expand_dims(A, axis, num_newaxis) B = topi.expand_dims(A, axis, num_newaxis)
s = topi.cuda.schedule_broadcast(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="expand_dims") foo = tvm.build(s, [A, B], device, name="expand_dims")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype) data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
...@@ -23,17 +24,18 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): ...@@ -23,17 +24,18 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def verify_tranpose(in_shape, axes): def verify_tranpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.transpose(A, axes) B = topi.transpose(A, axes)
s = topi.cuda.schedule_injective(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="tranpose") foo = tvm.build(s, [A, B], device, name="tranpose")
data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype) data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
...@@ -46,16 +48,17 @@ def verify_tranpose(in_shape, axes): ...@@ -46,16 +48,17 @@ def verify_tranpose(in_shape, axes):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def verify_reshape(src_shape, dst_shape): def verify_reshape(src_shape, dst_shape):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
B = topi.reshape(A, dst_shape) B = topi.reshape(A, dst_shape)
s = topi.cuda.schedule_injective(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="reshape") foo = tvm.build(s, [A, B], device, name="reshape")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
...@@ -68,16 +71,17 @@ def verify_reshape(src_shape, dst_shape): ...@@ -68,16 +71,17 @@ def verify_reshape(src_shape, dst_shape):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def verify_squeeze(src_shape, axis): def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
B = topi.squeeze(A, axis=axis) B = topi.squeeze(A, axis=axis)
s = topi.cuda.schedule_injective(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="squeeze") foo = tvm.build(s, [A, B], device, name="squeeze")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
...@@ -94,18 +98,19 @@ def verify_squeeze(src_shape, axis): ...@@ -94,18 +98,19 @@ def verify_squeeze(src_shape, axis):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def verify_concatenate(shapes, axis): def verify_concatenate(shapes, axis):
tensor_l = [] tensor_l = []
for i, shape in enumerate(shapes): for i, shape in enumerate(shapes):
tensor_l.append(tvm.placeholder(shape, name="A" + str(i))) tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis) out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis)
s = topi.cuda.schedule_injective(out_tensor)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate") foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes] data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
...@@ -118,16 +123,17 @@ def verify_concatenate(shapes, axis): ...@@ -118,16 +123,17 @@ def verify_concatenate(shapes, axis):
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def verify_split(src_shape, indices_or_sections, axis): def verify_split(src_shape, indices_or_sections, axis):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
tensor_l = topi.split(A, indices_or_sections, axis=axis) tensor_l = topi.split(A, indices_or_sections, axis=axis)
s = topi.cuda.schedule_injective(tensor_l)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
with tvm.target.create(device):
s = topi.generic.schedule_injective(tensor_l)
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
foo = tvm.build(s, [A] + tensor_l, device, name="split") foo = tvm.build(s, [A] + tensor_l, device, name="split")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
...@@ -142,7 +148,7 @@ def verify_split(src_shape, indices_or_sections, axis): ...@@ -142,7 +148,7 @@ def verify_split(src_shape, indices_or_sections, axis):
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
def test_expand_dims(): def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2) verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), (1, 3, 10), -3, 1) verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
...@@ -190,4 +196,3 @@ if __name__ == "__main__": ...@@ -190,4 +196,3 @@ if __name__ == "__main__":
test_squeeze() test_squeeze()
test_concatenate() test_concatenate()
test_split() test_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment