Commit 9d583cf5 by Wuwei Lin Committed by Tianqi Chen

[TOPI] Fix traverse function not inline zero-input op (#3623)

* Fix traverse_inline not inline zero input op properly

* Add where to python and set tag to broadcast

* Fix inline

* test

* fix test target

* fix
parent d4a51751
...@@ -807,7 +807,7 @@ inline Tensor where(const Tensor& condition, ...@@ -807,7 +807,7 @@ inline Tensor where(const Tensor& condition,
const Tensor& x, const Tensor& x,
const Tensor& y, const Tensor& y,
std::string name = "T_where", std::string name = "T_where",
std::string tag = kInjective) { std::string tag = kBroadcast) {
CHECK_EQ(x->shape.size(), y->shape.size()) CHECK_EQ(x->shape.size(), y->shape.size())
<< "x and y must have the same shape.Got different number of dimension: " << "x and y must have the same shape.Got different number of dimension: "
<< x->shape.size() << " vs " << y->shape.size(); << x->shape.size() << " vs " << y->shape.size();
......
...@@ -327,7 +327,7 @@ def schedule_bitserial_conv2d_nhwc(cfg, outs): ...@@ -327,7 +327,7 @@ def schedule_bitserial_conv2d_nhwc(cfg, outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'spatial_bitserial_conv_nhwc' in op.tag: if 'spatial_bitserial_conv_nhwc' in op.tag:
......
...@@ -164,7 +164,7 @@ def schedule_bitserial_dense(cfg, outs): ...@@ -164,7 +164,7 @@ def schedule_bitserial_dense(cfg, outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op) traverse(tensor.op)
elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar': elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
......
...@@ -123,7 +123,7 @@ def schedule_conv2d_hwcn(outs): ...@@ -123,7 +123,7 @@ def schedule_conv2d_hwcn(outs):
if operator not in sch.outputs: if operator not in sch.outputs:
sch[operator].compute_inline() sch[operator].compute_inline()
for tensor in operator.input_tensors: for tensor in operator.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
elif operator.tag == 'conv2d_hwcn': elif operator.tag == 'conv2d_hwcn':
Apad = operator.input_tensors[0] Apad = operator.input_tensors[0]
......
...@@ -120,7 +120,7 @@ def schedule_dense(cfg, outs): ...@@ -120,7 +120,7 @@ def schedule_dense(cfg, outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule dense # schedule dense
elif OP.tag == 'dense': elif OP.tag == 'dense':
......
...@@ -198,7 +198,7 @@ def schedule_depthwise_conv2d_nhwc(outs): ...@@ -198,7 +198,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule depthwise_conv2d # schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d_nhwc': if OP.tag == 'depthwise_conv2d_nhwc':
......
...@@ -74,7 +74,7 @@ def schedule_adaptive_pool(outs): ...@@ -74,7 +74,7 @@ def schedule_adaptive_pool(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule global_pool # schedule global_pool
elif OP.tag.startswith('adaptive_pool'): elif OP.tag.startswith('adaptive_pool'):
...@@ -137,7 +137,7 @@ def schedule_pool(outs, layout): ...@@ -137,7 +137,7 @@ def schedule_pool(outs, layout):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule pool # schedule pool
elif OP.tag.startswith('pool'): elif OP.tag.startswith('pool'):
......
...@@ -34,7 +34,7 @@ def _schedule_conv2d(outs): ...@@ -34,7 +34,7 @@ def _schedule_conv2d(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op) traverse(tensor.op)
# schedule conv2d # schedule conv2d
elif OP.tag.find("conv2d") >= 0: elif OP.tag.find("conv2d") >= 0:
...@@ -220,7 +220,7 @@ def schedule_reduce(outs): ...@@ -220,7 +220,7 @@ def schedule_reduce(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op) traverse(tensor.op)
elif OP.tag in ["comm_reduce", "comm_reduce_idx"]: elif OP.tag in ["comm_reduce", "comm_reduce_idx"]:
if OP.tag == "comm_reduce": if OP.tag == "comm_reduce":
...@@ -298,7 +298,7 @@ def schedule_dense(outs): ...@@ -298,7 +298,7 @@ def schedule_dense(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op) traverse(tensor.op)
# schedule dense # schedule dense
elif OP.tag == 'dense': elif OP.tag == 'dense':
...@@ -342,7 +342,7 @@ def schedule_pool(outs, layout): ...@@ -342,7 +342,7 @@ def schedule_pool(outs, layout):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op) traverse(tensor.op)
# schedule pool # schedule pool
elif OP.tag.startswith('pool'): elif OP.tag.startswith('pool'):
...@@ -386,7 +386,7 @@ def schedule_adaptive_pool(outs): ...@@ -386,7 +386,7 @@ def schedule_adaptive_pool(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op) traverse(tensor.op)
# schedule global_pool # schedule global_pool
elif OP.tag.startswith('adaptive_pool'): elif OP.tag.startswith('adaptive_pool'):
......
...@@ -149,7 +149,7 @@ def schedule_conv2d_NCHWc(outs): ...@@ -149,7 +149,7 @@ def schedule_conv2d_NCHWc(outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'conv2d' in op.tag: if 'conv2d' in op.tag:
_schedule_cl_spatialpack_NCHWc(s, op) _schedule_cl_spatialpack_NCHWc(s, op)
...@@ -378,7 +378,7 @@ def schedule_conv2d_nchw(outs): ...@@ -378,7 +378,7 @@ def schedule_conv2d_nchw(outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'conv2d' in op.tag: if 'conv2d' in op.tag:
_schedule_cl_spatialpack(s, op) _schedule_cl_spatialpack(s, op)
......
...@@ -55,7 +55,7 @@ def schedule_conv2d_nchw(outs): ...@@ -55,7 +55,7 @@ def schedule_conv2d_nchw(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].opengl() s[OP].opengl()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule conv2d_nchw # schedule conv2d_nchw
elif OP.tag.startswith('conv2d_nchw'): elif OP.tag.startswith('conv2d_nchw'):
......
...@@ -55,7 +55,7 @@ def schedule_dense(outs): ...@@ -55,7 +55,7 @@ def schedule_dense(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule dense # schedule dense
elif OP.tag == 'dense': elif OP.tag == 'dense':
......
...@@ -54,7 +54,7 @@ def schedule_adaptive_pool(outs): ...@@ -54,7 +54,7 @@ def schedule_adaptive_pool(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].opengl() s[OP].opengl()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule global_pool # schedule global_pool
elif OP.tag.startswith('adaptive_pool'): elif OP.tag.startswith('adaptive_pool'):
...@@ -108,7 +108,7 @@ def schedule_pool(outs, layout): ...@@ -108,7 +108,7 @@ def schedule_pool(outs, layout):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op not in scheduled_ops and tensor.op.input_tensors: if tensor.op not in scheduled_ops and isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op) traverse(tensor.op)
# schedule pool # schedule pool
elif OP.tag.startswith('pool'): elif OP.tag.startswith('pool'):
......
...@@ -496,3 +496,25 @@ def ndarray_size(array, dtype="int32"): ...@@ -496,3 +496,25 @@ def ndarray_size(array, dtype="int32"):
The resulting tensor. The resulting tensor.
""" """
return cpp.ndarray_size(array, dtype) return cpp.ndarray_size(array, dtype)
def where(condition, x, y):
"""Get the elements, either from x or y, depending on the condition.
Parameters
----------
condition : tvm.Tensor
The condition array.
x : tvm.Tensor
First array to be selected.
y : tvm.Tensor
Second array to be selected.
Returns
-------
result : tvm.Tensor
A Tensor selected from x or y depending on condition.
"""
return cpp.where(condition, x, y)
...@@ -49,7 +49,7 @@ def traverse_inline(s, final_op, callback): ...@@ -49,7 +49,7 @@ def traverse_inline(s, final_op, callback):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if isinstance(tensor.op, tvm.tensor.ComputeOp):
_traverse(tensor.op) _traverse(tensor.op)
callback(op) callback(op)
......
...@@ -58,7 +58,7 @@ def schedule_binary_dense(outs): ...@@ -58,7 +58,7 @@ def schedule_binary_dense(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule binary_dense # schedule binary_dense
elif OP.tag == 'binary_dense': elif OP.tag == 'binary_dense':
......
...@@ -36,7 +36,7 @@ def schedule_bitserial_conv2d(cfg, outs): ...@@ -36,7 +36,7 @@ def schedule_bitserial_conv2d(cfg, outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors and tensor.op not in scheduled_ops: for tensor in op.input_tensors and tensor.op not in scheduled_ops:
if tensor.op.input_tensors: if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op) traverse(tensor.op)
elif 'spatial_bitserial_conv_nchw' in op.tag or 'spatial_bitserial_conv_nhwc' in op.tag: elif 'spatial_bitserial_conv_nchw' in op.tag or 'spatial_bitserial_conv_nhwc' in op.tag:
......
...@@ -75,7 +75,7 @@ def schedule_bitserial_dense(cfg, outs): ...@@ -75,7 +75,7 @@ def schedule_bitserial_dense(cfg, outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op) traverse(tensor.op)
elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar': elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
......
...@@ -233,7 +233,7 @@ def schedule_conv2d(cfg, outs): ...@@ -233,7 +233,7 @@ def schedule_conv2d(cfg, outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'conv2d_nchw' in op.tag: if 'conv2d_nchw' in op.tag:
...@@ -284,7 +284,7 @@ def schedule_conv2d_nhwc_pack(cfg, outs): ...@@ -284,7 +284,7 @@ def schedule_conv2d_nhwc_pack(cfg, outs):
s[op].parallel(fused) s[op].parallel(fused)
s[op].vectorize(c) s[op].vectorize(c)
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'conv2d_nhwc_pack_int8' in op.tag: if 'conv2d_nhwc_pack_int8' in op.tag:
...@@ -335,7 +335,7 @@ def schedule_conv2d_nhwc(outs): ...@@ -335,7 +335,7 @@ def schedule_conv2d_nhwc(outs):
s[op].parallel(fused) s[op].parallel(fused)
s[op].vectorize(c) s[op].vectorize(c)
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'conv2d_nhwc' in op.tag: if 'conv2d_nhwc' in op.tag:
...@@ -648,7 +648,7 @@ def _schedule_conv2d_NCHWc(cfg, outs): ...@@ -648,7 +648,7 @@ def _schedule_conv2d_NCHWc(cfg, outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'conv2d_NCHWc' in op.tag: if 'conv2d_NCHWc' in op.tag:
......
...@@ -41,7 +41,7 @@ def schedule_conv2d_transpose(cfg, outs): ...@@ -41,7 +41,7 @@ def schedule_conv2d_transpose(cfg, outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'conv2d_transpose_nchw' in op.tag: if 'conv2d_transpose_nchw' in op.tag:
......
...@@ -144,7 +144,7 @@ def schedule_depthwise_conv2d_NCHWc(cfg, outs): ...@@ -144,7 +144,7 @@ def schedule_depthwise_conv2d_NCHWc(cfg, outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'depthwise_conv2d_NCHWc' in op.tag: if 'depthwise_conv2d_NCHWc' in op.tag:
conv_out = op.output(0) conv_out = op.output(0)
......
...@@ -94,7 +94,7 @@ def schedule_pool(outs, layout): ...@@ -94,7 +94,7 @@ def schedule_pool(outs, layout):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule pool # schedule pool
elif OP.tag.startswith('pool'): elif OP.tag.startswith('pool'):
...@@ -136,7 +136,7 @@ def schedule_adaptive_pool(outs): ...@@ -136,7 +136,7 @@ def schedule_adaptive_pool(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule pool # schedule pool
elif OP.tag.startswith('adaptive_pool'): elif OP.tag.startswith('adaptive_pool'):
......
...@@ -444,6 +444,35 @@ def verify_tile(in_shape, reps): ...@@ -444,6 +444,35 @@ def verify_tile(in_shape, reps):
for device in get_all_backend(): for device in get_all_backend():
check_device(device) check_device(device)
def verify_where(in_shape):
Cond = tvm.placeholder(shape=in_shape, name="cond")
dtype = Cond.dtype
A = tvm.placeholder(shape=in_shape, name="A")
B = tvm.placeholder(shape=in_shape, name="B")
C = topi.where(Cond, A, B)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C)
f = tvm.build(s, [Cond, A, B, C], device, name="where")
cond_npy = np.random.uniform(low=-1, high=1, size=in_shape).astype(dtype)
x_npy = np.random.uniform(size=in_shape).astype(dtype)
y_npy = np.random.uniform(size=in_shape).astype(dtype)
out_npy = np.where(cond_npy, x_npy, y_npy)
cond_nd = tvm.nd.array(cond_npy, ctx)
x_nd = tvm.nd.array(x_npy, ctx)
y_nd = tvm.nd.array(y_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
f(cond_nd, x_nd, y_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in get_all_backend():
check_device(device)
def test_strided_slice(): def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
...@@ -483,6 +512,10 @@ def test_reshape(): ...@@ -483,6 +512,10 @@ def test_reshape():
verify_reshape((16, ), (2, 2, 2, 2)) verify_reshape((16, ), (2, 2, 2, 2))
def test_where():
verify_where((1, 2, 3, 4))
def test_squeeze(): def test_squeeze():
verify_squeeze((1, 2, 3, 4), 0) verify_squeeze((1, 2, 3, 4), 0)
verify_squeeze((1, 2, 1, 4), None) verify_squeeze((1, 2, 1, 4), None)
...@@ -712,6 +745,32 @@ def test_ndarray_size(): ...@@ -712,6 +745,32 @@ def test_ndarray_size():
check_device(backend) check_device(backend)
def test_where_fusion():
"""integration test that where and zeros should be properly inlined"""
def check_device(device):
with tvm.target.create(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
data = tvm.placeholder((2, 1, 2, 4), 'int8', 'data')
w = tvm.placeholder((3, 1, 2, 2), 'int8', 'w')
conv1 = topi.nn.conv2d(data, w, 1, 0, 1, out_dtype='int32')
zeros = topi.full((2, 3, 1, 3), 'int32', tvm.const(0, dtype='int32'))
gt = topi.greater_equal(conv1, zeros)
one = topi.full((2, 3, 1, 3), 'int32', tvm.const(1, dtype='int32'))
two = topi.full((2, 3, 1, 3), 'int32', tvm.const(2, dtype='int32'))
where = topi.where(gt, one, two)
add = topi.add(conv1, where)
outs = [add]
s = topi.generic.schedule_conv2d_nchw(outs)
tvm.build(s, [data, w, add], target=backend)
for backend in get_all_backend():
check_device(backend)
if __name__ == "__main__": if __name__ == "__main__":
test_strided_slice() test_strided_slice()
test_concatenate() test_concatenate()
...@@ -719,6 +778,7 @@ if __name__ == "__main__": ...@@ -719,6 +778,7 @@ if __name__ == "__main__":
test_transpose() test_transpose()
test_expand_dims() test_expand_dims()
test_reshape() test_reshape()
test_where()
test_squeeze() test_squeeze()
test_split() test_split()
test_flip() test_flip()
...@@ -732,3 +792,4 @@ if __name__ == "__main__": ...@@ -732,3 +792,4 @@ if __name__ == "__main__":
test_shape() test_shape()
test_sequence_mask() test_sequence_mask()
test_ndarray_size() test_ndarray_size()
test_where_fusion()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment