Commit 07e56b9a by Yuwei HU Committed by Tianqi Chen

update depthwise_conv2d schedule and testing (#328)

parent 8edd047b
...@@ -4,6 +4,6 @@ from __future__ import absolute_import as _abs ...@@ -4,6 +4,6 @@ from __future__ import absolute_import as _abs
from .conv2d_nchw import schedule_conv2d_nchw from .conv2d_nchw import schedule_conv2d_nchw
from .conv2d_hwcn import schedule_conv2d_hwcn from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d_map import schedule_depthwise_conv2d_map from .depthwise_conv2d import schedule_depthwise_conv2d
from .reduction import schedule_reduce from .reduction import schedule_reduce
from .broadcast import schedule_broadcast_to from .broadcast import schedule_broadcast_to
...@@ -3,25 +3,24 @@ ...@@ -3,25 +3,24 @@
import tvm import tvm
from ..util import get_const_tuple from ..util import get_const_tuple
def schedule_depthwise_conv2d_map(op):
"""Schedule for depthwise_conv2d map ops.
This include scale-shift and relu. def schedule_depthwise_conv2d(outs):
"""Schedule for depthwise_conv2d.
Parameters Parameters
---------- ----------
op: Operation outs: Array of Tensor
The symbolic description of the operation, should be depthwise_conv2d or The computation graph description of depthwise_conv2d
depthwise_conv2d followed by a sequence of one-to-one-mapping operators. in the format of an array of tensors.
Returns Returns
------- -------
s: Schedule s: Schedule
The computation schedule for the op. The computation schedule for depthwise_conv2d.
""" """
s = tvm.create_schedule(op) outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d): s = tvm.create_schedule([x.op for x in outs])
"""Schedule for depthwise_conv2d declared in topi.nn.conv""" def _schedule(PaddedInput, Filter, DepthwiseConv2d):
out_shape = get_const_tuple(DepthwiseConv2d.shape) out_shape = get_const_tuple(DepthwiseConv2d.shape)
out_height = out_shape[2] out_height = out_shape[2]
out_width = out_shape[3] out_width = out_shape[3]
...@@ -35,27 +34,27 @@ def schedule_depthwise_conv2d_map(op): ...@@ -35,27 +34,27 @@ def schedule_depthwise_conv2d_map(op):
Output = DepthwiseConv2d Output = DepthwiseConv2d
CL = s.cache_write(DepthwiseConv2d, "local") CL = s.cache_write(DepthwiseConv2d, "local")
else: else:
Output = op.output(0) Output = outs[0].op.output(0)
s[DepthwiseConv2d].set_scope("local") s[DepthwiseConv2d].set_scope("local")
# schedule parameters # schedule parameters
num_thread = 8 num_thread_x = 8
num_thread_y = 8
num_vthread_x = 1 num_vthread_x = 1
num_vthread_y = 1 num_vthread_y = 1
blocking_h = out_height blocking_h = out_height
blocking_w = out_width blocking_w = out_width
if out_height % 48 == 0: if out_height % 32 == 0:
blocking_h = 48
elif out_height % 32 == 0:
blocking_h = 32 blocking_h = 32
if out_width % 48 == 0: num_thread_x = 2
blocking_w = 48 num_vthread_x = 2
num_vthread_y = 3 if out_width % 32 == 0:
elif out_width % 32 == 0:
blocking_w = 32 blocking_w = 32
num_thread_y = 16
num_vthread_y = 2
block_x = tvm.thread_axis("blockIdx.x") block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y") block_y = tvm.thread_axis("blockIdx.y")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx") thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy") thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy")
# split and bind # split and bind
...@@ -65,10 +64,10 @@ def schedule_depthwise_conv2d_map(op): ...@@ -65,10 +64,10 @@ def schedule_depthwise_conv2d_map(op):
s[Output].bind(bx, block_x) s[Output].bind(bx, block_x)
by1, y1i = s[Output].split(Output.op.axis[2], factor=blocking_h) by1, y1i = s[Output].split(Output.op.axis[2], factor=blocking_h)
tvx, vxi = s[Output].split(y1i, nparts=num_vthread_x) tvx, vxi = s[Output].split(y1i, nparts=num_vthread_x)
tx, xi = s[Output].split(vxi, nparts=num_thread) tx, xi = s[Output].split(vxi, nparts=num_thread_x)
by2, y2i = s[Output].split(Output.op.axis[3], factor=blocking_w) by2, y2i = s[Output].split(Output.op.axis[3], factor=blocking_w)
tvy, vyi = s[Output].split(y2i, nparts=num_vthread_y) tvy, vyi = s[Output].split(y2i, nparts=num_vthread_y)
ty, yi = s[Output].split(vyi, nparts=num_thread) ty, yi = s[Output].split(vyi, nparts=num_thread_y)
s[Output].reorder(by1, by2, tvx, tvy, tx, ty, xi, yi) s[Output].reorder(by1, by2, tvx, tvy, tx, ty, xi, yi)
by = s[Output].fuse(by1, by2) by = s[Output].fuse(by1, by2)
s[Output].bind(tvx, thread_vx) s[Output].bind(tvx, thread_vx)
...@@ -85,21 +84,21 @@ def schedule_depthwise_conv2d_map(op): ...@@ -85,21 +84,21 @@ def schedule_depthwise_conv2d_map(op):
s[DepthwiseConv2d].compute_at(s[Output], ty) s[DepthwiseConv2d].compute_at(s[Output], ty)
# input's shared memory load # input's shared memory load
s[IS].compute_at(s[Output], by) s[IS].compute_at(s[Output], by)
tx, xi = s[IS].split(IS.op.axis[2], nparts=num_thread) tx, xi = s[IS].split(IS.op.axis[2], nparts=num_thread_x)
ty, yi = s[IS].split(IS.op.axis[3], nparts=num_thread) ty, yi = s[IS].split(IS.op.axis[3], nparts=num_thread_y)
s[IS].bind(tx, thread_x) s[IS].bind(tx, thread_x)
s[IS].bind(ty, thread_y) s[IS].bind(ty, thread_y)
# filter's shared memory load # filter's shared memory load
s[FS].compute_at(s[Output], by) s[FS].compute_at(s[Output], by)
s[FS].reorder(FS.op.axis[2], FS.op.axis[3], FS.op.axis[1]) s[FS].reorder(FS.op.axis[2], FS.op.axis[3], FS.op.axis[1])
tx, xi = s[FS].split(FS.op.axis[2], nparts=num_thread) tx, xi = s[FS].split(FS.op.axis[2], nparts=num_thread_x)
ty, yi = s[FS].split(FS.op.axis[3], nparts=num_thread) ty, yi = s[FS].split(FS.op.axis[3], nparts=num_thread_y)
s[FS].bind(tx, thread_x) s[FS].bind(tx, thread_x)
s[FS].bind(ty, thread_y) s[FS].bind(ty, thread_y)
def traverse(OP): def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output) # inline all one-to-one-mapping operators except the last stage (output)
if OP.tag == 'ewise' or OP.tag == 'scale_shift': if 'ewise' in OP.tag or 'bcast' in OP.tag:
if OP not in s.outputs: if OP not in s.outputs:
s[OP].compute_inline() s[OP].compute_inline()
for tensor in OP.input_tensors: for tensor in OP.input_tensors:
...@@ -110,7 +109,7 @@ def schedule_depthwise_conv2d_map(op): ...@@ -110,7 +109,7 @@ def schedule_depthwise_conv2d_map(op):
PaddedInput = OP.input_tensors[0] PaddedInput = OP.input_tensors[0]
Filter = OP.input_tensors[1] Filter = OP.input_tensors[1]
DepthwiseConv2d = OP.output(0) DepthwiseConv2d = OP.output(0)
schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d) _schedule(PaddedInput, Filter, DepthwiseConv2d)
traverse(op) traverse(outs[0].op)
return s return s
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
@tvm.tag_scope(tag="scale_shift") @tvm.tag_scope(tag="bcast_scale_shift")
def scale_shift(Input, Scale, Shift): def scale_shift(Input, Scale, Shift):
"""Batch normalization operator in inference. """Batch normalization operator in inference.
......
...@@ -6,4 +6,5 @@ from __future__ import absolute_import as _abs ...@@ -6,4 +6,5 @@ from __future__ import absolute_import as _abs
from .conv2d_hwcn_python import conv2d_hwcn_python from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python from .conv2d_nchw_python import conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python
from .dilate_python import dilate_python from .dilate_python import dilate_python
# pylint: disable=invalid-name, unused-variable, line-too-long
"""Depthwise convolution in python"""
import numpy as np
from scipy import signal
def depthwise_conv2d_python(input_np, filter_np, stride, padding):
"""Depthwise convolution operator in NCHW layout.
Parameters
----------
input_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
filter_np : numpy.ndarray
4-D with shape [in_channel, channel_multiplier, filter_height, filter_width]
stride : list / tuple of 2 ints
[stride_height, stride_width]
padding : str
'VALID' or 'SAME'
Returns
-------
output_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = input_np.shape
_, channel_multiplier, filter_height, filter_width = filter_np.shape
stride_h, stride_w = stride
# calculate output shape
if padding == 'VALID':
out_channel = in_channel * channel_multiplier
out_height = (in_height - filter_height) // stride_h + 1
out_width = (in_width - filter_width) // stride_w + 1
output_np = np.zeros((batch, out_channel, out_height, out_width))
for i in range(batch):
for j in range(out_channel):
output_np[i, j, :, :] = signal.convolve2d(input_np[i, j//channel_multiplier, :, :], \
np.rot90(filter_np[j//channel_multiplier, j%channel_multiplier, :, :], 2), \
mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w]
if padding == 'SAME':
out_channel = in_channel * channel_multiplier
out_height = np.int(np.ceil(float(in_height) / float(stride_h)))
out_width = np.int(np.ceil(float(in_width) / float(stride_w)))
output_np = np.zeros((batch, out_channel, out_height, out_width))
pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0))
pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0))
pad_top_tvm = np.int(np.ceil(float(pad_along_height) / 2))
pad_left_tvm = np.int(np.ceil(float(pad_along_width) / 2))
pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2))
pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2))
index_h = pad_top_scipy - pad_top_tvm
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
output_np[i, j, :, :] = signal.convolve2d(input_np[i, j//channel_multiplier, :, :], \
np.rot90(filter_np[j//channel_multiplier, j%channel_multiplier, :, :], 2), \
mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]
return output_np
...@@ -5,10 +5,10 @@ from scipy import signal ...@@ -5,10 +5,10 @@ from scipy import signal
from tvm.contrib import nvcc from tvm.contrib import nvcc
import topi import topi
from topi.nn.util import get_const_tuple from topi.util import get_const_tuple
from topi.cuda.depthwise_conv2d_map import schedule_depthwise_conv2d_map from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d
TASK = "depthwise_conv2d_map" TASK = "depthwise_conv2d"
USE_MANUAL_CODE = False USE_MANUAL_CODE = False
@tvm.register_func @tvm.register_func
...@@ -29,20 +29,20 @@ def tvm_callback_cuda_postproc(code): ...@@ -29,20 +29,20 @@ def tvm_callback_cuda_postproc(code):
code = open("perf/%s_manual.cu" % TASK).read() code = open("perf/%s_manual.cu" % TASK).read()
return code return code
def test_depthwise_conv2d_map(): def test_depthwise_conv2d():
"""You may test different settings.""" """You may test different settings."""
batch = 2 batch = 1
in_channel = 256 in_channel = 256
in_height = 32 in_height = 96
in_width = 32 in_width = 96
filter_channel = in_channel filter_channel = in_channel
channel_multiplier = 2 channel_multiplier = 1
filter_height = 5 filter_height = 3
filter_width = 5 filter_width = 3
stride_h = 2 stride_h = 1
stride_w = 2 stride_w = 1
padding = 'SAME' # or 'VALID' padding = 'SAME' # or 'VALID'
...@@ -57,40 +57,14 @@ def test_depthwise_conv2d_map(): ...@@ -57,40 +57,14 @@ def test_depthwise_conv2d_map():
ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift) ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift) Relu = topi.nn.relu(ScaleShift)
# Schedule # Schedule
s1 = schedule_depthwise_conv2d_map(DepthwiseConv2d.op) s1 = schedule_depthwise_conv2d(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d_map(ScaleShift.op) s2 = schedule_depthwise_conv2d(ScaleShift)
s3 = schedule_depthwise_conv2d_map(Relu.op) s3 = schedule_depthwise_conv2d(Relu)
def depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np): input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
out_shape = get_const_tuple(DepthwiseConv2d.shape) filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
out_channel = out_shape[1] scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype)
out_height = out_shape[2] shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype)
out_width = out_shape[3]
depthwise_conv2d_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=DepthwiseConv2d.dtype)
scale_shift_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=ScaleShift.dtype)
relu_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=Relu.dtype)
if padding == 'SAME':
pad_top_tvm = np.int(np.ceil(float(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) / 2))
pad_left_tvm = np.int(np.ceil(float(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) / 2))
pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2))
pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2))
index_h = pad_top_scipy - pad_top_tvm
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j//channel_multiplier,:,:],
np.rot90(filter_np[j//channel_multiplier,j%channel_multiplier,:,:], 2),
mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]
if padding == 'VALID':
for i in range(batch):
for j in range(out_channel):
depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j//channel_multiplier,:,:],
np.rot90(filter_np[j//channel_multiplier,j%channel_multiplier,:,:], 2),
mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w]
for c in range(out_channel):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
relu_scipy[:,:,:,:] = np.maximum(scale_shift_scipy[:,:,:,:], 0)
return depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
...@@ -102,35 +76,36 @@ def test_depthwise_conv2d_map(): ...@@ -102,35 +76,36 @@ def test_depthwise_conv2d_map():
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
# Prepare data # Prepare data
input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
input_tvm = tvm.nd.array(input_np, ctx) input_tvm = tvm.nd.array(input_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx) filter_tvm = tvm.nd.array(filter_np, ctx)
scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype)
shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype)
scale_tvm = tvm.nd.array(scale_np, ctx) scale_tvm = tvm.nd.array(scale_np, ctx)
shift_tvm = tvm.nd.array(shift_np, ctx) shift_tvm = tvm.nd.array(shift_np, ctx)
depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx) depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx)
scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx) scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
# Measure time cost of kernel 1 (depthwise_conv2d) # Measure time cost of kernel 1 (depthwise_conv2d)
timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=10000) timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1000)
tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
# Measure time cost of kernel 2 (depthwise_conv2d + scale_shift) # Measure time cost of kernel 2 (depthwise_conv2d + scale_shift)
timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=10000) timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1000)
tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
# Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu) # Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=10000) timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1000)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
print("Input shape = " + str(get_const_tuple(Input.shape))) print("Input shape = " + str(get_const_tuple(Input.shape)))
print("Filter shape = " + str(get_const_tuple(Filter.shape))) print("Filter shape = " + str(get_const_tuple(Filter.shape)))
print("Stride = (%d, %d)" % (stride_h, stride_w)) print("Stride = (%d, %d)" % (stride_h, stride_w))
print("padding = %s\n" % padding) print("padding = %s\n" % padding)
print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape))) print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
print("average time cost of 10000 runs (depthwise_conv2d) = %g sec" % tcost_1) print("average time cost of 1000 runs (depthwise_conv2d) = %g sec" % tcost_1)
print("average time cost of 10000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2) print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2)
print("average time cost of 10000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3) print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3)
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy = depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np) # correctness
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
...@@ -138,10 +113,10 @@ def test_depthwise_conv2d_map(): ...@@ -138,10 +113,10 @@ def test_depthwise_conv2d_map():
with tvm.build_config(auto_unroll_max_step=32, with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0, auto_unroll_min_depth=0,
unroll_explicit=True, unroll_explicit=False,
detect_global_barrier=False, detect_global_barrier=False,
restricted_func=True): restricted_func=True):
check_device("cuda") check_device("cuda")
if __name__ == "__main__": if __name__ == "__main__":
test_depthwise_conv2d_map() test_depthwise_conv2d()
...@@ -3,9 +3,9 @@ import topi ...@@ -3,9 +3,9 @@ import topi
import numpy as np import numpy as np
from scipy import signal from scipy import signal
from topi.util import get_const_tuple from topi.util import get_const_tuple
from topi.cuda.depthwise_conv2d_map import schedule_depthwise_conv2d_map from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d
def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): def depthwise_conv2d_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
in_width = in_height in_width = in_height
filter_channel = in_channel filter_channel = in_channel
filter_width = filter_height filter_width = filter_height
...@@ -21,40 +21,14 @@ def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_mul ...@@ -21,40 +21,14 @@ def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_mul
ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift) ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift) Relu = topi.nn.relu(ScaleShift)
# schedule # schedule
s1 = schedule_depthwise_conv2d_map(DepthwiseConv2d.op) s1 = schedule_depthwise_conv2d(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d_map(ScaleShift.op) s2 = schedule_depthwise_conv2d(ScaleShift)
s3 = schedule_depthwise_conv2d_map(Relu.op) s3 = schedule_depthwise_conv2d(Relu)
def depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np): input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
out_shape = get_const_tuple(DepthwiseConv2d.shape) filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
out_channel = out_shape[1] scale_np = np.random.uniform(size=get_const_tuple(Scale.shape)).astype(Scale.dtype)
out_height = out_shape[2] shift_np = np.random.uniform(size=get_const_tuple(Shift.shape)).astype(Shift.dtype)
out_width = out_shape[3]
depthwise_conv2d_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=DepthwiseConv2d.dtype)
scale_shift_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=ScaleShift.dtype)
relu_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=Relu.dtype)
if padding == 'SAME':
pad_top_tvm = np.int(np.ceil(float(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) / 2))
pad_left_tvm = np.int(np.ceil(float(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) / 2))
pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2))
pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2))
index_h = pad_top_scipy - pad_top_tvm
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j//channel_multiplier,:,:],
np.rot90(filter_np[j//channel_multiplier,j%channel_multiplier,:,:], 2),
mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]
if padding == 'VALID':
for i in range(batch):
for j in range(out_channel):
depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j//channel_multiplier,:,:],
np.rot90(filter_np[j//channel_multiplier,j%channel_multiplier,:,:], 2),
mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w]
for c in range(out_channel):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
relu_scipy[:,:,:,:] = np.maximum(scale_shift_scipy[:,:,:,:], 0)
return depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
...@@ -66,12 +40,8 @@ def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_mul ...@@ -66,12 +40,8 @@ def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_mul
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
# prepare data # prepare data
input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
input_tvm = tvm.nd.array(input_np, ctx) input_tvm = tvm.nd.array(input_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx) filter_tvm = tvm.nd.array(filter_np, ctx)
scale_np = np.random.uniform(size=get_const_tuple(Scale.shape)).astype(Scale.dtype)
shift_np = np.random.uniform(size=get_const_tuple(Shift.shape)).astype(Shift.dtype)
scale_tvm = tvm.nd.array(scale_np, ctx) scale_tvm = tvm.nd.array(scale_np, ctx)
shift_tvm = tvm.nd.array(shift_np, ctx) shift_tvm = tvm.nd.array(shift_np, ctx)
depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx) depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx)
...@@ -87,7 +57,11 @@ def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_mul ...@@ -87,7 +57,11 @@ def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_mul
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1) timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
# correctness with scipy # correctness with scipy
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy = depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np) depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
...@@ -97,16 +71,16 @@ def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_mul ...@@ -97,16 +71,16 @@ def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_mul
check_device("metal") check_device("metal")
def test_depthwise_conv2d_map(): def test_depthwise_conv2d():
depthwise_conv2d_map_with_workload(1, 728, 64, 1, 3, 1, "SAME") depthwise_conv2d_with_workload(1, 728, 64, 1, 3, 1, "SAME")
depthwise_conv2d_map_with_workload(1, 728, 32, 1, 3, 1, "SAME") depthwise_conv2d_with_workload(1, 728, 32, 1, 3, 1, "SAME")
depthwise_conv2d_map_with_workload(4, 256, 64, 2, 5, 2, "SAME") depthwise_conv2d_with_workload(4, 256, 64, 2, 5, 2, "SAME")
depthwise_conv2d_map_with_workload(4, 256, 32, 2, 5, 2, "SAME") depthwise_conv2d_with_workload(4, 256, 32, 2, 5, 2, "SAME")
depthwise_conv2d_map_with_workload(1, 728, 64, 1, 3, 1, "VALID") depthwise_conv2d_with_workload(1, 728, 64, 1, 3, 1, "VALID")
depthwise_conv2d_map_with_workload(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_with_workload(1, 728, 32, 1, 3, 1, "VALID")
depthwise_conv2d_map_with_workload(4, 256, 64, 2, 5, 2, "VALID") depthwise_conv2d_with_workload(4, 256, 64, 2, 5, 2, "VALID")
depthwise_conv2d_map_with_workload(4, 256, 32, 2, 5, 2, "VALID") depthwise_conv2d_with_workload(4, 256, 32, 2, 5, 2, "VALID")
if __name__ == "__main__": if __name__ == "__main__":
test_depthwise_conv2d_map() test_depthwise_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment