Commit 989e99e6 by wetliu Committed by Tianqi Chen

[WIP] [TOPI] Depth wise Conv for NHWC (#325)

* rename the nchw and pass the unit test; going to do it for nhwc depthwise

* bug with fusion

* nchw works fine; nhwc float32 problem remains

* still cannot bind them together

* fusion works

* syntax fix

* all bugs fixed; test cases pass

* minor fix on nn.h
parent 64870ffb
......@@ -287,7 +287,7 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
int stride_h = 1,
int stride_w = 1,
std::string name = "tensor",
std::string tag = kDepthwiseConv2d) {
std::string tag = kDepthwiseConv2dNCHW) {
CHECK_EQ(4, I->shape.size());
CHECK_EQ(4, W->shape.size());
auto pH = I->shape[2];
......@@ -313,6 +313,39 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
return tvm::compute(output_shape, l, name, tag);
}
inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
const tvm::Tensor& W,
int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
int stride_w = 1,
std::string name = "tensor",
std::string tag = kDepthwiseConv2dNHWC) {
CHECK_EQ(4, I->shape.size());
CHECK_EQ(4, W->shape.size());
auto pH = I->shape[1];
auto pW = I->shape[2];
auto pCM = W->shape[1]; // channel_multiplier
tvm::Array<tvm::Expr> output_shape{
I->shape[0], // B
(I->shape[1] - W->shape[1] + 2 * pad_h) / stride_h + 1, // H
(I->shape[2] - W->shape[2] + 2 * pad_w) / stride_w + 1, // W
W->shape[3], // O
};
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[0]}, "kh");
auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
auto T = (pad_h == 0 && pad_w == 0)
? I
: pad(I, {tvm::Expr(0), pad_h, pad_w, tvm::Expr(0)});
auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) {
return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, i / pCM) *
W(kh, kw, i / pCM, o % pCM),
{kh, kw, i});
};
return tvm::compute(output_shape, l, name, tag);
}
/*!
* \brief Creates an operation that performs a 2-D group convolution with
* an NGCHW-layout
......
......@@ -13,7 +13,8 @@ constexpr auto kBroadcast = "bcast";
constexpr auto kMatMult = "matmult";
constexpr auto kConv2dNCHW = "conv2d_nchw";
constexpr auto kConv2dHWCN = "conv2d_hwcn";
constexpr auto kDepthwiseConv2d = "depthwise_conv2d";
constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw";
constexpr auto kDepthwiseConv2dNHWC = "depthwise_conv2d_nhwc";
constexpr auto kGroupConv2d = "group_conv2d";
} // namespace topi
......
......@@ -4,6 +4,6 @@ from __future__ import absolute_import as _abs
from .conv2d_nchw import schedule_conv2d_nchw
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d
from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
from .reduction import schedule_reduce
from .broadcast import schedule_broadcast_to
......@@ -3,9 +3,8 @@
import tvm
from ..util import get_const_tuple
def schedule_depthwise_conv2d(outs):
"""Schedule for depthwise_conv2d.
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for depthwise_conv2d nchw forward.
Parameters
----------
......@@ -16,7 +15,7 @@ def schedule_depthwise_conv2d(outs):
Returns
-------
s: Schedule
The computation schedule for depthwise_conv2d.
The computation schedule for depthwise_conv2d nchw.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
......@@ -105,7 +104,78 @@ def schedule_depthwise_conv2d(outs):
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d':
if OP.tag == 'depthwise_conv2d_nchw':
PaddedInput = OP.input_tensors[0]
Filter = OP.input_tensors[1]
DepthwiseConv2d = OP.output(0)
_schedule(PaddedInput, Filter, DepthwiseConv2d)
traverse(outs[0].op)
return s
def schedule_depthwise_conv2d_nhwc(outs):
"""Schedule for depthwise_conv2d nhwc forward.
Parameters
----------
outs: Array of Tensor
The computation graph description of depthwise_conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for depthwise_conv2d nhwc.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(temp, Filter, DepthwiseConv2d):
s[temp].compute_inline()
FS = s.cache_read(Filter, "shared", [DepthwiseConv2d])
if DepthwiseConv2d.op in s.outputs:
Output = DepthwiseConv2d
CL = s.cache_write(DepthwiseConv2d, "local")
else:
Output = outs[0].op.output(0)
s[DepthwiseConv2d].set_scope("local")
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
b, h, w, c = s[Output].op.axis
ic_val = tvm.ir_pass.Simplify(temp.shape[3]).value
xoc, xic = s[Output].split(c, factor=ic_val)
s[Output].reorder(xoc, b, h, w, xic)
xo, yo, _, _ = s[Output].tile(h, w, x_factor=2, y_factor=2)
fused = s[Output].fuse(yo, xo)
fused = s[Output].fuse(fused, b)
fused = s[Output].fuse(fused, xoc)
s[Output].bind(fused, block_x)
s[Output].bind(xic, thread_x)
if DepthwiseConv2d.op in s.outputs:
s[CL].compute_at(s[Output], xic)
else:
s[DepthwiseConv2d].compute_at(s[Output], xic)
_, _, ci, fi = s[FS].op.axis
s[FS].compute_at(s[Output], fused)
fused = s[FS].fuse(fi, ci)
s[FS].bind(fused, thread_x)
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if 'ewise' in OP.tag or 'bcast' in OP.tag:
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d_nhwc':
PaddedInput = OP.input_tensors[0]
Filter = OP.input_tensors[1]
DepthwiseConv2d = OP.output(0)
......
......@@ -107,9 +107,8 @@ def conv2d_hwcn(Input, Filter, stride, padding):
name="Conv2dOutput", tag="conv2d_hwcn")
return Output
def depthwise_conv2d(Input, Filter, stride, padding):
"""Depthwise convolution operator.
def depthwise_conv2d_nchw(Input, Filter, stride, padding):
"""Depthwise convolution nchw forward operator.
Parameters
----------
......@@ -153,5 +152,53 @@ def depthwise_conv2d(Input, Filter, stride, padding):
(PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj]),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d")
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
return Output
def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
"""Depthwise convolution nhwc forward operator.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
Filter : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
Stride : tvm.Tensor
1-D of size 2
padding : str
'VALID' or 'SAME'
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
batch, in_height, in_width, in_channel = Input.shape
filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape
stride_h, stride_w = stride
pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option(
padding, (filter_height, filter_width))
out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1)
# padding stage
pad_before = [0, pad_top, pad_left, 0]
pad_after = [0, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_height, out_width, out_channel),
lambda b, i, j, c: tvm.sum(
(PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier] *
Filter[di, dj, c/channel_multiplier, c%channel_multiplier]),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
return Output
......@@ -3,8 +3,8 @@
from __future__ import absolute_import as _abs
import tvm
@tvm.tag_scope(tag="bcast_scale_shift")
def scale_shift(Input, Scale, Shift):
@tvm.tag_scope(tag="bcast_scale_shift_nchw")
def scale_shift_nchw(Input, Scale, Shift):
"""Batch normalization operator in inference.
Parameters
......@@ -24,3 +24,25 @@ def scale_shift(Input, Scale, Shift):
Output tensor, layout is NCHW
"""
return tvm.compute(Input.shape, lambda b, c, i, j: Input[b, c, i, j] * Scale[c] + Shift[c], name='ScaleShift')
@tvm.tag_scope(tag="bcast_scale_shift_nhwc")
def scale_shift_nhwc(Input, Scale, Shift):
"""Batch normalization operator in inference.
Parameters
----------
Input : tvm.Tensor
Input tensor, layout is NHWC
Scale : tvm.Tensor
Scale tensor, 1-D of size channel number
Shift : tvm.Tensor
Shift tensor, 1-D of size channel number
Returns
-------
Output : tvm.Tensor
Output tensor, layout is NHWC
"""
return tvm.compute(Input.shape, lambda b, i, j, c: Input[b, i, j, c] * Scale[c] + Shift[c], name='ScaleShift')
......@@ -6,5 +6,5 @@ from __future__ import absolute_import as _abs
from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python
......@@ -3,8 +3,7 @@
import numpy as np
from scipy import signal
def depthwise_conv2d_python(input_np, filter_np, stride, padding):
def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding):
"""Depthwise convolution operator in NCHW layout.
Parameters
......@@ -60,3 +59,60 @@ def depthwise_conv2d_python(input_np, filter_np, stride, padding):
mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]
return output_np
def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding):
"""Depthwise convolution operator in nchw layout.
Parameters
----------
input_np : numpy.ndarray
4-D with shape [batch, in_height, in_width, in_channel]
filter_np : numpy.ndarray
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
stride : list / tuple of 2 ints
[stride_height, stride_width]
padding : str
'VALID' or 'SAME'
Returns
-------
output_np : np.ndarray
4-D with shape [batch, out_height, out_width, out_channel]
"""
batch, in_height, in_width, in_channel = input_np.shape
filter_height, filter_width, _, channel_multiplier = filter_np.shape
stride_h, stride_w = stride
# calculate output shape
if padding == 'VALID':
out_channel = in_channel * channel_multiplier
out_height = (in_height - filter_height) // stride_h + 1
out_width = (in_width - filter_width) // stride_w + 1
output_np = np.zeros((batch, out_height, out_width, out_channel))
for i in range(batch):
for j in range(out_channel):
output_np[i, :, :, j] = signal.convolve2d(input_np[i, :, :, j//channel_multiplier], \
np.rot90(filter_np[:, :, j//channel_multiplier, j%channel_multiplier], 2), \
mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w]
if padding == 'SAME':
out_channel = in_channel * channel_multiplier
out_height = np.int(np.ceil(float(in_height) / float(stride_h)))
out_width = np.int(np.ceil(float(in_width) / float(stride_w)))
output_np = np.zeros((batch, out_height, out_width, out_channel))
pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0))
pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0))
pad_top_tvm = np.int(np.ceil(float(pad_along_height) / 2))
pad_left_tvm = np.int(np.ceil(float(pad_along_width) / 2))
pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2))
pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2))
index_h = pad_top_scipy - pad_top_tvm
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
output_np[i, :, :, j] = signal.convolve2d(input_np[i, :, :, j//channel_multiplier], \
np.rot90(filter_np[:, :, j//channel_multiplier, j%channel_multiplier], 2), \
mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]
return output_np
......@@ -6,14 +6,14 @@ from tvm.contrib import nvcc
import topi
from topi.util import get_const_tuple
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
TASK = "depthwise_conv2d"
USE_MANUAL_CODE = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"])
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_37"]) # 37 for k80(ec2 instance)
return ptx
def write_code(code, fname):
......@@ -29,7 +29,7 @@ def tvm_callback_cuda_postproc(code):
code = open("perf/%s_manual.cu" % TASK).read()
return code
def test_depthwise_conv2d():
def test_depthwise_conv2d_nchw():
"""You may test different settings."""
batch = 1
in_channel = 256
......@@ -53,14 +53,13 @@ def test_depthwise_conv2d():
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# Declare
DepthwiseConv2d = topi.nn.depthwise_conv2d(Input, Filter, Stride, padding)
ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift)
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, Stride, padding)
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
# Schedule
s1 = schedule_depthwise_conv2d(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d(ScaleShift)
s3 = schedule_depthwise_conv2d(Relu)
s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = schedule_depthwise_conv2d_nchw(Relu)
input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype)
......@@ -80,6 +79,7 @@ def test_depthwise_conv2d():
filter_tvm = tvm.nd.array(filter_np, ctx)
scale_tvm = tvm.nd.array(scale_np, ctx)
shift_tvm = tvm.nd.array(shift_np, ctx)
depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx)
scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
......@@ -101,7 +101,7 @@ def test_depthwise_conv2d():
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2)
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3)
# correctness
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
......@@ -118,5 +118,95 @@ def test_depthwise_conv2d():
restricted_func=True):
check_device("cuda")
def test_depthwise_conv2d_nhwc():
"""You may test different settings."""
batch = 1
in_channel = 256
in_height = 96
in_width = 96
filter_channel = in_channel
channel_multiplier = 1
filter_height = 3
filter_width = 3
stride_h = 1
stride_w = 1
padding = 'SAME' # or 'VALID'
# Placeholder
Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input')
Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter')
Stride = [stride_h, stride_w]
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# Declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, Stride, padding)
ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
# Schedule
s1 = schedule_depthwise_conv2d_nhwc(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d_nhwc(ScaleShift)
s3 = schedule_depthwise_conv2d_nhwc(Relu)
input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype)
shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
# Build the kernel
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
# Prepare data
input_tvm = tvm.nd.array(input_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx)
scale_tvm = tvm.nd.array(scale_np, ctx)
shift_tvm = tvm.nd.array(shift_np, ctx)
depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx)
scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
# Measure time cost of kernel 1 (depthwise_conv2d)
timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1000)
tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
# Measure time cost of kernel 2 (depthwise_conv2d + scale_shift)
timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1000)
tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
# Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1000)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
print("Input shape = " + str(get_const_tuple(Input.shape)))
print("Filter shape = " + str(get_const_tuple(Filter.shape)))
print("Stride = (%d, %d)" % (stride_h, stride_w))
print("padding = %s\n" % padding)
print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
print("average time cost of 1000 runs (depthwise_conv2d) = %g sec" % tcost_1)
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2)
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3)
# correctness
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
print("success")
with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0,
unroll_explicit=False,
detect_global_barrier=False,
restricted_func=True):
check_device("cuda")
if __name__ == "__main__":
test_depthwise_conv2d()
test_depthwise_conv2d_nchw()
test_depthwise_conv2d_nhwc()
......@@ -3,9 +3,9 @@ import topi
import numpy as np
from scipy import signal
from topi.util import get_const_tuple
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
def depthwise_conv2d_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
in_width = in_height
filter_channel = in_channel
filter_width = filter_height
......@@ -17,13 +17,13 @@ def depthwise_conv2d_with_workload(batch, in_channel, in_height, channel_multipl
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# declare
DepthwiseConv2d = topi.nn.depthwise_conv2d(Input, Filter, Stride, padding)
ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift)
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, Stride, padding)
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
# schedule
s1 = schedule_depthwise_conv2d(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d(ScaleShift)
s3 = schedule_depthwise_conv2d(Relu)
s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = schedule_depthwise_conv2d_nchw(Relu)
input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
......@@ -57,7 +57,7 @@ def depthwise_conv2d_with_workload(batch, in_channel, in_height, channel_multipl
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
# correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
......@@ -70,17 +70,91 @@ def depthwise_conv2d_with_workload(batch, in_channel, in_height, channel_multipl
check_device("cuda")
check_device("metal")
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
in_width = in_height
filter_channel = in_channel
filter_width = filter_height
stride_w = stride_h
# placeholder
Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input')
Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter')
Stride = [stride_h, stride_w]
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, Stride, padding)
ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
# schedule
s1 = schedule_depthwise_conv2d_nhwc(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d_nhwc(ScaleShift)
s3 = schedule_depthwise_conv2d_nhwc(Relu)
def test_depthwise_conv2d():
depthwise_conv2d_with_workload(1, 728, 64, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload(1, 728, 32, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload(4, 256, 64, 2, 5, 2, "SAME")
depthwise_conv2d_with_workload(4, 256, 32, 2, 5, 2, "SAME")
depthwise_conv2d_with_workload(1, 728, 64, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload(1, 728, 32, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload(4, 256, 64, 2, 5, 2, "VALID")
depthwise_conv2d_with_workload(4, 256, 32, 2, 5, 2, "VALID")
input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
scale_np = np.random.uniform(size=get_const_tuple(Scale.shape)).astype(Scale.dtype)
shift_np = np.random.uniform(size=get_const_tuple(Shift.shape)).astype(Shift.dtype)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.context(device, 0)
# build the kernels
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
# prepare data
input_tvm = tvm.nd.array(input_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx)
scale_tvm = tvm.nd.array(scale_np, ctx)
shift_tvm = tvm.nd.array(shift_np, ctx)
depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx)
scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
# launch kernel 1 (depthwise_conv2d)
timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1)
tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
# launch kernel 2 (depthwise_conv2d + scale_shift)
timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1)
tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
# launch kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
# correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
check_device("opencl")
check_device("cuda")
check_device("metal")
def test_depthwise_conv2d():
print("testing nchw")
depthwise_conv2d_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_nchw(1, 728, 32, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_nchw(4, 256, 64, 2, 5, 2, "SAME")
depthwise_conv2d_with_workload_nchw(4, 256, 32, 2, 5, 2, "SAME")
depthwise_conv2d_with_workload_nchw(1, 728, 64, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_nchw(1, 728, 32, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_nchw(4, 256, 64, 2, 5, 2, "VALID")
depthwise_conv2d_with_workload_nchw(4, 256, 32, 2, 5, 2, "VALID")
print("testing nhwc")
depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "SAME")
depthwise_conv2d_with_workload_nhwc(4, 256, 32, 2, 5, 2, "SAME")
depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID")
depthwise_conv2d_with_workload_nhwc(4, 256, 32, 2, 5, 2, "VALID")
if __name__ == "__main__":
test_depthwise_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment