Commit fa53dbdf by Yuwei HU Committed by Tianqi Chen

modify schedule_depthwise_conv2d_nchw (#350)

parent ed9f3897
......@@ -36,64 +36,62 @@ def schedule_depthwise_conv2d_nchw(outs):
Output = outs[0].op.output(0)
s[DepthwiseConv2d].set_scope("local")
# schedule parameters
num_thread_x = 8
num_thread_y = 8
num_vthread_x = 1
num_thread_x = 8
num_vthread_y = 1
num_vthread_x = 1
blocking_h = out_height
blocking_w = out_width
if out_height % 32 == 0:
blocking_h = 32
num_thread_x = 2
num_vthread_x = 2
if out_width % 32 == 0:
blocking_w = 32
num_thread_y = 16
num_vthread_y = 2
block_x = tvm.thread_axis("blockIdx.x")
num_thread_x = 16
num_vthread_x = 2
block_y = tvm.thread_axis("blockIdx.y")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
block_x = tvm.thread_axis("blockIdx.x")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy")
thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
# split and bind
bx, bxi = s[Output].split(Output.op.axis[1], factor=channel_multiplier)
s[Output].reorder(Output.op.axis[2], Output.op.axis[3], bxi)
bx = s[Output].fuse(Output.op.axis[0], bx)
s[Output].bind(bx, block_x)
by1, y1i = s[Output].split(Output.op.axis[2], factor=blocking_h)
tvx, vxi = s[Output].split(y1i, nparts=num_vthread_x)
tx, xi = s[Output].split(vxi, nparts=num_thread_x)
by2, y2i = s[Output].split(Output.op.axis[3], factor=blocking_w)
tvy, vyi = s[Output].split(y2i, nparts=num_vthread_y)
by, byi = s[Output].split(Output.op.axis[1], factor=channel_multiplier)
s[Output].reorder(Output.op.axis[2], Output.op.axis[3], byi)
by = s[Output].fuse(Output.op.axis[0], by)
s[Output].bind(by, block_y)
bx1, x1i = s[Output].split(Output.op.axis[2], factor=blocking_h)
tvy, vyi = s[Output].split(x1i, nparts=num_vthread_y)
ty, yi = s[Output].split(vyi, nparts=num_thread_y)
s[Output].reorder(by1, by2, tvx, tvy, tx, ty, xi, yi)
by = s[Output].fuse(by1, by2)
s[Output].bind(tvx, thread_vx)
bx2, x2i = s[Output].split(Output.op.axis[3], factor=blocking_w)
tvx, vxi = s[Output].split(x2i, nparts=num_vthread_x)
tx, xi = s[Output].split(vxi, nparts=num_thread_x)
s[Output].reorder(bx1, bx2, tvy, tvx, ty, tx, yi, xi)
bx = s[Output].fuse(bx1, bx2)
s[Output].bind(bx, block_x)
s[Output].bind(tvy, thread_vy)
s[Output].bind(tx, thread_x)
s[Output].bind(tvx, thread_vx)
s[Output].bind(ty, thread_y)
s[Output].bind(by, block_y)
s[Output].bind(tx, thread_x)
# local memory load
s[IL].compute_at(s[Output], ty)
s[FL].compute_at(s[Output], ty)
s[IL].compute_at(s[Output], tx)
s[FL].compute_at(s[Output], tx)
if DepthwiseConv2d.op in s.outputs:
s[CL].compute_at(s[Output], ty)
s[CL].compute_at(s[Output], tx)
else:
s[DepthwiseConv2d].compute_at(s[Output], ty)
s[DepthwiseConv2d].compute_at(s[Output], tx)
# input's shared memory load
s[IS].compute_at(s[Output], by)
tx, xi = s[IS].split(IS.op.axis[2], nparts=num_thread_x)
ty, yi = s[IS].split(IS.op.axis[3], nparts=num_thread_y)
s[IS].bind(tx, thread_x)
s[IS].compute_at(s[Output], bx)
ty, yi = s[IS].split(IS.op.axis[2], nparts=num_thread_y)
tx, xi = s[IS].split(IS.op.axis[3], nparts=num_thread_x)
s[IS].bind(ty, thread_y)
s[IS].bind(tx, thread_x)
# filter's shared memory load
s[FS].compute_at(s[Output], by)
s[FS].compute_at(s[Output], bx)
s[FS].reorder(FS.op.axis[2], FS.op.axis[3], FS.op.axis[1])
tx, xi = s[FS].split(FS.op.axis[2], nparts=num_thread_x)
ty, yi = s[FS].split(FS.op.axis[3], nparts=num_thread_y)
s[FS].bind(tx, thread_x)
ty, yi = s[FS].split(FS.op.axis[2], nparts=num_thread_y)
tx, xi = s[FS].split(FS.op.axis[3], nparts=num_thread_x)
s[FS].bind(ty, thread_y)
s[FS].bind(tx, thread_x)
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
......
......@@ -97,9 +97,9 @@ def test_depthwise_conv2d_nchw():
print("Stride = (%d, %d)" % (stride_h, stride_w))
print("padding = %s\n" % padding)
print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
print("average time cost of 1000 runs (depthwise_conv2d) = %g sec" % tcost_1)
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2)
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3)
print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1*1e6))
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us" % (tcost_2*1e6))
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us" % (tcost_3*1e6))
# correctness
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
......@@ -186,9 +186,9 @@ def test_depthwise_conv2d_nhwc():
print("Stride = (%d, %d)" % (stride_h, stride_w))
print("padding = %s\n" % padding)
print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
print("average time cost of 1000 runs (depthwise_conv2d) = %g sec" % tcost_1)
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2)
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3)
print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1*1e6))
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us" % (tcost_2*1e6))
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us" % (tcost_3*1e6))
# correctness
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment