Commit fa53dbdf by Yuwei HU Committed by Tianqi Chen

modify schedule_depthwise_conv2d_nchw (#350)

parent ed9f3897
...@@ -36,64 +36,62 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -36,64 +36,62 @@ def schedule_depthwise_conv2d_nchw(outs):
Output = outs[0].op.output(0) Output = outs[0].op.output(0)
s[DepthwiseConv2d].set_scope("local") s[DepthwiseConv2d].set_scope("local")
# schedule parameters # schedule parameters
num_thread_x = 8
num_thread_y = 8 num_thread_y = 8
num_vthread_x = 1 num_thread_x = 8
num_vthread_y = 1 num_vthread_y = 1
num_vthread_x = 1
blocking_h = out_height blocking_h = out_height
blocking_w = out_width blocking_w = out_width
if out_height % 32 == 0: if out_height % 32 == 0:
blocking_h = 32 blocking_h = 32
num_thread_x = 2
num_vthread_x = 2
if out_width % 32 == 0: if out_width % 32 == 0:
blocking_w = 32 blocking_w = 32
num_thread_y = 16 num_thread_x = 16
num_vthread_y = 2 num_vthread_x = 2
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y") block_y = tvm.thread_axis("blockIdx.y")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") block_x = tvm.thread_axis("blockIdx.x")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y") thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx") thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy") thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy")
thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
# split and bind # split and bind
bx, bxi = s[Output].split(Output.op.axis[1], factor=channel_multiplier) by, byi = s[Output].split(Output.op.axis[1], factor=channel_multiplier)
s[Output].reorder(Output.op.axis[2], Output.op.axis[3], bxi) s[Output].reorder(Output.op.axis[2], Output.op.axis[3], byi)
bx = s[Output].fuse(Output.op.axis[0], bx) by = s[Output].fuse(Output.op.axis[0], by)
s[Output].bind(bx, block_x) s[Output].bind(by, block_y)
by1, y1i = s[Output].split(Output.op.axis[2], factor=blocking_h) bx1, x1i = s[Output].split(Output.op.axis[2], factor=blocking_h)
tvx, vxi = s[Output].split(y1i, nparts=num_vthread_x) tvy, vyi = s[Output].split(x1i, nparts=num_vthread_y)
tx, xi = s[Output].split(vxi, nparts=num_thread_x)
by2, y2i = s[Output].split(Output.op.axis[3], factor=blocking_w)
tvy, vyi = s[Output].split(y2i, nparts=num_vthread_y)
ty, yi = s[Output].split(vyi, nparts=num_thread_y) ty, yi = s[Output].split(vyi, nparts=num_thread_y)
s[Output].reorder(by1, by2, tvx, tvy, tx, ty, xi, yi) bx2, x2i = s[Output].split(Output.op.axis[3], factor=blocking_w)
by = s[Output].fuse(by1, by2) tvx, vxi = s[Output].split(x2i, nparts=num_vthread_x)
s[Output].bind(tvx, thread_vx) tx, xi = s[Output].split(vxi, nparts=num_thread_x)
s[Output].reorder(bx1, bx2, tvy, tvx, ty, tx, yi, xi)
bx = s[Output].fuse(bx1, bx2)
s[Output].bind(bx, block_x)
s[Output].bind(tvy, thread_vy) s[Output].bind(tvy, thread_vy)
s[Output].bind(tx, thread_x) s[Output].bind(tvx, thread_vx)
s[Output].bind(ty, thread_y) s[Output].bind(ty, thread_y)
s[Output].bind(by, block_y) s[Output].bind(tx, thread_x)
# local memory load # local memory load
s[IL].compute_at(s[Output], ty) s[IL].compute_at(s[Output], tx)
s[FL].compute_at(s[Output], ty) s[FL].compute_at(s[Output], tx)
if DepthwiseConv2d.op in s.outputs: if DepthwiseConv2d.op in s.outputs:
s[CL].compute_at(s[Output], ty) s[CL].compute_at(s[Output], tx)
else: else:
s[DepthwiseConv2d].compute_at(s[Output], ty) s[DepthwiseConv2d].compute_at(s[Output], tx)
# input's shared memory load # input's shared memory load
s[IS].compute_at(s[Output], by) s[IS].compute_at(s[Output], bx)
tx, xi = s[IS].split(IS.op.axis[2], nparts=num_thread_x) ty, yi = s[IS].split(IS.op.axis[2], nparts=num_thread_y)
ty, yi = s[IS].split(IS.op.axis[3], nparts=num_thread_y) tx, xi = s[IS].split(IS.op.axis[3], nparts=num_thread_x)
s[IS].bind(tx, thread_x)
s[IS].bind(ty, thread_y) s[IS].bind(ty, thread_y)
s[IS].bind(tx, thread_x)
# filter's shared memory load # filter's shared memory load
s[FS].compute_at(s[Output], by) s[FS].compute_at(s[Output], bx)
s[FS].reorder(FS.op.axis[2], FS.op.axis[3], FS.op.axis[1]) s[FS].reorder(FS.op.axis[2], FS.op.axis[3], FS.op.axis[1])
tx, xi = s[FS].split(FS.op.axis[2], nparts=num_thread_x) ty, yi = s[FS].split(FS.op.axis[2], nparts=num_thread_y)
ty, yi = s[FS].split(FS.op.axis[3], nparts=num_thread_y) tx, xi = s[FS].split(FS.op.axis[3], nparts=num_thread_x)
s[FS].bind(tx, thread_x)
s[FS].bind(ty, thread_y) s[FS].bind(ty, thread_y)
s[FS].bind(tx, thread_x)
def traverse(OP): def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
......
...@@ -97,9 +97,9 @@ def test_depthwise_conv2d_nchw(): ...@@ -97,9 +97,9 @@ def test_depthwise_conv2d_nchw():
print("Stride = (%d, %d)" % (stride_h, stride_w)) print("Stride = (%d, %d)" % (stride_h, stride_w))
print("padding = %s\n" % padding) print("padding = %s\n" % padding)
print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape))) print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
print("average time cost of 1000 runs (depthwise_conv2d) = %g sec" % tcost_1) print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1*1e6))
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2) print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us" % (tcost_2*1e6))
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3) print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us" % (tcost_3*1e6))
# correctness # correctness
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding) depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape)) scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
...@@ -186,9 +186,9 @@ def test_depthwise_conv2d_nhwc(): ...@@ -186,9 +186,9 @@ def test_depthwise_conv2d_nhwc():
print("Stride = (%d, %d)" % (stride_h, stride_w)) print("Stride = (%d, %d)" % (stride_h, stride_w))
print("padding = %s\n" % padding) print("padding = %s\n" % padding)
print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape))) print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
print("average time cost of 1000 runs (depthwise_conv2d) = %g sec" % tcost_1) print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1*1e6))
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2) print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us" % (tcost_2*1e6))
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3) print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us" % (tcost_3*1e6))
# correctness # correctness
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding) depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape)) scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment