Commit 31fb14e4 by Tianqi Chen Committed by GitHub

[TOPI] Formalize the tag system (#473)

parent 0a6c36ce
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
namespace topi { namespace topi {
constexpr auto kElementWise = "ewise"; constexpr auto kElementWise = "elemwise";
constexpr auto kBroadcast = "bcast"; constexpr auto kBroadcast = "broadcast";
constexpr auto kMatMult = "matmult"; constexpr auto kMatMult = "matmult";
constexpr auto kConv2dNCHW = "conv2d_nchw"; constexpr auto kConv2dNCHW = "conv2d_nchw";
constexpr auto kConv2dHWCN = "conv2d_hwcn"; constexpr auto kConv2dHWCN = "conv2d_hwcn";
......
...@@ -11,6 +11,7 @@ from __future__ import absolute_import as _abs ...@@ -11,6 +11,7 @@ from __future__ import absolute_import as _abs
from .math import * from .math import *
from .reduction import * from .reduction import *
from .transform import *
from .broadcast import * from .broadcast import *
from . import nn from . import nn
from . import cuda from . import cuda
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""Broadcast operators""" """Broadcast operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .import tag
from .util import get_const_tuple, equal_const_int from .util import get_const_tuple, equal_const_int
def _get_bcast_info(original_shape, target_shape): def _get_bcast_info(original_shape, target_shape):
...@@ -113,7 +114,7 @@ def broadcast_to(data, shape): ...@@ -113,7 +114,7 @@ def broadcast_to(data, shape):
return ret return ret
@tvm.tag_scope(tag="broadcast_binary_op") @tvm.tag_scope(tag=tag.BROADCAST)
def broadcast_binary_op(lhs, rhs, func, name="bop"): def broadcast_binary_op(lhs, rhs, func, name="bop"):
"""Binary operands that will automatically broadcast the inputs """Binary operands that will automatically broadcast the inputs
......
...@@ -8,6 +8,5 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise ...@@ -8,6 +8,5 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
from .reduction import schedule_reduce from .reduction import schedule_reduce
from .broadcast import schedule_broadcast
from .softmax import schedule_softmax from .softmax import schedule_softmax
from .elemwise import schedule_elemwise from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
# pylint: disable=invalid-name,unused-variable
"""Schedule for broadcast operators"""
from __future__ import absolute_import as _abs
import tvm
from .elemwise import _schedule_elemwise
def schedule_broadcast(outs):
"""Schedule for broadcasting ops (broadcast_to + broadcast binary) + element-wise ops.
Parameters
----------
outs: Array of Tensor
The computation graph description of broadcast_to in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
sch = tvm.create_schedule([x.op for x in outs])
def traverse(operator):
if operator.tag == 'ewise' or operator.tag == 'scale_shift':
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif operator.tag == 'broadcast_to' or operator.tag == 'broadcast_binary_op':
_schedule_elemwise(operator, sch)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
traverse(outs[0].op)
return sch
# pylint: disable=invalid-name, too-many-locals, too-many-statements # pylint: disable=invalid-name, too-many-locals, too-many-statements
"""Schedule for conv2d_hwcn with auto fusion""" """Schedule for conv2d_hwcn with auto fusion"""
import tvm import tvm
from .. import tag
def schedule_conv2d_hwcn(outs): def schedule_conv2d_hwcn(outs):
"""Schedule for conv2d_hwcn and any element-wise operations. """Schedule for conv2d_hwcn and any element-wise operations.
...@@ -101,7 +101,7 @@ def schedule_conv2d_hwcn(outs): ...@@ -101,7 +101,7 @@ def schedule_conv2d_hwcn(outs):
def traverse(operator): def traverse(operator):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
if operator.tag == 'ewise' or operator.tag == 'scale_shift': if tag.is_broadcast(operator.tag):
if operator not in sch.outputs: if operator not in sch.outputs:
sch[operator].compute_inline() sch[operator].compute_inline()
for tensor in operator.input_tensors: for tensor in operator.input_tensors:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""Schedule for conv2d_nchw with auto fusion""" """Schedule for conv2d_nchw with auto fusion"""
import tvm import tvm
from .. import util from .. import util
from .. import tag
def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L): def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern""" """Schedule conv2d for specific feature_in_out_filter pattern"""
...@@ -389,7 +390,7 @@ def schedule_conv2d_small_batch(outs): ...@@ -389,7 +390,7 @@ def schedule_conv2d_small_batch(outs):
def traverse(OP): def traverse(OP):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
if 'ewise' in OP.tag or 'bcast' in OP.tag: if tag.is_broadcast(OP.tag):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""Schedule for depthwise_conv2d with auto fusion""" """Schedule for depthwise_conv2d with auto fusion"""
import tvm import tvm
from ..util import get_const_tuple from ..util import get_const_tuple
from .. import tag
def schedule_depthwise_conv2d_nchw(outs): def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for depthwise_conv2d nchw forward. """Schedule for depthwise_conv2d nchw forward.
...@@ -100,7 +101,7 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -100,7 +101,7 @@ def schedule_depthwise_conv2d_nchw(outs):
def traverse(OP): def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
if 'ewise' in OP.tag or 'bcast' in OP.tag: if tag.is_broadcast(OP.tag):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
...@@ -171,7 +172,7 @@ def schedule_depthwise_conv2d_nhwc(outs): ...@@ -171,7 +172,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
def traverse(OP): def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
if 'ewise' in OP.tag or 'bcast' in OP.tag: if tag.is_broadcast(OP.tag):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
......
# pylint: disable=invalid-name, unused-variable, trailing-whitespace, no-member # pylint: disable=invalid-name, unused-variable,
"""Schedule for element wise operator""" """Schedule for composition of injective operator"""
import tvm import tvm
def _schedule_injective(op, sch):
def _schedule_elemwise(op, sch):
x = op.output(0) x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis) fused = sch[x].fuse(*sch[x].op.axis)
num_thread = 512 num_thread = 512
...@@ -13,8 +12,8 @@ def _schedule_elemwise(op, sch): ...@@ -13,8 +12,8 @@ def _schedule_elemwise(op, sch):
return sch return sch
def schedule_elemwise(outs): def schedule_injective(outs):
"""Schedule for element wise op. """Schedule for injective op.
Parameters Parameters
---------- ----------
...@@ -31,4 +30,7 @@ def schedule_elemwise(outs): ...@@ -31,4 +30,7 @@ def schedule_elemwise(outs):
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s) tvm.schedule.AutoInlineInjective(s)
return _schedule_elemwise(outs[0].op, s) return _schedule_injective(outs[0].op, s)
schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"""Schedule for reduce operators""" """Schedule for reduce operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import tag
def _schedule_reduce(op, sch): def _schedule_reduce(op, sch):
data_in = op.input_tensors[0] data_in = op.input_tensors[0]
...@@ -42,7 +42,7 @@ def _schedule_reduce(op, sch): ...@@ -42,7 +42,7 @@ def _schedule_reduce(op, sch):
def schedule_reduce(outs): def schedule_reduce(outs):
"""Schedule for reduce ops + ewise + scale_shift ops. """Schedule for inject->reduce->bcast ops.
Parameters Parameters
---------- ----------
...@@ -58,7 +58,7 @@ def schedule_reduce(outs): ...@@ -58,7 +58,7 @@ def schedule_reduce(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
sch = tvm.create_schedule([x.op for x in outs]) sch = tvm.create_schedule([x.op for x in outs])
def traverse(operator): def traverse(operator):
if operator.tag == 'ewise' or operator.tag == 'scale_shift': if tag.is_injective(operator.tag):
if operator not in sch.outputs: if operator not in sch.outputs:
sch[operator].compute_inline() sch[operator].compute_inline()
for tensor in operator.input_tensors: for tensor in operator.input_tensors:
......
"""Elementwise operators""" """Elementwise operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from . import tag
@tvm.tag_scope(tag='ewise') @tvm.tag_scope(tag=tag.ELEMWISE)
def identity(x): def identity(x):
"""Take identity of input x. """Take identity of input x.
...@@ -20,9 +21,9 @@ def identity(x): ...@@ -20,9 +21,9 @@ def identity(x):
return tvm.compute(x.shape, lambda *i: x(*i)) return tvm.compute(x.shape, lambda *i: x(*i))
@tvm.tag_scope(tag='ewise') @tvm.tag_scope(tag=tag.ELEMWISE)
def negative(x): def negative(x):
"""Take negative of input x. """Take negation of input x.
Parameters Parameters
---------- ----------
...@@ -38,7 +39,7 @@ def negative(x): ...@@ -38,7 +39,7 @@ def negative(x):
return tvm.compute(x.shape, lambda *i: -x(*i)) return tvm.compute(x.shape, lambda *i: -x(*i))
@tvm.tag_scope(tag="ewise") @tvm.tag_scope(tag=tag.ELEMWISE)
def exp(x): def exp(x):
"""Take exponential of input x. """Take exponential of input x.
...@@ -55,7 +56,7 @@ def exp(x): ...@@ -55,7 +56,7 @@ def exp(x):
return tvm.compute(x.shape, lambda *i: tvm.exp(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.exp(x(*i)))
@tvm.tag_scope(tag="ewise") @tvm.tag_scope(tag=tag.ELEMWISE)
def tanh(x): def tanh(x):
"""Take hyperbolic tanh of input x. """Take hyperbolic tanh of input x.
...@@ -72,7 +73,7 @@ def tanh(x): ...@@ -72,7 +73,7 @@ def tanh(x):
return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i)))
@tvm.tag_scope(tag="ewise") @tvm.tag_scope(tag=tag.ELEMWISE)
def log(x): def log(x):
"""Take logarithm of input x. """Take logarithm of input x.
...@@ -89,7 +90,7 @@ def log(x): ...@@ -89,7 +90,7 @@ def log(x):
return tvm.compute(x.shape, lambda *i: tvm.log(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.log(x(*i)))
@tvm.tag_scope(tag="ewise") @tvm.tag_scope(tag=tag.ELEMWISE)
def sqrt(x): def sqrt(x):
"""Take square root of input x. """Take square root of input x.
...@@ -106,7 +107,7 @@ def sqrt(x): ...@@ -106,7 +107,7 @@ def sqrt(x):
return tvm.compute(x.shape, lambda *i: tvm.sqrt(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.sqrt(x(*i)))
@tvm.tag_scope(tag="ewise") @tvm.tag_scope(tag=tag.ELEMWISE)
def sigmoid(x): def sigmoid(x):
"""Take sigmoid tanh of input x. """Take sigmoid tanh of input x.
......
"""TVM operator batch normalization compute.""" """TVM operator batch normalization compute."""
from __future__ import absolute_import from __future__ import absolute_import
import tvm import tvm
from .. import tag
@tvm.tag_scope(tag='batch_norm') @tvm.tag_scope(tag=tag.BROADCAST)
def batch_norm(data, gamma, beta, moving_mean, moving_var, eps, fix_gamma): def batch_norm_inference(data, gamma, beta, moving_mean, moving_var, eps, fix_gamma):
"""Batch normalization operator in NCHW layout. """Batch normalization inference operator in NCHW layout.
Parameters Parameters
---------- ----------
......
...@@ -93,6 +93,7 @@ def _get_workload(data, kernel, stride, padding): ...@@ -93,6 +93,7 @@ def _get_workload(data, kernel, stride, padding):
HSTR, WSTR = stride, stride HSTR, WSTR = stride, stride
return Workload(IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) return Workload(IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _get_schedule(wkl, target=None): def _get_schedule(wkl, target=None):
""" Get the platform specific schedule. """ """ Get the platform specific schedule. """
if target is None: if target is None:
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import util from .. import util
from .. import tag
@tvm.tag_scope(tag=tag.INJECTIVE)
@tvm.tag_scope(tag="dilation")
def dilate(data, strides, name="DilatedInput"): def dilate(data, strides, name="DilatedInput"):
"""Dilate data with zeros. """Dilate data with zeros.
......
"""Elementwise operators""" """Elementwise operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import tag
@tvm.tag_scope(tag="ewise") @tvm.tag_scope(tag=tag.ELEMWISE)
def relu(x): def relu(x):
"""Take relu of input x. """Take relu of input x.
......
"""TVM operator flatten compute.""" """TVM operator flatten compute."""
from __future__ import absolute_import from __future__ import absolute_import
import tvm import tvm
from .. import tag
@tvm.tag_scope(tag='flatten') @tvm.tag_scope(tag=tag.INJECTIVE)
def flatten(data): def flatten(data):
"""Flattens the input array into a 2-D array by collapsing the higher dimensions. """Flattens the input array into a 2-D array by collapsing the higher dimensions.
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
"""Operators of one-to-one-mapping on the first input""" """Operators of one-to-one-mapping on the first input"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import tag
@tvm.tag_scope(tag="bcast_scale_shift_nchw") @tvm.tag_scope(tag=tag.BROADCAST)
def scale_shift_nchw(Input, Scale, Shift): def scale_shift_nchw(Input, Scale, Shift):
"""Batch normalization operator in inference. """Batch normalization operator in inference.
...@@ -25,7 +26,8 @@ def scale_shift_nchw(Input, Scale, Shift): ...@@ -25,7 +26,8 @@ def scale_shift_nchw(Input, Scale, Shift):
""" """
return tvm.compute(Input.shape, lambda b, c, i, j: Input[b, c, i, j] * Scale[c] + Shift[c], name='ScaleShift') return tvm.compute(Input.shape, lambda b, c, i, j: Input[b, c, i, j] * Scale[c] + Shift[c], name='ScaleShift')
@tvm.tag_scope(tag="bcast_scale_shift_nhwc")
@tvm.tag_scope(tag=tag.BROADCAST)
def scale_shift_nhwc(Input, Scale, Shift): def scale_shift_nhwc(Input, Scale, Shift):
"""Batch normalization operator in inference. """Batch normalization operator in inference.
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from ..util import equal_const_int from ..util import equal_const_int
from .. import tag
@tvm.tag_scope(tag="pad") @tvm.tag_scope(tag=tag.INJECTIVE+",pad")
def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"): def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
"""Dilate Input with zeros. """Dilate Input with zeros.
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import target as _target from .. import target as _target
from .. import tag
from ..nn.convolution import SpatialPack, Im2ColPack from ..nn.convolution import SpatialPack, Im2ColPack
from ..nn.convolution import _CONV_DECLARATION, _CONV_SCHEDULE from ..nn.convolution import _CONV_DECLARATION, _CONV_SCHEDULE
from ..nn.convolution import _WORKLOADS, _SCH_TO_DECL_FUNC from ..nn.convolution import _WORKLOADS, _SCH_TO_DECL_FUNC
...@@ -270,7 +271,7 @@ def schedule_convolution(outs): ...@@ -270,7 +271,7 @@ def schedule_convolution(outs):
def traverse(op): def traverse(op):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
if 'ewise' in op.tag or 'bcast' in op.tag: if tag.is_broadcast(op.tag):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"""Reduce operators""" """Reduce operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from . import tag
def _get_real_axis(ndim, axis): def _get_real_axis(ndim, axis):
if axis is None: if axis is None:
...@@ -37,7 +37,7 @@ def get_reduce_out_shape(src_shape, axis=None, keepdims=False): ...@@ -37,7 +37,7 @@ def get_reduce_out_shape(src_shape, axis=None, keepdims=False):
return dst_shape return dst_shape
@tvm.tag_scope(tag="comm_reduce") @tvm.tag_scope(tag=tag.COMM_REDUCE)
def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum): def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum):
"""Reducing the data """Reducing the data
......
"""Namespace of all tag system in tvm
Each operator can be tagged by a tag, which indicate its type.
Generic categories
- tag.ELEMWISE="elemwise":
Elementwise operator, for example :code:`out[i, j] = input[i, j]`
- tag.BROADCAST="broadcast":
Broadcasting operator, can always map output axis to the input in order.
for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
Note that the axis need to be in order so transpose is not a bcast operator.
If an input of broadcast operator has same shape as output,
we can ensure that it is elementwise relation.
- tag.INJECTIVE="injective":
Injective operator, can always injectively map output axis to a single input axis.
All injective operator can still be safely fused similar to ewise to reduction.
- tag.COMM_REDUCE="comm_reduce":
Communicative reduction operator
- If an op does not belong to these generic categories, it should have a special tag.
Note
----
When we add a new topi operator, the op need to be tagged as generic as possible.
We can also compose tags like "injective,pad" to give generic and specific information.
When we use composed tags, we must always put generic tag in the first location.
"""
ELEMWISE = "elemwise"
BROADCAST = "broadcast"
INJECTIVE = "injective"
COMM_REDUCE = "comm_reduce"
def is_broadcast(tag):
"""Check if a tag is bcast
Parameters
----------
tag : str
The input tag
Returns
-------
ret : bool
Whether a tag is broadcast
"""
if tag in (ELEMWISE, BROADCAST):
return True
return tag.startswith(ELEMWISE) or tag.startswith(BROADCAST)
def is_injective(tag):
"""Check if a tag is injective
Parameters
----------
tag : str
The input tag
Returns
-------
ret : bool
Whether a tag is injective
"""
if tag in (ELEMWISE, BROADCAST, INJECTIVE):
return True
return (tag.startswith(ELEMWISE) or
tag.startswith(BROADCAST) or
tag.startswith(INJECTIVE))
"""Injective transformation operators"""
from __future__ import absolute_import as _abs
import tvm
from . import tag
@tvm.tag_scope(tag=tag.BROADCAST)
def expand_dims(a, axis, num_newaxis=1):
"""Expand the shape of an array.
Parameters
----------
a : tvm.Tensor
The tensor to be expanded.
num_newaxis: int, optional
Number of newaxis to be inserted on axis
Returns
-------
ret : tvm.Tensor
"""
axis = len(a.shape) + axis + 1 if axis < 0 else axis
new_shape = a.shape[:axis] + ([1] * num_newaxis) + a.shape[axis:]
def _compute(*indices):
idx = indices[:axis] + indices[axis + num_newaxis:]
return a(*idx)
return tvm.compute(new_shape, _compute)
...@@ -16,7 +16,6 @@ def verify_convolution(batch, in_size, in_channel, num_filter, kernel, stride, p ...@@ -16,7 +16,6 @@ def verify_convolution(batch, in_size, in_channel, num_filter, kernel, stride, p
B = topi.nn.convolution(A, W, stride, padding) B = topi.nn.convolution(A, W, stride, padding)
s = topi.rasp.schedule_convolution([B]) s = topi.rasp.schedule_convolution([B])
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
dtype = A.dtype dtype = A.dtype
......
"""Test code for broadcasting operators."""
import numpy as np
import tvm
import topi
def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.expand_dims(A, axis, num_newaxis)
s = topi.cuda.schedule_broadcast(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, [A, B], device, name="expand_dims")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = data_npy.reshape(out_shape)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("opencl")
check_device("cuda")
check_device("metal")
def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
if __name__ == "__main__":
test_expand_dims()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment