Commit bdfcec0e by masahi Committed by Tianqi Chen

update topi schedules (#1556)

parent 7b59b8ef
......@@ -327,6 +327,8 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
def schedule_bitserial_conv2d_nhwc(outs):
"""Raspverry pi schedule for bitserial conv2d"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
......@@ -334,7 +336,7 @@ def schedule_bitserial_conv2d_nhwc(outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'spatial_bitserial_conv_nhwc' in op.tag:
......@@ -360,6 +362,7 @@ def schedule_bitserial_conv2d_nhwc(outs):
_schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec, conv_out, output, outs[0])
scheduled_ops.append(op)
traverse(outs[0].op)
return s
......@@ -39,10 +39,11 @@ def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
def schedule_conv2d_nchw_arm_cpu(cfg, outs):
"""TOPI schedule callback"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _callback(op):
# schedule conv2d
if 'spatial_conv_output' in op.tag:
if 'spatial_conv_output' in op.tag and op not in scheduled_ops:
output = op.output(0)
conv = op.input_tensors[0]
......@@ -64,6 +65,8 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
output = op.output(0)
_schedule_winograd(cfg, s, output, outs[0])
scheduled_ops.append(op)
traverse_inline(s, outs[0].op, _callback)
return s
......
......@@ -79,8 +79,10 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs):
return s
scheduled_ops = []
def _callback(op):
if op.tag == 'depthwise_conv2d_nchw':
if op.tag == 'depthwise_conv2d_nchw' and op not in scheduled_ops:
output = op.output(0)
kernel = op.input_tensors[1]
data = op.input_tensors[0]
......@@ -90,5 +92,7 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs):
data = data_pad.op.input_tensors[0]
_schedule(cfg, s, data, data_pad, kernel, output)
scheduled_ops.append(op)
traverse_inline(s, outs[0].op, _callback)
return s
......@@ -99,13 +99,15 @@ def schedule_conv2d_hwcn(outs):
sch[WW].bind(tx, thread_x)
sch[WW].vectorize(fi)
scheduled_ops = []
def traverse(operator):
"""Traverse operators from computation graph"""
if tag.is_broadcast(operator.tag):
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
elif operator.tag == 'conv2d_hwcn':
Apad = operator.input_tensors[0]
......@@ -117,5 +119,7 @@ def schedule_conv2d_hwcn(outs):
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
scheduled_ops.append(operator)
traverse(outs[0].op)
return sch
......@@ -492,6 +492,8 @@ def schedule_conv2d_small_batch(outs):
else:
conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L)
scheduled_ops = []
def traverse(OP):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
......@@ -499,7 +501,7 @@ def schedule_conv2d_small_batch(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule conv2d
if 'conv2d_nchw' in OP.tag:
......@@ -510,6 +512,8 @@ def schedule_conv2d_small_batch(outs):
Output = OP.output(0)
schedule(temp, Filter, Output)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......
......@@ -73,6 +73,8 @@ def schedule_conv2d_transpose_small_batch(outs):
else:
conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L)
scheduled_ops = []
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
......@@ -80,7 +82,7 @@ def schedule_conv2d_transpose_small_batch(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule conv2d_transpose_nchw
if 'conv2d_transpose_nchw' in OP.tag:
......@@ -91,6 +93,8 @@ def schedule_conv2d_transpose_small_batch(outs):
Output = OP.output(0)
schedule(temp, Filter, Output)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......
......@@ -86,6 +86,8 @@ def schedule_dense(outs):
s[Dense].set_store_predicate(thread_x.var.equal(0))
s[Out].set_store_predicate(thread_x.var.equal(0))
scheduled_ops = []
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
......@@ -93,7 +95,7 @@ def schedule_dense(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
......@@ -102,5 +104,7 @@ def schedule_dense(outs):
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......@@ -101,6 +101,8 @@ def schedule_depthwise_conv2d_nchw(outs):
s[FS].bind(ty, thread_y)
s[FS].bind(tx, thread_x)
scheduled_ops = []
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
......@@ -108,7 +110,7 @@ def schedule_depthwise_conv2d_nchw(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d_nchw':
......@@ -119,6 +121,8 @@ def schedule_depthwise_conv2d_nchw(outs):
DepthwiseConv2d = OP.output(0)
_schedule(PaddedInput, Filter, DepthwiseConv2d)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......@@ -180,6 +184,8 @@ def schedule_depthwise_conv2d_nhwc(outs):
fused = s[FS].fuse(fi, ci)
s[FS].bind(fused, thread_x)
scheduled_ops = []
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
......@@ -187,7 +193,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d_nhwc':
......@@ -198,6 +204,8 @@ def schedule_depthwise_conv2d_nhwc(outs):
DepthwiseConv2d = OP.output(0)
_schedule(PaddedInput, Filter, DepthwiseConv2d)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......
......@@ -45,6 +45,8 @@ def schedule_global_pool(outs):
else:
s[Pool].compute_at(s[Out], tx)
scheduled_ops = []
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
......@@ -52,7 +54,7 @@ def schedule_global_pool(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith('global_pool'):
......@@ -61,6 +63,8 @@ def schedule_global_pool(outs):
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......@@ -101,6 +105,8 @@ def schedule_pool(outs):
else:
s[Pool].compute_at(s[Out], tx)
scheduled_ops = []
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
......@@ -108,7 +114,7 @@ def schedule_pool(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
......@@ -118,5 +124,7 @@ def schedule_pool(outs):
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......@@ -88,6 +88,7 @@ def schedule_reduce(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
sch = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse_before_reduce(operator):
"""Internal travserse function"""
......@@ -96,10 +97,13 @@ def schedule_reduce(outs):
elif tag.is_injective(operator.tag):
sch[operator].compute_inline()
for tensor in operator.input_tensors:
traverse_before_reduce(tensor.op)
if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
scheduled_ops.append(operator)
def traverse_after_reduce(operator):
"""Internal travserse function"""
if tag.is_broadcast(operator.tag):
......@@ -107,13 +111,18 @@ def schedule_reduce(outs):
elif operator.tag == 'comm_reduce':
_schedule_reduce(operator, sch, is_idx_reduce=False)
for tensor in operator.input_tensors:
traverse_before_reduce(tensor.op)
if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
elif operator.tag == 'comm_reduce_idx':
_schedule_reduce(operator, sch, is_idx_reduce=True)
for tensor in operator.input_tensors[0].op.input_tensors:
traverse_before_reduce(tensor.op)
input_tensors = operator.input_tensors[0].op.input_tensors
for tensor in input_tensors:
if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
scheduled_ops.append(operator)
traverse_after_reduce(outs[0].op)
return sch
......@@ -11,6 +11,8 @@ def _default_schedule(outs):
target = tvm.target.current_target()
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
if "nms" in op.tag:
......@@ -32,9 +34,11 @@ def _default_schedule(outs):
s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
......
......@@ -113,6 +113,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
......@@ -120,12 +121,14 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if "4_5" in op.tag or "4_4" in op.tag or "2_7" in op.tag or "2_14" in op.tag \
or "1_16" in op.tag:
_schedule_cl_spatialpack_NCHWc(s, op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
......@@ -360,6 +363,7 @@ def schedule_conv2d_nchw(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
......@@ -367,12 +371,14 @@ def schedule_conv2d_nchw(outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if "4_5" in op.tag or "4_4" in op.tag or "2_7" in op.tag or "2_14" in op.tag \
or "1_16" in op.tag:
_schedule_cl_spatialpack(s, op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
......
......@@ -144,6 +144,7 @@ def schedule_conv2d_nchw(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
......@@ -151,7 +152,7 @@ def schedule_conv2d_nchw(outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'im2col_conv_output' in op.tag:
......@@ -163,6 +164,8 @@ def schedule_conv2d_nchw(outs):
if 'winograd_conv_output' in op.tag:
_schedule_winograd(s, op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
......
......@@ -81,6 +81,8 @@ def schedule_dense(outs):
# bias = s[outs[0]].op.input_tensors[1]
# print(tvm.lower(s, [data, weight, bias, outs[0]], simple_mode=True))
scheduled_ops = []
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
......@@ -88,7 +90,7 @@ def schedule_dense(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
......@@ -97,5 +99,7 @@ def schedule_dense(outs):
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......@@ -86,6 +86,8 @@ def schedule_depthwise_conv2d_nchw(outs):
s[conv].vectorize(xi)
s[conv].compute_at(s[output], ji)
scheduled_ops = []
def traverse(op):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
......@@ -93,7 +95,7 @@ def schedule_depthwise_conv2d_nchw(outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule depthwise_conv2d
......@@ -105,5 +107,7 @@ def schedule_depthwise_conv2d_nchw(outs):
conv = op.output(0)
_schedule(pad_data, kernel, conv)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
......@@ -21,6 +21,8 @@ def schedule_conv2d_nchw(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _schedule(conv2d, data):
if conv2d.op in s.outputs:
Out = conv2d
......@@ -37,7 +39,7 @@ def schedule_conv2d_nchw(outs):
if OP not in s.outputs:
s[OP].opengl()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule conv2d_nchw
elif OP.tag.startswith('conv2d_nchw'):
......@@ -50,5 +52,7 @@ def schedule_conv2d_nchw(outs):
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......@@ -22,6 +22,8 @@ def schedule_dense(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _schedule(Dense):
if Dense.op in s.outputs:
Out = Dense
......@@ -37,7 +39,7 @@ def schedule_dense(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
......@@ -46,5 +48,7 @@ def schedule_dense(outs):
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......@@ -21,6 +21,8 @@ def schedule_global_pool(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _schedule(Pool):
if Pool.op in s.outputs:
Out = Pool
......@@ -36,7 +38,7 @@ def schedule_global_pool(outs):
if OP not in s.outputs:
s[OP].opengl()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith('global_pool'):
......@@ -45,6 +47,8 @@ def schedule_global_pool(outs):
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......@@ -66,6 +70,8 @@ def schedule_pool(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _schedule(PaddedInput, Pool):
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
s[PaddedInput].opengl()
......@@ -82,7 +88,7 @@ def schedule_pool(outs):
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
for tensor in OP.input_tensors and tensor.op not in scheduled_ops:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule pool
......@@ -93,5 +99,7 @@ def schedule_pool(outs):
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......@@ -23,6 +23,7 @@ def schedule_binary_dense(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _schedule(A, B, C):
s[C].split(s[C].op.reduce_axis[0], factor=8)
......@@ -41,7 +42,7 @@ def schedule_binary_dense(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule binary_dense
elif OP.tag == 'binary_dense':
......@@ -52,5 +53,7 @@ def schedule_binary_dense(outs):
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......@@ -71,6 +71,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits
def schedule_bitserial_conv2d(outs):
"""CPU schedule for bitserial convolutions NCHW and NHWC"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
......@@ -79,7 +80,7 @@ def schedule_bitserial_conv2d(outs):
if tag.is_broadcast(op.tag) or 'elemwise' in op.tag:
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
for tensor in op.input_tensors and tensor.op not in scheduled_ops:
if tensor.op.input_tensors:
traverse(tensor.op)
......@@ -111,6 +112,7 @@ def schedule_bitserial_conv2d(outs):
_schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, outs[0])
scheduled_ops.append(op)
traverse(outs[0].op)
return s
......
......@@ -188,6 +188,7 @@ def schedule_conv2d(outs):
}
s = tvm.create_schedule([x.op for x in outs])
target = tvm.target.current_target(allow_none=False)
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
......@@ -196,7 +197,7 @@ def schedule_conv2d(outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_nchw' in op.tag:
......@@ -223,6 +224,8 @@ def schedule_conv2d(outs):
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, data, data_pad, data_vec,
kernel, kernel_vec, conv_out, output, outs[0])
scheduled_ops.append(op)
traverse(outs[0].op)
return s
......@@ -232,6 +235,7 @@ def schedule_conv2d_nhwc(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
output_op = outs[0].op
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
......@@ -246,7 +250,7 @@ def schedule_conv2d_nhwc(outs):
s[op].parallel(fused)
s[op].vectorize(c)
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_nhwc' in op.tag:
......@@ -275,6 +279,8 @@ def schedule_conv2d_nhwc(outs):
fused = s[C].fuse(n, h, w)
s[C].parallel(fused)
scheduled_ops.append(op)
traverse(output_op)
return s
......@@ -288,6 +294,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc
}
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
......@@ -296,7 +303,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_NCHWc' in op.tag:
......@@ -322,5 +329,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec,
kernel, conv_out, outs[0])
scheduled_ops.append(op)
traverse(outs[0].op)
return s
......@@ -53,6 +53,7 @@ def schedule_dense(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
......@@ -61,7 +62,7 @@ def schedule_dense(outs):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'dense' in op.tag:
......@@ -89,5 +90,7 @@ def schedule_dense(outs):
# Parallelization
s[C].parallel(yo)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
......@@ -32,6 +32,7 @@ def schedule_pool(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _schedule(PaddedInput, Pool):
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
......@@ -45,7 +46,7 @@ def schedule_pool(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
......@@ -54,6 +55,9 @@ def schedule_pool(outs):
_schedule(PaddedInput, Pool)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
......@@ -75,6 +79,8 @@ def schedule_global_pool(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
......@@ -82,7 +88,7 @@ def schedule_global_pool(outs):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('global_pool'):
......@@ -90,5 +96,8 @@ def schedule_global_pool(outs):
_parallel_sch(s[Pool])
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment