Commit 310f56fa by Yuwei HU Committed by Tianqi Chen

[TOPI] update depthconv padding api; fix shared memory overflow (#365)

* modify depthconv padding

* fix shared memory overflow in depthconv schedule
parent c493baf3
...@@ -20,7 +20,10 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -20,7 +20,10 @@ def schedule_depthwise_conv2d_nchw(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
def _schedule(PaddedInput, Filter, DepthwiseConv2d): def _schedule(PaddedInput, Filter, DepthwiseConv2d):
in_shape = get_const_tuple(PaddedInput.shape)
out_shape = get_const_tuple(DepthwiseConv2d.shape) out_shape = get_const_tuple(DepthwiseConv2d.shape)
in_height = in_shape[2]
in_width = in_shape[3]
out_height = out_shape[2] out_height = out_shape[2]
out_width = out_shape[3] out_width = out_shape[3]
channel_multiplier = get_const_tuple(Filter.shape)[1] channel_multiplier = get_const_tuple(Filter.shape)[1]
...@@ -42,12 +45,14 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -42,12 +45,14 @@ def schedule_depthwise_conv2d_nchw(outs):
num_vthread_x = 1 num_vthread_x = 1
blocking_h = out_height blocking_h = out_height
blocking_w = out_width blocking_w = out_width
if out_height % 32 == 0: if out_height % 32 == 0 or in_height >= 108:
blocking_h = 32 blocking_h = 32
if out_width % 32 == 0: if out_width % 32 == 0:
blocking_w = 32 blocking_w = 32
num_thread_x = 16 num_thread_x = 16
num_vthread_x = 2 num_vthread_x = 2
elif in_width >= 108:
blocking_w = 32
block_y = tvm.thread_axis("blockIdx.y") block_y = tvm.thread_axis("blockIdx.y")
block_x = tvm.thread_axis("blockIdx.x") block_x = tvm.thread_axis("blockIdx.x")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y") thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
......
...@@ -121,8 +121,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding): ...@@ -121,8 +121,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding):
stride : tuple of two ints stride : tuple of two ints
The spatial stride along height and width The spatial stride along height and width
padding : str padding : int or str
'VALID' or 'SAME' Padding size, or ['VALID', 'SAME']
Returns Returns
------- -------
...@@ -169,8 +169,8 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding): ...@@ -169,8 +169,8 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
Stride : tvm.Tensor Stride : tvm.Tensor
1-D of size 2 1-D of size 2
padding : str padding : int or str
'VALID' or 'SAME' Padding size, or ['VALID', 'SAME']
Returns Returns
------- -------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment