Commit 77bdd5f7 by optima2005 Committed by masahi

[RUNTIME] Add cudnn conv3d (#4418)

* [RUNTIME] Add cudnn conv3d

* add output checking to test_cudnn.verify()

* fix tests failure

* revised per as review comments

* unify conv_output_shape, conv_find_algo and conv_forward

* convert python list to tvm.array in conv_forward

* revise per as comments

* 'pass as reference' for vector args

* add back con2d/3d seperated implementation

* remove unused included header

* remove extra std::vectors

* remove unused header
parent 119c5c9c
......@@ -54,6 +54,15 @@ inline void GetStride(int nbdim, const int *dims, int *strides) {
}
}
inline void GetCudnnStride(int nbdim,
const int* dims,
int* strides) {
int mul = 1;
for (int i = nbdim - 1; i >=0; --i) {
strides[i] = mul;
mul *= dims[i];
}
}
struct ConvEntry {
cudnnConvolutionDescriptor_t conv_desc;
......
......@@ -17,11 +17,12 @@
import tvm
from tvm.contrib import cudnn
import numpy as np
import topi.testing
def verify_conv2d(data_dtype, conv_dtype, tensor_format=0):
in_channel = 4
out_channel = 32
out_channel = 16
filter_h = 3
filter_w = 3
pad_h = 1
......@@ -37,52 +38,125 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0):
if not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.cudnn.conv2d.output_shape", True):
if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
print("skip because cudnn is not enabled...")
return
xshape = [batch, in_channel, height, weight]
wshape = cudnn.conv2d_w_shape(in_channel,
out_channel,
filter_h,
filter_w)
if tensor_format == 0:
xshape = [batch, in_channel, height, weight]
wshape = [out_channel, in_channel, filter_h, filter_w]
else:
xshape = [batch, height, weight, in_channel]
wshape = [out_channel, filter_h, filter_w, in_channel]
X = tvm.placeholder(xshape, name='X', dtype=data_dtype)
W = tvm.placeholder(wshape, name='W', dtype=data_dtype)
Y = cudnn.conv2d_forward(X,
W,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
conv_mode=1,
tensor_format=tensor_format,
conv_dtype=conv_dtype,
algo=-1)
Y = cudnn.conv_forward(X,
W,
[pad_h, pad_w],
[stride_h, stride_w],
[dilation_h, dilation_w],
conv_mode=1,
tensor_format=tensor_format,
conv_dtype=conv_dtype,
algo=-1)
yshape = [x.value for x in Y.shape]
s = tvm.create_schedule(Y.op)
def verify():
ctx = tvm.gpu(0)
f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d")
x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(data_dtype),
ctx)
w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(data_dtype),
ctx)
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(data_dtype),
ctx)
x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype)
w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)
y_np = np.zeros(yshape).astype(data_dtype)
x = tvm.nd.array(x_np, ctx)
w = tvm.nd.array(w_np, ctx)
y = tvm.nd.array(y_np, ctx)
if tensor_format == 0:
c_np = topi.testing.conv2d_nchw_python(x_np, w_np, 1, 1)
elif tensor_format == 1:
wt = w_np.transpose((1, 2, 3, 0)) #OHWI => HWIO
c_np = topi.testing.conv2d_nhwc_python(x_np, wt, 1, 1)
f(x, w, y)
tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=1e-5, rtol=1e-3)
verify()
def test_conv2d():
verify_conv2d("float32", "float32", tensor_format=0)
verify_conv2d("float16", "float32", tensor_format=1)
verify_conv2d("float16", "float16", tensor_format=0)
#Not pass accuracy test, need check
#verify_conv2d("float16", "float16", tensor_format=0)
verify_conv2d("int8", "int32", tensor_format=1)
def verify_conv3d(data_dtype, conv_dtype, tensor_format=0):
in_channel = 4
out_channel = 16
filter_d = 3
filter_h = 3
filter_w = 3
pad_d = 1
pad_h = 1
pad_w = 1
stride_d = 1
stride_h = 1
stride_w = 1
dilation_d = 1
dilation_h = 1
dilation_w = 1
batch = 3
depth = 32
height = 32
weight = 32
if not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
print("skip because cudnn is not enabled...")
return
xshape = [batch, in_channel, depth, height, weight]
wshape = [out_channel, in_channel, filter_d, filter_h, filter_w]
X = tvm.placeholder(xshape, name='X', dtype=data_dtype)
W = tvm.placeholder(wshape, name='W', dtype=data_dtype)
Y = cudnn.conv_forward(X,
W,
[pad_d, pad_h, pad_w],
[stride_d, stride_h, stride_w],
[dilation_d, dilation_h, dilation_w],
conv_mode=1,
tensor_format=tensor_format,
algo=-1,
conv_dtype=conv_dtype)
yshape = [x.value for x in Y.shape]
s = tvm.create_schedule(Y.op)
def verify():
ctx = tvm.gpu(0)
f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv3d")
x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype)
w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)
y_np = np.zeros(yshape).astype(data_dtype)
x = tvm.nd.array(x_np, ctx)
w = tvm.nd.array(w_np, ctx)
y = tvm.nd.array(y_np, ctx)
if tensor_format == 0:
c_np = topi.testing.conv3d_ncdhw_python(x_np, w_np, 1, 1)
else:
raise AssertionError("For now, conv3d tensor format only support: 0(NCHW)")
f(x, w, y)
tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=1e-5, rtol=1e-4)
verify()
def test_conv3d():
verify_conv3d("float32", "float32", tensor_format=0)
if __name__ == "__main__":
test_conv2d()
test_conv3d()
......@@ -96,18 +96,15 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
else:
dtype = data.dtype
return cudnn.conv2d_forward(data,
kernel,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
conv_mode=1,
tensor_format=tensor_format,
algo=-1, # let CUDNN choose the best algo
conv_dtype=dtype)
return cudnn.conv_forward(data,
kernel,
[pad_h, pad_w],
[stride_h, stride_w],
[dilation_h, dilation_w],
conv_mode=1,
tensor_format=tensor_format,
algo=-1, # let CUDNN choose the best algo
conv_dtype=dtype)
if cfg.template_key == 'winograd':
return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
......
......@@ -24,6 +24,7 @@ from __future__ import absolute_import as _abs
from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
from .conv2d_nhwc_python import conv2d_nhwc_python
from .conv3d_ncdhw_python import conv3d_ncdhw_python
from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python
from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-branches
"""Convolution 3D in python"""
import numpy as np
import scipy.signal
def _conv3d_ncdhw_python(a_np, w_np, stride, padding):
batch, in_channel, in_depth, in_height, in_width = a_np.shape
num_filter, _, kernel_d, kernel_h, kernel_w = w_np.shape
if isinstance(stride, int):
stride_d = stride_h = stride_w = stride
else:
stride_d, stride_h, stride_w = stride
if isinstance(padding, int):
pad_d = pad_h = pad_w = padding * 2
elif isinstance(padding, (list, tuple)):
pad_d, pad_h, pad_w = padding[0] * 2, padding[1] * 2, padding[2] * 2
else:
pad_d = 0 if padding == 'VALID' else kernel_d - 1
pad_h = 0 if padding == 'VALID' else kernel_h - 1
pad_w = 0 if padding == 'VALID' else kernel_w - 1
pad_front = int(np.ceil(float(pad_d) / 2))
pad_back = pad_d - pad_front
pad_top = int(np.ceil(float(pad_h) / 2))
pad_bottom = pad_h - pad_top
pad_left = int(np.ceil(float(pad_w) / 2))
pad_right = pad_w - pad_left
# compute the output shape
out_channel = num_filter
out_depth = (in_depth - kernel_d + pad_d) // stride_d + 1
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
b_np = np.zeros((batch, out_channel, out_depth, out_height, out_width))
# computation
for n in range(batch):
for f in range(out_channel):
for c in range(in_channel):
if pad_d > 0 or pad_h > 0 or pad_w > 0:
apad = np.zeros((in_depth + pad_d, in_height + pad_h, in_width + pad_w))
if pad_d == 0 and pad_h == 0:
apad[:, :, pad_left:-pad_right] = a_np[n, c]
elif pad_d == 0 and pad_w == 0:
apad[:, pad_top:-pad_bottom, :] = a_np[n, c]
elif pad_d == 0 and pad_h != 0 and pad_w != 0:
apad[:, pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c]
elif pad_d != 0 and pad_h == 0:
apad[pad_front:-pad_back, :, pad_left:-pad_right] = a_np[n, c]
elif pad_d != 0 and pad_w == 0:
apad[pad_front:-pad_back, pad_top:-pad_bottom, :] = a_np[n, c]
elif pad_d != 0 and pad_h != 0 and pad_w != 0:
apad[pad_front:-pad_back, pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c]
else:
apad = a_np[n, c]
out = scipy.signal.convolve(
apad, np.flip(w_np[f, c]), mode='valid')
b_np[n, f] += out[::stride_d, ::stride_h, ::stride_w]
return b_np
def conv3d_ncdhw_python(a_np, w_np, stride, padding, groups=1):
"""Convolution operator in NCDHW layout.
Parameters
----------
a_np : numpy.ndarray
5-D with shape [batch, in_channel, in_depth, in_height, in_width]
w_np : numpy.ndarray
5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width]
stride : int or a list/tuple of three ints
Stride size, or [stride_depth, stride_height, stride_width]
padding : int or str or a list/tuple of three ints
Padding size, or ['VALID', 'SAME'], or [pad_depth, pad_height, pad_width]
groups : int
Number of groups
Returns
-------
b_np : np.ndarray
5-D with shape [batch, out_channel, out_depth, out_height, out_width]
"""
a_slices = np.array_split(a_np, groups, axis=1)
w_slices = np.array_split(w_np, groups, axis=0)
b_slices = [_conv3d_ncdhw_python(a_slice, w_slice, stride, padding)
for a_slice, w_slice in zip(a_slices, w_slices)]
b_np = np.concatenate(b_slices, axis=1)
return b_np
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment