Commit eb761f36 by Tianqi Chen Committed by GitHub

[Refactor] Introduce target generic dispatch system (#556)

* [TVM] Introduce target generic dispatch system

* fix target warning
parent c3cac464
......@@ -8,6 +8,7 @@ Python API
intrin
tensor
schedule
target
build
module
ndarray
......
tvm.target
----------
.. automodule:: tvm.target
.. autofunction:: tvm.target.generic_func
.. autoclass:: tvm.target.Target
:members:
.. autofunction:: tvm.target.cuda
.. autofunction:: tvm.target.rocm
.. autofunction:: tvm.target.rasp
.. autofunction:: tvm.target.create
......@@ -37,13 +37,11 @@ Index
.. autosummary::
topi.cuda.schedule_conv2d_nchw
topi.cuda.schedule_conv2d_hwcn
topi.cuda.schedule_depthwise_conv2d_nchw
topi.cuda.schedule_depthwise_conv2d_nhwc
topi.cuda.schedule_reduce
topi.cuda.schedule_broadcast
topi.cuda.schedule_injective
topi.generic.schedule_conv2d_nchw
topi.generic.schedule_depthwise_conv2d_nchw
topi.generic.schedule_reduce
topi.generic.schedule_broadcast
topi.generic.schedule_injective
topi
~~~~
......@@ -75,14 +73,12 @@ topi.nn
.. autofunction:: topi.nn.depthwise_conv2d_nhwc
topi.cuda
~~~~~~~~~
.. automodule:: topi.cuda
topi.generic
~~~~~~~~~~~~
.. automodule:: topi.generic
.. autofunction:: topi.cuda.schedule_conv2d_nchw
.. autofunction:: topi.cuda.schedule_conv2d_hwcn
.. autofunction:: topi.cuda.schedule_depthwise_conv2d_nchw
.. autofunction:: topi.cuda.schedule_depthwise_conv2d_nhwc
.. autofunction:: topi.cuda.schedule_reduce
.. autofunction:: topi.cuda.schedule_broadcast
.. autofunction:: topi.cuda.schedule_injective
.. autofunction:: topi.generic.schedule_conv2d_nchw
.. autofunction:: topi.generic.schedule_depthwise_conv2d_nchw
.. autofunction:: topi.generic.schedule_reduce
.. autofunction:: topi.generic.schedule_broadcast
.. autofunction:: topi.generic.schedule_injective
......@@ -56,11 +56,7 @@ def context(dev_type, dev_id=0):
assert tvm.context("cuda", 0) == tvm.gpu(0)
"""
if isinstance(dev_type, string_types):
if dev_type not in TVMContext.STR2MASK:
if dev_type.find("nvptx") != -1:
dev_type = "cuda"
if dev_type.find("rocm") != -1:
dev_type = "rocm"
dev_type = dev_type.split()[0]
if dev_type not in TVMContext.STR2MASK:
raise ValueError("Unknown device type %s" % dev_type)
dev_type = TVMContext.STR2MASK[dev_type]
......
......@@ -100,9 +100,12 @@ class TVMContext(ctypes.Structure):
12: 'ext_dev',
}
STR2MASK = {
'llvm': 1,
'stackvm': 1,
'cpu': 1,
'gpu': 2,
'cuda': 2,
'nvptx': 2,
'cl': 4,
'opencl': 4,
'metal': 8,
......
......@@ -15,6 +15,7 @@ from . import container
from . import module
from . import codegen
from . import ndarray
from . import target as _target
class BuildConfig(object):
"""Configuration scope to set a build config option.
......@@ -238,7 +239,7 @@ def lower(sch,
def build(sch,
args=None,
target="llvm",
target=None,
target_host=None,
name="default_function",
binds=None):
......@@ -252,36 +253,10 @@ def build(sch,
args : list of Buffer or Tensor or Var, optional
The argument lists to the function.
target : str, optional
target : str or :any:`tvm.target.Target`, optional
The target and option of the compilation.
When the target is llvm, you can set options like:
- **-mtriple=<target triple>** or **-target**
Specify the target triple, which is useful for cross
compilation.
- **-mcpu=<cpuname>**
Specify a specific chip in the current architecture to
generate code for. By default this is infered from the
target triple and autodetected to the current architecture.
- **-mattr=a1,+a2,-a3,...**
Override or control specific attributes of the target,
such as whether SIMD operations are enabled or not. The
default set of attributes is set by the current CPU.
- **-system-lib**
Build TVM system library module. System lib is a global module that contains
self registered functions in program startup. User can get the module using
:any:`tvm.module.system_lib`.
It is useful in environments where dynamic loading api like dlopen is banned.
The system lib will be available as long as the result code is linked by the program.
target_host : str, optional
target_host : str or :any:`tvm.target.Target` optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
......@@ -301,6 +276,10 @@ def build(sch,
-------
f : Function, or pair of functions
The result function.
Note
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(sch, schedule.Schedule):
if args is None:
......@@ -325,6 +304,9 @@ def build(sch,
if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name)
target = _target.current_target() if target is None else target
target = _target.create(target) if target else _target.create("llvm")
fhost = []
fdevice = []
for func in flist:
......@@ -332,7 +314,7 @@ def build(sch,
if BuildConfig.current.detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
warp_size = 32 if target == "cuda" else 1
warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(func)]
fhost.append(fsplits[0])
......@@ -345,29 +327,28 @@ def build(sch,
else:
raise ValueError("unknown function type %d" % func.func_type)
if not target.startswith("llvm") and target not in ("stackvm", "ext_dev") and not fdevice:
if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target)
device = "cpu" if target.startswith("llvm") or target == "stackvm" else target
device_type = ndarray.context(device, 0).device_type
device_type = ndarray.context(target.target_name, 0).device_type
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
if not target_host:
if device == "cpu":
if device_type == ndarray.cpu(0).device_type:
target_host = target
assert not fdevice
else:
target_host = "llvm" if module.enabled("llvm") else "stackvm"
target_host = _target.create(target_host)
target_device = target
fdevice = [ir_pass.LowerIntrin(x, target_device) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target_device.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mhost = codegen.build_module(fhost, target_host)
mhost = codegen.build_module(fhost, str(target_host))
if fdevice:
mdev = codegen.build_module(fdevice, target_device)
mdev = codegen.build_module(fdevice, str(target_device))
mhost.import_module(mdev)
return mhost
"""Target management API of tvm"""
"""Target management API of TVM.
TVM's target string is in fomat ``<target_name> [-option=value]...``.
Note
----
The list of options include:
- **-device=<device name>**
The device name.
- **-mtriple=<target triple>** or **-target**
Specify the target triple, which is useful for cross
compilation.
- **-mcpu=<cpuname>**
Specify a specific chip in the current architecture to
generate code for. By default this is infered from the
target triple and autodetected to the current architecture.
- **-mattr=a1,+a2,-a3,...**
Override or control specific attributes of the target,
such as whether SIMD operations are enabled or not. The
default set of attributes is set by the current CPU.
- **-system-lib**
Build TVM system library module. System lib is a global module that contains
self registered functions in program startup. User can get the module using
:any:`tvm.module.system_lib`.
It is useful in environments where dynamic loading api like dlopen is banned.
The system lib will be available as long as the result code is linked by the program.
We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string.
We can also use other specific function in this module to create specific targets.
"""
from __future__ import absolute_import
import warnings
from ._ffi.base import _LIB_NAME
try:
from decorator import decorate
except ImportError as err_msg:
# Allow decorator to be missing in runtime
if _LIB_NAME != "libtvm_runtime.so":
raise err_msg
def _merge_opts(opts, new_opts):
"""Helper function to merge options"""
if isinstance(new_opts, str):
new_opts = new_opts.split()
if new_opts:
return opts + new_opts
return opts
class Target(object):
"""A Target describes the target type on which computation should be carried on"""
default_target = None
str2type = {'x86': 1, 'cuda': 2, 'rasp': 3}
type2str = {1: 'x86', 2: 'cuda', 3: 'rasp'}
def __init__(self, target_type):
"""Constructs a context."""
if isinstance(target_type, Target):
self.target_typeid = target_type.target_typeid
else:
self.target_typeid = Target.str2type[target_type]
"""Target device information, use through TVM API.
@property
def target_type(self):
"""Returns the target type of current target."""
return Target.type2str[self.target_typeid]
Parameters
----------
target_name : {"llvm", "cuda", "opencl", "metal", "rocm", "stackvm", "ext_dev"}
The major target name.
def __hash__(self):
"""Compute hash value of target for dictionary lookup"""
return hash(self.target_typeid)
options : list of str, optional
Additional arguments appended to the target.
def __eq__(self, other):
"""Compares two targets. Two targets are equal if they
have the same target type.
Note
----
Do not use class constructor, you can create target using the following functions
- :any:`tvm.target.create` create target from string
- :any:`tvm.target.rasp` create raspberry pi target
- :any:`tvm.target.cuda` create CUDA target
- :any:`tvm.target.rocm` create ROCM target
"""
return isinstance(other, Target) and \
self.target_typeid == other.target_typeid
current = None
def __init__(self,
target_name,
options=None):
self.target_name = target_name
self.options = _merge_opts([], options)
self.device_name = ""
# Parse device option
for item in self.options:
if item.startswith("-device="):
self.device_name = item.split("=")[1]
# Target query searchs device name first
if self.device_name:
self.keys = (self.device_name,)
else:
self.keys = ()
# Target configuration handling
self.thread_warp_size = 1
if target_name in ("llvm", ):
self.keys += ("cpu",)
elif target_name in ("cuda", "nvptx"):
self.keys += ("cuda", "gpu")
self.max_num_threads = 512
self.thread_warp_size = 32
elif target_name in ("rocm", "opencl"):
# For now assume rocm schedule for opencl
self.keys += ("rocm", "gpu")
self.max_num_threads = 256
elif target_name in ("metal",):
self.keys += ("gpu",)
self.max_num_threads = 256
elif target_name in ("stackvm", "ext_dev"):
# Do not now class for stacvm or ext_dev
pass
else:
raise ValueError("Unknown target name %s" % target_name)
def __str__(self):
return '%s' % (self.target_type)
return " ".join([self.target_name] + self.options)
def __repr__(self):
return self.__str__()
def __enter__(self):
self._old_target = Target.default_target
Target.default_target = self
self._old_target = Target.current
if self._old_target is not None and str(self) != str(self._old_target):
warnings.warn(
"Override target '%s' with new target scope '%s'" % (
self._old_target, self))
Target.current = self
return self
def __exit__(self, ptype, value, trace):
Target.default_target = self._old_target
Target.current = self._old_target
def generic_func(fdefault):
"""Wrap a target generic function.
Generic function allows registeration of further functions
that can be dispatched on current target context.
If no registered dispatch is matched, the fdefault will be called.
Target.default_target = Target('x86')
Parameters
----------
fdefault : function
The default function.
def x86():
"""Returns a x86 target."""
return Target('x86')
Returns
-------
fgeneric : function
A wrapped generic function.
Example
-------
.. code-block:: python
import tvm
# wrap function as target generic
@tvm.target.generic_func
def my_func(a):
return a + 1
# register specialization of my_func under target cuda
@my_func.register("cuda")
def my_func_cuda(a):
return a + 2
# displays 3, because my_func is called
print(my_func(2))
# displays 4, because my_func_cuda is called
with tvm.target.cuda():
print(my_func(2))
"""
dispatch_dict = {}
func_name = fdefault.__name__
def cuda():
"""Returns a cuda target."""
return Target('cuda')
def register(key, func=None, override=False):
"""Register function to be the dispatch function.
def rasp():
"""Returns a rasp target."""
return Target('rasp')
Parameters
----------
key : str or list of str
The key to be registered.
def current_target():
"""Returns the current target."""
return Target.default_target
func : function
The function to be registered.
override : bool
Whether override existing registeration.
Returns
-------
The register function is necessary.
"""
def _do_reg(myf):
key_list = [key] if isinstance(key, str) else key
for k in key_list:
if k in dispatch_dict and not override:
raise ValueError(
"Key is already registered for %s" % func_name)
dispatch_dict[k] = myf
return myf
if func:
return _do_reg(myf)
return _do_reg
def dispatch_func(func, *args, **kwargs):
"""The wrapped dispath function"""
target = current_target()
if target is None:
return func(*args, **kwargs)
for k in target.keys:
if k in dispatch_dict:
return dispatch_dict[k](*args, **kwargs)
return func(*args, **kwargs)
fdecorate = decorate(fdefault, dispatch_func)
fdecorate.register = register
return fdecorate
def cuda(options=None):
"""Returns a cuda target.
Parameters
----------
options : list of str
Additional options
"""
return Target("cuda", options)
def rocm(options=None):
"""Returns a ROCM target.
Parameters
----------
options : list of str
Additional options
"""
return Target("rocm", options)
def rasp(options=None):
"""Returns a rasp target.
Parameters
----------
options : list of str
Additional options
"""
opts = ["-device=rasp",
"-mtriple=armv7l-none-linux-gnueabihf",
"-mcpu=cortex-a53",
"-mattr=+neon"]
opts = _merge_opts(opts, options)
return Target("llvm", opts)
def create(target_str):
"""Get a target given target string.
Parameters
----------
target_str : str
The target string.
Returns
-------
target : Target
The target object
Note
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(target_str, Target):
return target_str
if not isinstance(target_str, str):
raise ValueError("target_str has to be string type")
arr = target_str.split()
# Parse device option
device_name = ""
for item in arr[1:]:
if item.startswith("-device="):
device_name = item.split("=")[1]
if device_name == "rasp":
return rasp(arr[1:])
return Target(arr[0], arr[1:])
def current_target(allow_none=True):
"""Returns the current target.
Parameters
----------
allow_none : bool
Whether allow the current target to be none
Raises
------
ValueError if current target is not set.
"""
if Target.current:
return Target.current
if not allow_none:
raise RuntimeError(
"Requires a current target in generic function, but it is not set. "
"Please set it using `with TargetObject:`")
return Target.current
......@@ -82,6 +82,8 @@ GetLLVMTargetMachine(const std::string& target_str,
} else {
LOG(FATAL) << "invalid -mfloat-abi option " << value;
}
} else if (key == "-device") {
// pass
} else {
LOG(FATAL) << "unknown option " << key;
}
......
......@@ -68,7 +68,8 @@ def test_gemm():
print("skip because %s is not enabled.." % device)
return
f = tvm.build(s, [A, B, C], device)
with tvm.target.create(device):
f = tvm.build(s, [A, B, C])
ctx = tvm.context(device, 0)
# launch the kernel.
n = nn
......
import tvm
@tvm.target.generic_func
def mygeneric(data):
# default generic function
return data + 1
@mygeneric.register(["cuda", "gpu"])
def cuda_func(data):
return data + 2
@mygeneric.register("rocm")
def rocm_func(data):
return data + 3
@mygeneric.register("cpu")
def rocm_func(data):
return data + 10
def test_target_dispatch():
with tvm.target.cuda():
assert mygeneric(1) == 3
with tvm.target.rocm():
assert mygeneric(1) == 4
with tvm.target.create("cuda"):
assert mygeneric(1) == 3
with tvm.target.rasp():
assert mygeneric(1) == 11
with tvm.target.create("metal"):
assert mygeneric(1) == 3
try:
mygeneric(0)
raise RuntimeError("not reached")
except RuntimeError:
pass
if __name__ == "__main__":
test_target_dispatch()
......@@ -3,6 +3,7 @@
import tvm
from .. import util
from .. import tag
from .. import generic
def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
......@@ -483,6 +484,8 @@ def schedule_conv2d_small_batch(outs):
traverse(outs[0].op)
return s
@generic.schedule_conv2d_nchw.register(["cuda", "gpu"])
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw.
......
......@@ -3,7 +3,9 @@
from __future__ import absolute_import as _abs
import tvm
from .. import tag
from .. import generic
@generic.schedule_dense.register(["cuda", "gpu"])
def schedule_dense(outs):
"""Schedule for dense operator.
......
......@@ -3,7 +3,9 @@
import tvm
from ..util import get_const_tuple
from .. import tag
from .. import generic
@generic.schedule_depthwise_conv2d_nchw.register(["cuda", "gpu"])
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for depthwise_conv2d nchw forward.
......
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
import tvm
from .. import generic
def _schedule_injective(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
num_thread = 512
target = tvm.target.current_target()
target = target if target else tvm.target.cuda()
num_thread = target.max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch
@generic.schedule_injective.register(["cuda", "gpu"])
def schedule_injective(outs):
"""Schedule for injective op.
......
......@@ -2,7 +2,9 @@
"""Schedule for pooling operators"""
import tvm
from .. import tag
from .. import generic
@generic.schedule_global_pool.register(["cuda", "gpu"])
def schedule_global_pool(outs):
"""Schedule for global_pool.
......@@ -63,6 +65,7 @@ def schedule_global_pool(outs):
return s
@generic.schedule_pool.register(["cuda", "gpu"])
def schedule_pool(outs):
"""Schedule for pool.
......
......@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
from .. import tag
from .. import generic
def _schedule_reduce(op, sch, is_idx_reduce=False):
if is_idx_reduce:
......@@ -62,6 +63,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
return sch
@generic.schedule_reduce.register(["cuda", "gpu"])
def schedule_reduce(outs):
"""Schedule for inject->reduce->bcast ops.
......
# pylint: disable=invalid-name, unused-variable, trailing-whitespace
"""Schedule for softmax operator"""
import tvm
from .. import generic
@generic.schedule_softmax.register(["cuda", "gpu"])
def schedule_softmax(outs):
"""Schedule for softmax op.
......
# pylint: disable=wildcard-import
"""Generic declaration and schedules.
This is a recommended way of using TOPI API.
To use the generic schedule function, user must set
the current target scope using with block. See also :any:`tvm.target`
Example
-------
.. code-block:: python
# create schedule that dispatches to topi.cuda.schedule_injective
with tvm.target.create("cuda"):
s = tvm.generic.schedule_injective(outs)
"""
from __future__ import absolute_import as _abs
from .nn import *
from .injective import *
# pylint: disable=invalid-name
"""generic declaration and schedules."""
from __future__ import absolute_import as _abs
import tvm
@tvm.target.generic_func
def schedule_injective(outs):
"""Schedule for injective op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm":
raise RuntimeError("schedule_injective not registered for '%s'" % target)
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
"""Generic nn operators"""
from __future__ import absolute_import as _abs
import tvm
def _default_schedule(outs, auto_inline):
"""Default schedule for llvm."""
target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm":
raise RuntimeError("schedule_pool not registered for '%s'" % target)
s = tvm.create_schedule([x.op for x in outs])
if auto_inline:
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
@tvm.target.generic_func
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d nchow
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for conv2d nchow
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_reduce(outs):
"""Schedule for reduction
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, True)
@tvm.target.generic_func
def schedule_softmax(outs):
"""Schedule for softmax
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_dense(outs):
"""Schedule for dense
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_pool(outs):
"""Schedule for pool
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_global_pool(outs):
"""Schedule for global pool
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=invalid-name, unused-variable, too-many-locals, unused-argument
"""Conv2D operators"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from tvm import target as _target
from .pad import pad
from .util import get_pad_tuple
from ..util import simplify
......@@ -51,9 +50,7 @@ _WORKLOADS = [
# platform specific schedule
_CONV_SCHEDULE = {}
# platform specific declaration
_CONV_DECLARATION = {}
@tvm.target.generic_func
def conv2d(data, kernel, stride, padding, layout='NCHW'):
"""Conv2D operator.
......@@ -80,10 +77,6 @@ def conv2d(data, kernel, stride, padding, layout='NCHW'):
4-D with shape [batch, out_channel, out_height, out_width]
"""
# search platform specific declaration first
target = _target.current_target()
if target in _CONV_DECLARATION:
return _CONV_DECLARATION[target](data, kernel, stride, padding, layout)
# default declaration
if layout == 'NCHW':
return conv2d_nchw(data, kernel, stride, padding)
......@@ -105,15 +98,15 @@ def _get_workload(data, kernel, stride, padding):
return Workload(IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _get_schedule(wkl, target=None):
@tvm.target.generic_func
def _get_schedule(wkl):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
if target is None:
target = _target.current_target()
else:
target = _target.Target(target)
assert target in _CONV_SCHEDULE, "no schedule for such target: {}".format(target)
return _CONV_SCHEDULE[target](wkl)
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
def _spatial_pack(data, kernel, stride, padding):
""" Compute convolution with pack on spatial axes. """
......
......@@ -4,11 +4,12 @@ from __future__ import absolute_import as _abs
import tvm
from tvm import target as _target
from .. import tag
from ..nn.conv2d import conv2d, _get_schedule
from ..nn.conv2d import SpatialPack, Im2ColPack
from ..nn.conv2d import _CONV_DECLARATION, _CONV_SCHEDULE
from ..nn.conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC
from ..nn.conv2d import _get_workload, _get_schedule
from ..nn.conv2d import _get_workload
from ..nn.util import infer_pad, infer_stride
from .. import generic
_SCHEDULES = [
SpatialPack(1, 8, 4, 1, 4, True),
......@@ -36,6 +37,7 @@ _SCHEDULES = [
Im2ColPack(7, 4, 1, 4, True),
]
@_get_schedule.register("rasp")
def _schedule_conv2d(wkl):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
......@@ -43,8 +45,8 @@ def _schedule_conv2d(wkl):
sch = _SCHEDULES[idx]
return sch
_CONV_SCHEDULE[_target.rasp()] = _schedule_conv2d
@conv2d.register("rasp")
def _declaration_conv2d(data, kernel, stride, padding, layout):
assert layout == 'NCHW', "only support NCHW convolution on rasp"
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
......@@ -52,7 +54,6 @@ def _declaration_conv2d(data, kernel, stride, padding, layout):
sch = _get_schedule(wkl)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding)
_CONV_DECLARATION[_target.rasp()] = _declaration_conv2d
def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
kernel, kernel_vec,
......@@ -64,7 +65,9 @@ def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl, 'rasp')
with tvm.target.rasp():
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width
CI, CO = wkl.in_filter, wkl.out_filter
......@@ -170,7 +173,9 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl, 'rasp')
with _target.rasp():
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width
CI = wkl.in_filter
......@@ -275,6 +280,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
return s
@generic.schedule_conv2d_nchw.register(["cpu", "rasp"])
def schedule_conv2d(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
......
......@@ -5,7 +5,7 @@ from collections import namedtuple
import tvm
from .. import tag
from ..nn.util import infer_pad, infer_stride, get_pad_tuple
from .. import generic
_Workload = namedtuple('Workload',
['height', 'width', 'channel', 'multiplier',
......@@ -145,7 +145,7 @@ def _schedule(s, data, data_pad, kernel, output, last):
return s
@generic.schedule_depthwise_conv2d_nchw.register(["cpu", "rasp"])
def schedule_depthwise_conv2d(outs):
"""Schedule for depthwise_conv2d nchw forward.
......
......@@ -8,16 +8,16 @@ def verify_broadcast_to_ele(in_shape, out_shape):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.broadcast_to(A, out_shape)
s = topi.cuda.schedule_broadcast(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
for _ in range(1):
......@@ -48,11 +48,12 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
C = topi.broadcast_minimum(A, B)
else:
raise NotImplementedError
s = topi.cuda.schedule_broadcast(C)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
......
......@@ -14,8 +14,8 @@ def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, paddin
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d(A, W, stride, padding)
s = topi.generic.schedule_conv2d_nchw([B])
s = topi.rasp.schedule_conv2d([B])
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
......
......@@ -14,8 +14,6 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d_nchw(A, W, stride, padding)
C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_nchw([B])
s2 = topi.cuda.schedule_conv2d_nchw([C])
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
......@@ -35,6 +33,9 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s1 = topi.generic.schedule_conv2d_nchw([B])
s2 = topi.generic.schedule_conv2d_nchw([C])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
......
......@@ -12,7 +12,6 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
C = tvm.placeholder((out_dim,), name='C')
D = topi.nn.dense(A, B, C if use_bias else None)
D = topi.nn.relu(D)
s = topi.cuda.schedule_dense(D)
dtype = A.dtype
# use memoize to pickle the test data for next time use
......@@ -33,6 +32,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_dense(D)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
......
......@@ -4,7 +4,7 @@ import numpy as np
from scipy import signal
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nhwc
def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
......@@ -21,15 +21,18 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, stride=[stride_h, stride_w], padding=padding)
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
# schedule
s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = schedule_depthwise_conv2d_nchw(Relu)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
# schedule
s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
ctx = tvm.context(device, 0)
# build the kernels
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
......
......@@ -12,7 +12,6 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type):
A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, pool_type=pool_type)
B = topi.nn.relu(B)
s = topi.cuda.schedule_pool(B)
dtype = A.dtype
a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype)
......@@ -36,6 +35,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_pool(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
......@@ -57,7 +58,6 @@ def verify_global_pool(n, c, h, w, pool_type):
A = tvm.placeholder((n, c, h, w), name='A')
B = topi.nn.global_pool(A, pool_type=pool_type)
B = topi.nn.relu(B)
s = topi.cuda.schedule_global_pool(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
if pool_type == 'avg':
......@@ -70,6 +70,8 @@ def verify_global_pool(n, c, h, w, pool_type):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_global_pool(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
......
......@@ -45,11 +45,13 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
out_dtype = "int32"
else:
raise NotImplementedError
s = topi.cuda.schedule_reduce(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_reduce(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="sum")
# Test
......
......@@ -12,8 +12,6 @@ def verify_softmax(m, n):
s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True)
s = topi.cuda.schedule_softmax(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.softmax_python(a_np)
......@@ -21,6 +19,8 @@ def verify_softmax(m, n):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
......@@ -43,7 +43,6 @@ def verify_log_softmax(m, n):
s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True)
s = topi.cuda.schedule_softmax(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.log_softmax_python(a_np)
......@@ -52,6 +51,8 @@ def verify_log_softmax(m, n):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
......
......@@ -6,11 +6,12 @@ import topi
def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.expand_dims(A, axis, num_newaxis)
s = topi.cuda.schedule_broadcast(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="expand_dims")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
......@@ -29,11 +30,12 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
def verify_tranpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.transpose(A, axes)
s = topi.cuda.schedule_injective(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="tranpose")
data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
......@@ -51,11 +53,12 @@ def verify_tranpose(in_shape, axes):
def verify_reshape(src_shape, dst_shape):
A = tvm.placeholder(shape=src_shape, name="A")
B = topi.reshape(A, dst_shape)
s = topi.cuda.schedule_injective(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="reshape")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
......@@ -73,11 +76,12 @@ def verify_reshape(src_shape, dst_shape):
def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A")
B = topi.squeeze(A, axis=axis)
s = topi.cuda.schedule_injective(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="squeeze")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
......@@ -101,11 +105,12 @@ def verify_concatenate(shapes, axis):
for i, shape in enumerate(shapes):
tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis)
s = topi.cuda.schedule_injective(out_tensor)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor)
ctx = tvm.context(device, 0)
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
......@@ -123,11 +128,12 @@ def verify_concatenate(shapes, axis):
def verify_split(src_shape, indices_or_sections, axis):
A = tvm.placeholder(shape=src_shape, name="A")
tensor_l = topi.split(A, indices_or_sections, axis=axis)
s = topi.cuda.schedule_injective(tensor_l)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_injective(tensor_l)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A] + tensor_l, device, name="split")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
......@@ -190,4 +196,3 @@ if __name__ == "__main__":
test_squeeze()
test_concatenate()
test_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment