Commit bdfcec0e by masahi Committed by Tianqi Chen

update topi schedules (#1556)

parent 7b59b8ef
...@@ -327,6 +327,8 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, ...@@ -327,6 +327,8 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
def schedule_bitserial_conv2d_nhwc(outs): def schedule_bitserial_conv2d_nhwc(outs):
"""Raspverry pi schedule for bitserial conv2d""" """Raspverry pi schedule for bitserial conv2d"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op): def traverse(op):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
...@@ -334,7 +336,7 @@ def schedule_bitserial_conv2d_nhwc(outs): ...@@ -334,7 +336,7 @@ def schedule_bitserial_conv2d_nhwc(outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'spatial_bitserial_conv_nhwc' in op.tag: if 'spatial_bitserial_conv_nhwc' in op.tag:
...@@ -360,6 +362,7 @@ def schedule_bitserial_conv2d_nhwc(outs): ...@@ -360,6 +362,7 @@ def schedule_bitserial_conv2d_nhwc(outs):
_schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec, conv_out, output, outs[0]) kernel, kernel_q, kernel_vec, conv_out, output, outs[0])
scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -39,10 +39,11 @@ def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype): ...@@ -39,10 +39,11 @@ def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
def schedule_conv2d_nchw_arm_cpu(cfg, outs): def schedule_conv2d_nchw_arm_cpu(cfg, outs):
"""TOPI schedule callback""" """TOPI schedule callback"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _callback(op): def _callback(op):
# schedule conv2d # schedule conv2d
if 'spatial_conv_output' in op.tag: if 'spatial_conv_output' in op.tag and op not in scheduled_ops:
output = op.output(0) output = op.output(0)
conv = op.input_tensors[0] conv = op.input_tensors[0]
...@@ -64,6 +65,8 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs): ...@@ -64,6 +65,8 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
output = op.output(0) output = op.output(0)
_schedule_winograd(cfg, s, output, outs[0]) _schedule_winograd(cfg, s, output, outs[0])
scheduled_ops.append(op)
traverse_inline(s, outs[0].op, _callback) traverse_inline(s, outs[0].op, _callback)
return s return s
......
...@@ -79,8 +79,10 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs): ...@@ -79,8 +79,10 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs):
return s return s
scheduled_ops = []
def _callback(op): def _callback(op):
if op.tag == 'depthwise_conv2d_nchw': if op.tag == 'depthwise_conv2d_nchw' and op not in scheduled_ops:
output = op.output(0) output = op.output(0)
kernel = op.input_tensors[1] kernel = op.input_tensors[1]
data = op.input_tensors[0] data = op.input_tensors[0]
...@@ -90,5 +92,7 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs): ...@@ -90,5 +92,7 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs):
data = data_pad.op.input_tensors[0] data = data_pad.op.input_tensors[0]
_schedule(cfg, s, data, data_pad, kernel, output) _schedule(cfg, s, data, data_pad, kernel, output)
scheduled_ops.append(op)
traverse_inline(s, outs[0].op, _callback) traverse_inline(s, outs[0].op, _callback)
return s return s
...@@ -99,13 +99,15 @@ def schedule_conv2d_hwcn(outs): ...@@ -99,13 +99,15 @@ def schedule_conv2d_hwcn(outs):
sch[WW].bind(tx, thread_x) sch[WW].bind(tx, thread_x)
sch[WW].vectorize(fi) sch[WW].vectorize(fi)
scheduled_ops = []
def traverse(operator): def traverse(operator):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
if tag.is_broadcast(operator.tag): if tag.is_broadcast(operator.tag):
if operator not in sch.outputs: if operator not in sch.outputs:
sch[operator].compute_inline() sch[operator].compute_inline()
for tensor in operator.input_tensors: for tensor in operator.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
elif operator.tag == 'conv2d_hwcn': elif operator.tag == 'conv2d_hwcn':
Apad = operator.input_tensors[0] Apad = operator.input_tensors[0]
...@@ -117,5 +119,7 @@ def schedule_conv2d_hwcn(outs): ...@@ -117,5 +119,7 @@ def schedule_conv2d_hwcn(outs):
else: else:
raise RuntimeError("Unsupported operator: %s" % operator.tag) raise RuntimeError("Unsupported operator: %s" % operator.tag)
scheduled_ops.append(operator)
traverse(outs[0].op) traverse(outs[0].op)
return sch return sch
...@@ -492,6 +492,8 @@ def schedule_conv2d_small_batch(outs): ...@@ -492,6 +492,8 @@ def schedule_conv2d_small_batch(outs):
else: else:
conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L) conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L)
scheduled_ops = []
def traverse(OP): def traverse(OP):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
...@@ -499,7 +501,7 @@ def schedule_conv2d_small_batch(outs): ...@@ -499,7 +501,7 @@ def schedule_conv2d_small_batch(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule conv2d # schedule conv2d
if 'conv2d_nchw' in OP.tag: if 'conv2d_nchw' in OP.tag:
...@@ -510,6 +512,8 @@ def schedule_conv2d_small_batch(outs): ...@@ -510,6 +512,8 @@ def schedule_conv2d_small_batch(outs):
Output = OP.output(0) Output = OP.output(0)
schedule(temp, Filter, Output) schedule(temp, Filter, Output)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
......
...@@ -73,6 +73,8 @@ def schedule_conv2d_transpose_small_batch(outs): ...@@ -73,6 +73,8 @@ def schedule_conv2d_transpose_small_batch(outs):
else: else:
conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L) conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L)
scheduled_ops = []
def traverse(OP): def traverse(OP):
"""Internal travserse function""" """Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
...@@ -80,7 +82,7 @@ def schedule_conv2d_transpose_small_batch(outs): ...@@ -80,7 +82,7 @@ def schedule_conv2d_transpose_small_batch(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule conv2d_transpose_nchw # schedule conv2d_transpose_nchw
if 'conv2d_transpose_nchw' in OP.tag: if 'conv2d_transpose_nchw' in OP.tag:
...@@ -91,6 +93,8 @@ def schedule_conv2d_transpose_small_batch(outs): ...@@ -91,6 +93,8 @@ def schedule_conv2d_transpose_small_batch(outs):
Output = OP.output(0) Output = OP.output(0)
schedule(temp, Filter, Output) schedule(temp, Filter, Output)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
......
...@@ -86,6 +86,8 @@ def schedule_dense(outs): ...@@ -86,6 +86,8 @@ def schedule_dense(outs):
s[Dense].set_store_predicate(thread_x.var.equal(0)) s[Dense].set_store_predicate(thread_x.var.equal(0))
s[Out].set_store_predicate(thread_x.var.equal(0)) s[Out].set_store_predicate(thread_x.var.equal(0))
scheduled_ops = []
def traverse(OP): def traverse(OP):
"""Internal travserse function""" """Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
...@@ -93,7 +95,7 @@ def schedule_dense(outs): ...@@ -93,7 +95,7 @@ def schedule_dense(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule dense # schedule dense
elif OP.tag == 'dense': elif OP.tag == 'dense':
...@@ -102,5 +104,7 @@ def schedule_dense(outs): ...@@ -102,5 +104,7 @@ def schedule_dense(outs):
else: else:
raise RuntimeError("Unsupported operator: %s" % OP.tag) raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -101,6 +101,8 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -101,6 +101,8 @@ def schedule_depthwise_conv2d_nchw(outs):
s[FS].bind(ty, thread_y) s[FS].bind(ty, thread_y)
s[FS].bind(tx, thread_x) s[FS].bind(tx, thread_x)
scheduled_ops = []
def traverse(OP): def traverse(OP):
"""Internal travserse function""" """Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
...@@ -108,7 +110,7 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -108,7 +110,7 @@ def schedule_depthwise_conv2d_nchw(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule depthwise_conv2d # schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d_nchw': if OP.tag == 'depthwise_conv2d_nchw':
...@@ -119,6 +121,8 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -119,6 +121,8 @@ def schedule_depthwise_conv2d_nchw(outs):
DepthwiseConv2d = OP.output(0) DepthwiseConv2d = OP.output(0)
_schedule(PaddedInput, Filter, DepthwiseConv2d) _schedule(PaddedInput, Filter, DepthwiseConv2d)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -180,6 +184,8 @@ def schedule_depthwise_conv2d_nhwc(outs): ...@@ -180,6 +184,8 @@ def schedule_depthwise_conv2d_nhwc(outs):
fused = s[FS].fuse(fi, ci) fused = s[FS].fuse(fi, ci)
s[FS].bind(fused, thread_x) s[FS].bind(fused, thread_x)
scheduled_ops = []
def traverse(OP): def traverse(OP):
"""Internal travserse function""" """Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
...@@ -187,7 +193,7 @@ def schedule_depthwise_conv2d_nhwc(outs): ...@@ -187,7 +193,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule depthwise_conv2d # schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d_nhwc': if OP.tag == 'depthwise_conv2d_nhwc':
...@@ -198,6 +204,8 @@ def schedule_depthwise_conv2d_nhwc(outs): ...@@ -198,6 +204,8 @@ def schedule_depthwise_conv2d_nhwc(outs):
DepthwiseConv2d = OP.output(0) DepthwiseConv2d = OP.output(0)
_schedule(PaddedInput, Filter, DepthwiseConv2d) _schedule(PaddedInput, Filter, DepthwiseConv2d)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
......
...@@ -45,6 +45,8 @@ def schedule_global_pool(outs): ...@@ -45,6 +45,8 @@ def schedule_global_pool(outs):
else: else:
s[Pool].compute_at(s[Out], tx) s[Pool].compute_at(s[Out], tx)
scheduled_ops = []
def traverse(OP): def traverse(OP):
"""Internal travserse function""" """Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
...@@ -52,7 +54,7 @@ def schedule_global_pool(outs): ...@@ -52,7 +54,7 @@ def schedule_global_pool(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule global_pool # schedule global_pool
elif OP.tag.startswith('global_pool'): elif OP.tag.startswith('global_pool'):
...@@ -61,6 +63,8 @@ def schedule_global_pool(outs): ...@@ -61,6 +63,8 @@ def schedule_global_pool(outs):
else: else:
raise RuntimeError("Unsupported operator: %s" % OP.tag) raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -101,6 +105,8 @@ def schedule_pool(outs): ...@@ -101,6 +105,8 @@ def schedule_pool(outs):
else: else:
s[Pool].compute_at(s[Out], tx) s[Pool].compute_at(s[Out], tx)
scheduled_ops = []
def traverse(OP): def traverse(OP):
"""Internal travserse function""" """Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
...@@ -108,7 +114,7 @@ def schedule_pool(outs): ...@@ -108,7 +114,7 @@ def schedule_pool(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule pool # schedule pool
elif OP.tag.startswith('pool'): elif OP.tag.startswith('pool'):
...@@ -118,5 +124,7 @@ def schedule_pool(outs): ...@@ -118,5 +124,7 @@ def schedule_pool(outs):
else: else:
raise RuntimeError("Unsupported operator: %s" % OP.tag) raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -88,6 +88,7 @@ def schedule_reduce(outs): ...@@ -88,6 +88,7 @@ def schedule_reduce(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
sch = tvm.create_schedule([x.op for x in outs]) sch = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse_before_reduce(operator): def traverse_before_reduce(operator):
"""Internal travserse function""" """Internal travserse function"""
...@@ -96,10 +97,13 @@ def schedule_reduce(outs): ...@@ -96,10 +97,13 @@ def schedule_reduce(outs):
elif tag.is_injective(operator.tag): elif tag.is_injective(operator.tag):
sch[operator].compute_inline() sch[operator].compute_inline()
for tensor in operator.input_tensors: for tensor in operator.input_tensors:
traverse_before_reduce(tensor.op) if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
else: else:
raise RuntimeError("Unsupported operator: %s" % operator.tag) raise RuntimeError("Unsupported operator: %s" % operator.tag)
scheduled_ops.append(operator)
def traverse_after_reduce(operator): def traverse_after_reduce(operator):
"""Internal travserse function""" """Internal travserse function"""
if tag.is_broadcast(operator.tag): if tag.is_broadcast(operator.tag):
...@@ -107,13 +111,18 @@ def schedule_reduce(outs): ...@@ -107,13 +111,18 @@ def schedule_reduce(outs):
elif operator.tag == 'comm_reduce': elif operator.tag == 'comm_reduce':
_schedule_reduce(operator, sch, is_idx_reduce=False) _schedule_reduce(operator, sch, is_idx_reduce=False)
for tensor in operator.input_tensors: for tensor in operator.input_tensors:
traverse_before_reduce(tensor.op) if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
elif operator.tag == 'comm_reduce_idx': elif operator.tag == 'comm_reduce_idx':
_schedule_reduce(operator, sch, is_idx_reduce=True) _schedule_reduce(operator, sch, is_idx_reduce=True)
for tensor in operator.input_tensors[0].op.input_tensors: input_tensors = operator.input_tensors[0].op.input_tensors
traverse_before_reduce(tensor.op) for tensor in input_tensors:
if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
else: else:
raise RuntimeError("Unsupported operator: %s" % operator.tag) raise RuntimeError("Unsupported operator: %s" % operator.tag)
scheduled_ops.append(operator)
traverse_after_reduce(outs[0].op) traverse_after_reduce(outs[0].op)
return sch return sch
...@@ -11,6 +11,8 @@ def _default_schedule(outs): ...@@ -11,6 +11,8 @@ def _default_schedule(outs):
target = tvm.target.current_target() target = tvm.target.current_target()
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op): def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)""" """inline all one-to-one-mapping operators except the last stage (output)"""
if "nms" in op.tag: if "nms" in op.tag:
...@@ -32,9 +34,11 @@ def _default_schedule(outs): ...@@ -32,9 +34,11 @@ def _default_schedule(outs):
s[x].bind(bx, tvm.thread_axis("blockIdx.x")) s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
s[x].bind(tx, tvm.thread_axis("threadIdx.x")) s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
......
...@@ -113,6 +113,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_ ...@@ -113,6 +113,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op): def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)""" """inline all one-to-one-mapping operators except the last stage (output)"""
...@@ -120,12 +121,14 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_ ...@@ -120,12 +121,14 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if "4_5" in op.tag or "4_4" in op.tag or "2_7" in op.tag or "2_14" in op.tag \ if "4_5" in op.tag or "4_4" in op.tag or "2_7" in op.tag or "2_14" in op.tag \
or "1_16" in op.tag: or "1_16" in op.tag:
_schedule_cl_spatialpack_NCHWc(s, op) _schedule_cl_spatialpack_NCHWc(s, op)
scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -360,6 +363,7 @@ def schedule_conv2d_nchw(outs): ...@@ -360,6 +363,7 @@ def schedule_conv2d_nchw(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op): def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)""" """inline all one-to-one-mapping operators except the last stage (output)"""
...@@ -367,12 +371,14 @@ def schedule_conv2d_nchw(outs): ...@@ -367,12 +371,14 @@ def schedule_conv2d_nchw(outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if "4_5" in op.tag or "4_4" in op.tag or "2_7" in op.tag or "2_14" in op.tag \ if "4_5" in op.tag or "4_4" in op.tag or "2_7" in op.tag or "2_14" in op.tag \
or "1_16" in op.tag: or "1_16" in op.tag:
_schedule_cl_spatialpack(s, op) _schedule_cl_spatialpack(s, op)
scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
......
...@@ -144,6 +144,7 @@ def schedule_conv2d_nchw(outs): ...@@ -144,6 +144,7 @@ def schedule_conv2d_nchw(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op): def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)""" """inline all one-to-one-mapping operators except the last stage (output)"""
...@@ -151,7 +152,7 @@ def schedule_conv2d_nchw(outs): ...@@ -151,7 +152,7 @@ def schedule_conv2d_nchw(outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'im2col_conv_output' in op.tag: if 'im2col_conv_output' in op.tag:
...@@ -163,6 +164,8 @@ def schedule_conv2d_nchw(outs): ...@@ -163,6 +164,8 @@ def schedule_conv2d_nchw(outs):
if 'winograd_conv_output' in op.tag: if 'winograd_conv_output' in op.tag:
_schedule_winograd(s, op) _schedule_winograd(s, op)
scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
......
...@@ -81,6 +81,8 @@ def schedule_dense(outs): ...@@ -81,6 +81,8 @@ def schedule_dense(outs):
# bias = s[outs[0]].op.input_tensors[1] # bias = s[outs[0]].op.input_tensors[1]
# print(tvm.lower(s, [data, weight, bias, outs[0]], simple_mode=True)) # print(tvm.lower(s, [data, weight, bias, outs[0]], simple_mode=True))
scheduled_ops = []
def traverse(OP): def traverse(OP):
"""Internal travserse function""" """Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
...@@ -88,7 +90,7 @@ def schedule_dense(outs): ...@@ -88,7 +90,7 @@ def schedule_dense(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule dense # schedule dense
elif OP.tag == 'dense': elif OP.tag == 'dense':
...@@ -97,5 +99,7 @@ def schedule_dense(outs): ...@@ -97,5 +99,7 @@ def schedule_dense(outs):
else: else:
raise RuntimeError("Unsupported operator: %s" % OP.tag) raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -86,6 +86,8 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -86,6 +86,8 @@ def schedule_depthwise_conv2d_nchw(outs):
s[conv].vectorize(xi) s[conv].vectorize(xi)
s[conv].compute_at(s[output], ji) s[conv].compute_at(s[output], ji)
scheduled_ops = []
def traverse(op): def traverse(op):
"""Internal travserse function""" """Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
...@@ -93,7 +95,7 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -93,7 +95,7 @@ def schedule_depthwise_conv2d_nchw(outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule depthwise_conv2d # schedule depthwise_conv2d
...@@ -105,5 +107,7 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -105,5 +107,7 @@ def schedule_depthwise_conv2d_nchw(outs):
conv = op.output(0) conv = op.output(0)
_schedule(pad_data, kernel, conv) _schedule(pad_data, kernel, conv)
scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -21,6 +21,8 @@ def schedule_conv2d_nchw(outs): ...@@ -21,6 +21,8 @@ def schedule_conv2d_nchw(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _schedule(conv2d, data): def _schedule(conv2d, data):
if conv2d.op in s.outputs: if conv2d.op in s.outputs:
Out = conv2d Out = conv2d
...@@ -37,7 +39,7 @@ def schedule_conv2d_nchw(outs): ...@@ -37,7 +39,7 @@ def schedule_conv2d_nchw(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].opengl() s[OP].opengl()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule conv2d_nchw # schedule conv2d_nchw
elif OP.tag.startswith('conv2d_nchw'): elif OP.tag.startswith('conv2d_nchw'):
...@@ -50,5 +52,7 @@ def schedule_conv2d_nchw(outs): ...@@ -50,5 +52,7 @@ def schedule_conv2d_nchw(outs):
else: else:
raise RuntimeError("Unsupported operator: %s" % OP.tag) raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -22,6 +22,8 @@ def schedule_dense(outs): ...@@ -22,6 +22,8 @@ def schedule_dense(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _schedule(Dense): def _schedule(Dense):
if Dense.op in s.outputs: if Dense.op in s.outputs:
Out = Dense Out = Dense
...@@ -37,7 +39,7 @@ def schedule_dense(outs): ...@@ -37,7 +39,7 @@ def schedule_dense(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule dense # schedule dense
elif OP.tag == 'dense': elif OP.tag == 'dense':
...@@ -46,5 +48,7 @@ def schedule_dense(outs): ...@@ -46,5 +48,7 @@ def schedule_dense(outs):
else: else:
raise RuntimeError("Unsupported operator: %s" % OP.tag) raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -21,6 +21,8 @@ def schedule_global_pool(outs): ...@@ -21,6 +21,8 @@ def schedule_global_pool(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _schedule(Pool): def _schedule(Pool):
if Pool.op in s.outputs: if Pool.op in s.outputs:
Out = Pool Out = Pool
...@@ -36,7 +38,7 @@ def schedule_global_pool(outs): ...@@ -36,7 +38,7 @@ def schedule_global_pool(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].opengl() s[OP].opengl()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule global_pool # schedule global_pool
elif OP.tag.startswith('global_pool'): elif OP.tag.startswith('global_pool'):
...@@ -45,6 +47,8 @@ def schedule_global_pool(outs): ...@@ -45,6 +47,8 @@ def schedule_global_pool(outs):
else: else:
raise RuntimeError("Unsupported operator: %s" % OP.tag) raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -66,6 +70,8 @@ def schedule_pool(outs): ...@@ -66,6 +70,8 @@ def schedule_pool(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _schedule(PaddedInput, Pool): def _schedule(PaddedInput, Pool):
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp): if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
s[PaddedInput].opengl() s[PaddedInput].opengl()
...@@ -82,7 +88,7 @@ def schedule_pool(outs): ...@@ -82,7 +88,7 @@ def schedule_pool(outs):
if tag.is_broadcast(OP.tag): if tag.is_broadcast(OP.tag):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors and tensor.op not in scheduled_ops:
if tensor.op.input_tensors: if tensor.op.input_tensors:
traverse(tensor.op) traverse(tensor.op)
# schedule pool # schedule pool
...@@ -93,5 +99,7 @@ def schedule_pool(outs): ...@@ -93,5 +99,7 @@ def schedule_pool(outs):
else: else:
raise RuntimeError("Unsupported operator: %s" % OP.tag) raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -23,6 +23,7 @@ def schedule_binary_dense(outs): ...@@ -23,6 +23,7 @@ def schedule_binary_dense(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _schedule(A, B, C): def _schedule(A, B, C):
s[C].split(s[C].op.reduce_axis[0], factor=8) s[C].split(s[C].op.reduce_axis[0], factor=8)
...@@ -41,7 +42,7 @@ def schedule_binary_dense(outs): ...@@ -41,7 +42,7 @@ def schedule_binary_dense(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule binary_dense # schedule binary_dense
elif OP.tag == 'binary_dense': elif OP.tag == 'binary_dense':
...@@ -52,5 +53,7 @@ def schedule_binary_dense(outs): ...@@ -52,5 +53,7 @@ def schedule_binary_dense(outs):
else: else:
raise RuntimeError("Unsupported operator: %s" % OP.tag) raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -71,6 +71,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits ...@@ -71,6 +71,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits
def schedule_bitserial_conv2d(outs): def schedule_bitserial_conv2d(outs):
"""CPU schedule for bitserial convolutions NCHW and NHWC""" """CPU schedule for bitserial convolutions NCHW and NHWC"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op): def traverse(op):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
...@@ -79,7 +80,7 @@ def schedule_bitserial_conv2d(outs): ...@@ -79,7 +80,7 @@ def schedule_bitserial_conv2d(outs):
if tag.is_broadcast(op.tag) or 'elemwise' in op.tag: if tag.is_broadcast(op.tag) or 'elemwise' in op.tag:
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors and tensor.op not in scheduled_ops:
if tensor.op.input_tensors: if tensor.op.input_tensors:
traverse(tensor.op) traverse(tensor.op)
...@@ -111,6 +112,7 @@ def schedule_bitserial_conv2d(outs): ...@@ -111,6 +112,7 @@ def schedule_bitserial_conv2d(outs):
_schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec, kernel, kernel_q, kernel_vec,
conv_out, output, outs[0]) conv_out, output, outs[0])
scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
......
...@@ -188,6 +188,7 @@ def schedule_conv2d(outs): ...@@ -188,6 +188,7 @@ def schedule_conv2d(outs):
} }
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
scheduled_ops = []
def traverse(op): def traverse(op):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
...@@ -196,7 +197,7 @@ def schedule_conv2d(outs): ...@@ -196,7 +197,7 @@ def schedule_conv2d(outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'conv2d_nchw' in op.tag: if 'conv2d_nchw' in op.tag:
...@@ -223,6 +224,8 @@ def schedule_conv2d(outs): ...@@ -223,6 +224,8 @@ def schedule_conv2d(outs):
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, data, data_pad, data_vec, _AVX_SCH_TO_SCH_FUNC[type(sch)](s, data, data_pad, data_vec,
kernel, kernel_vec, conv_out, output, outs[0]) kernel, kernel_vec, conv_out, output, outs[0])
scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -232,6 +235,7 @@ def schedule_conv2d_nhwc(outs): ...@@ -232,6 +235,7 @@ def schedule_conv2d_nhwc(outs):
"""Create schedule for tensors""" """Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
output_op = outs[0].op output_op = outs[0].op
scheduled_ops = []
def traverse(op): def traverse(op):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
...@@ -246,7 +250,7 @@ def schedule_conv2d_nhwc(outs): ...@@ -246,7 +250,7 @@ def schedule_conv2d_nhwc(outs):
s[op].parallel(fused) s[op].parallel(fused)
s[op].vectorize(c) s[op].vectorize(c)
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'conv2d_nhwc' in op.tag: if 'conv2d_nhwc' in op.tag:
...@@ -275,6 +279,8 @@ def schedule_conv2d_nhwc(outs): ...@@ -275,6 +279,8 @@ def schedule_conv2d_nhwc(outs):
fused = s[C].fuse(n, h, w) fused = s[C].fuse(n, h, w)
s[C].parallel(fused) s[C].parallel(fused)
scheduled_ops.append(op)
traverse(output_op) traverse(output_op)
return s return s
...@@ -288,6 +294,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, ...@@ -288,6 +294,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc
} }
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op): def traverse(op):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
...@@ -296,7 +303,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, ...@@ -296,7 +303,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'conv2d_NCHWc' in op.tag: if 'conv2d_NCHWc' in op.tag:
...@@ -322,5 +329,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, ...@@ -322,5 +329,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec, _AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec,
kernel, conv_out, outs[0]) kernel, conv_out, outs[0])
scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -53,6 +53,7 @@ def schedule_dense(outs): ...@@ -53,6 +53,7 @@ def schedule_dense(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op): def traverse(op):
"""Traverse operators from computation graph""" """Traverse operators from computation graph"""
...@@ -61,7 +62,7 @@ def schedule_dense(outs): ...@@ -61,7 +62,7 @@ def schedule_dense(outs):
if op not in s.outputs: if op not in s.outputs:
s[op].compute_inline() s[op].compute_inline()
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
if 'dense' in op.tag: if 'dense' in op.tag:
...@@ -89,5 +90,7 @@ def schedule_dense(outs): ...@@ -89,5 +90,7 @@ def schedule_dense(outs):
# Parallelization # Parallelization
s[C].parallel(yo) s[C].parallel(yo)
scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -32,6 +32,7 @@ def schedule_pool(outs): ...@@ -32,6 +32,7 @@ def schedule_pool(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _schedule(PaddedInput, Pool): def _schedule(PaddedInput, Pool):
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp): if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
...@@ -45,7 +46,7 @@ def schedule_pool(outs): ...@@ -45,7 +46,7 @@ def schedule_pool(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule pool # schedule pool
elif OP.tag.startswith('pool'): elif OP.tag.startswith('pool'):
...@@ -54,6 +55,9 @@ def schedule_pool(outs): ...@@ -54,6 +55,9 @@ def schedule_pool(outs):
_schedule(PaddedInput, Pool) _schedule(PaddedInput, Pool)
else: else:
raise RuntimeError("Unsupported operator: %s" % OP.tag) raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -75,6 +79,8 @@ def schedule_global_pool(outs): ...@@ -75,6 +79,8 @@ def schedule_global_pool(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(OP): def traverse(OP):
"""Internal travserse function""" """Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
...@@ -82,7 +88,7 @@ def schedule_global_pool(outs): ...@@ -82,7 +88,7 @@ def schedule_global_pool(outs):
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
# schedule pool # schedule pool
elif OP.tag.startswith('global_pool'): elif OP.tag.startswith('global_pool'):
...@@ -90,5 +96,8 @@ def schedule_global_pool(outs): ...@@ -90,5 +96,8 @@ def schedule_global_pool(outs):
_parallel_sch(s[Pool]) _parallel_sch(s[Pool])
else: else:
raise RuntimeError("Unsupported operator: %s" % OP.tag) raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment