Commit 9d583cf5 by Wuwei Lin Committed by Tianqi Chen

[TOPI] Fix traverse function not inline zero-input op (#3623)

* Fix traverse_inline not inline zero input op properly

* Add where to python and set tag to broadcast

* Fix inline

* test

* fix test target

* fix
parent d4a51751
......@@ -807,7 +807,7 @@ inline Tensor where(const Tensor& condition,
const Tensor& x,
const Tensor& y,
std::string name = "T_where",
std::string tag = kInjective) {
std::string tag = kBroadcast) {
CHECK_EQ(x->shape.size(), y->shape.size())
<< "x and y must have the same shape.Got different number of dimension: "
<< x->shape.size() << " vs " << y->shape.size();
......
......@@ -327,7 +327,7 @@ def schedule_bitserial_conv2d_nhwc(cfg, outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'spatial_bitserial_conv_nhwc' in op.tag:
......
......@@ -164,7 +164,7 @@ def schedule_bitserial_dense(cfg, outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
......
......@@ -123,7 +123,7 @@ def schedule_conv2d_hwcn(outs):
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
elif operator.tag == 'conv2d_hwcn':
Apad = operator.input_tensors[0]
......
......@@ -120,7 +120,7 @@ def schedule_dense(cfg, outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
......
......@@ -198,7 +198,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d_nhwc':
......
......@@ -74,7 +74,7 @@ def schedule_adaptive_pool(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith('adaptive_pool'):
......@@ -137,7 +137,7 @@ def schedule_pool(outs, layout):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
......
......@@ -34,7 +34,7 @@ def _schedule_conv2d(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
# schedule conv2d
elif OP.tag.find("conv2d") >= 0:
......@@ -220,7 +220,7 @@ def schedule_reduce(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
elif OP.tag in ["comm_reduce", "comm_reduce_idx"]:
if OP.tag == "comm_reduce":
......@@ -298,7 +298,7 @@ def schedule_dense(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
......@@ -342,7 +342,7 @@ def schedule_pool(outs, layout):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
......@@ -386,7 +386,7 @@ def schedule_adaptive_pool(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith('adaptive_pool'):
......
......@@ -149,7 +149,7 @@ def schedule_conv2d_NCHWc(outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d' in op.tag:
_schedule_cl_spatialpack_NCHWc(s, op)
......@@ -378,7 +378,7 @@ def schedule_conv2d_nchw(outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d' in op.tag:
_schedule_cl_spatialpack(s, op)
......
......@@ -55,7 +55,7 @@ def schedule_conv2d_nchw(outs):
if OP not in s.outputs:
s[OP].opengl()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule conv2d_nchw
elif OP.tag.startswith('conv2d_nchw'):
......
......@@ -55,7 +55,7 @@ def schedule_dense(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
......
......@@ -54,7 +54,7 @@ def schedule_adaptive_pool(outs):
if OP not in s.outputs:
s[OP].opengl()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith('adaptive_pool'):
......@@ -108,7 +108,7 @@ def schedule_pool(outs, layout):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op not in scheduled_ops and tensor.op.input_tensors:
if tensor.op not in scheduled_ops and isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
......
......@@ -496,3 +496,25 @@ def ndarray_size(array, dtype="int32"):
The resulting tensor.
"""
return cpp.ndarray_size(array, dtype)
def where(condition, x, y):
"""Get the elements, either from x or y, depending on the condition.
Parameters
----------
condition : tvm.Tensor
The condition array.
x : tvm.Tensor
First array to be selected.
y : tvm.Tensor
Second array to be selected.
Returns
-------
result : tvm.Tensor
A Tensor selected from x or y depending on condition.
"""
return cpp.where(condition, x, y)
......@@ -49,7 +49,7 @@ def traverse_inline(s, final_op, callback):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
_traverse(tensor.op)
callback(op)
......
......@@ -58,7 +58,7 @@ def schedule_binary_dense(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule binary_dense
elif OP.tag == 'binary_dense':
......
......@@ -36,7 +36,7 @@ def schedule_bitserial_conv2d(cfg, outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors and tensor.op not in scheduled_ops:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
elif 'spatial_bitserial_conv_nchw' in op.tag or 'spatial_bitserial_conv_nhwc' in op.tag:
......
......@@ -75,7 +75,7 @@ def schedule_bitserial_dense(cfg, outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
......
......@@ -233,7 +233,7 @@ def schedule_conv2d(cfg, outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_nchw' in op.tag:
......@@ -284,7 +284,7 @@ def schedule_conv2d_nhwc_pack(cfg, outs):
s[op].parallel(fused)
s[op].vectorize(c)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_nhwc_pack_int8' in op.tag:
......@@ -335,7 +335,7 @@ def schedule_conv2d_nhwc(outs):
s[op].parallel(fused)
s[op].vectorize(c)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_nhwc' in op.tag:
......@@ -648,7 +648,7 @@ def _schedule_conv2d_NCHWc(cfg, outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_NCHWc' in op.tag:
......
......@@ -41,7 +41,7 @@ def schedule_conv2d_transpose(cfg, outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_transpose_nchw' in op.tag:
......
......@@ -144,7 +144,7 @@ def schedule_depthwise_conv2d_NCHWc(cfg, outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'depthwise_conv2d_NCHWc' in op.tag:
conv_out = op.output(0)
......
......@@ -94,7 +94,7 @@ def schedule_pool(outs, layout):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
......@@ -136,7 +136,7 @@ def schedule_adaptive_pool(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('adaptive_pool'):
......
......@@ -444,6 +444,35 @@ def verify_tile(in_shape, reps):
for device in get_all_backend():
check_device(device)
def verify_where(in_shape):
Cond = tvm.placeholder(shape=in_shape, name="cond")
dtype = Cond.dtype
A = tvm.placeholder(shape=in_shape, name="A")
B = tvm.placeholder(shape=in_shape, name="B")
C = topi.where(Cond, A, B)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C)
f = tvm.build(s, [Cond, A, B, C], device, name="where")
cond_npy = np.random.uniform(low=-1, high=1, size=in_shape).astype(dtype)
x_npy = np.random.uniform(size=in_shape).astype(dtype)
y_npy = np.random.uniform(size=in_shape).astype(dtype)
out_npy = np.where(cond_npy, x_npy, y_npy)
cond_nd = tvm.nd.array(cond_npy, ctx)
x_nd = tvm.nd.array(x_npy, ctx)
y_nd = tvm.nd.array(y_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
f(cond_nd, x_nd, y_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in get_all_backend():
check_device(device)
def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
......@@ -483,6 +512,10 @@ def test_reshape():
verify_reshape((16, ), (2, 2, 2, 2))
def test_where():
verify_where((1, 2, 3, 4))
def test_squeeze():
verify_squeeze((1, 2, 3, 4), 0)
verify_squeeze((1, 2, 1, 4), None)
......@@ -712,6 +745,32 @@ def test_ndarray_size():
check_device(backend)
def test_where_fusion():
"""integration test that where and zeros should be properly inlined"""
def check_device(device):
with tvm.target.create(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
data = tvm.placeholder((2, 1, 2, 4), 'int8', 'data')
w = tvm.placeholder((3, 1, 2, 2), 'int8', 'w')
conv1 = topi.nn.conv2d(data, w, 1, 0, 1, out_dtype='int32')
zeros = topi.full((2, 3, 1, 3), 'int32', tvm.const(0, dtype='int32'))
gt = topi.greater_equal(conv1, zeros)
one = topi.full((2, 3, 1, 3), 'int32', tvm.const(1, dtype='int32'))
two = topi.full((2, 3, 1, 3), 'int32', tvm.const(2, dtype='int32'))
where = topi.where(gt, one, two)
add = topi.add(conv1, where)
outs = [add]
s = topi.generic.schedule_conv2d_nchw(outs)
tvm.build(s, [data, w, add], target=backend)
for backend in get_all_backend():
check_device(backend)
if __name__ == "__main__":
test_strided_slice()
test_concatenate()
......@@ -719,6 +778,7 @@ if __name__ == "__main__":
test_transpose()
test_expand_dims()
test_reshape()
test_where()
test_squeeze()
test_split()
test_flip()
......@@ -732,3 +792,4 @@ if __name__ == "__main__":
test_shape()
test_sequence_mask()
test_ndarray_size()
test_where_fusion()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment