Commit 310f56fa by Yuwei HU Committed by Tianqi Chen

[TOPI] update depthconv padding api; fix shared memory overflow (#365)

* modify depthconv padding

* fix shared memory overflow in depthconv schedule
parent c493baf3
......@@ -20,7 +20,10 @@ def schedule_depthwise_conv2d_nchw(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(PaddedInput, Filter, DepthwiseConv2d):
in_shape = get_const_tuple(PaddedInput.shape)
out_shape = get_const_tuple(DepthwiseConv2d.shape)
in_height = in_shape[2]
in_width = in_shape[3]
out_height = out_shape[2]
out_width = out_shape[3]
channel_multiplier = get_const_tuple(Filter.shape)[1]
......@@ -42,12 +45,14 @@ def schedule_depthwise_conv2d_nchw(outs):
num_vthread_x = 1
blocking_h = out_height
blocking_w = out_width
if out_height % 32 == 0:
if out_height % 32 == 0 or in_height >= 108:
blocking_h = 32
if out_width % 32 == 0:
blocking_w = 32
num_thread_x = 16
num_vthread_x = 2
elif in_width >= 108:
blocking_w = 32
block_y = tvm.thread_axis("blockIdx.y")
block_x = tvm.thread_axis("blockIdx.x")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
......
......@@ -121,8 +121,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding):
stride : tuple of two ints
The spatial stride along height and width
padding : str
'VALID' or 'SAME'
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
......@@ -169,8 +169,8 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
Stride : tvm.Tensor
1-D of size 2
padding : str
'VALID' or 'SAME'
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment