Unverified Commit 9c12ec81 by Wei Pan Committed by GitHub

[cuDNN] Add cuDNN grouped convolutions support (#5319)

Signed-off-by: Wei Pan <weip@nvidia.com>
parent a3b13973
...@@ -182,7 +182,8 @@ def conv_output_shape(tensor_format, ...@@ -182,7 +182,8 @@ def conv_output_shape(tensor_format,
x_shape, x_shape,
w_shape, w_shape,
data_dtype, data_dtype,
conv_dtype): conv_dtype,
groups=1):
"""Get output shape of 2D or 3D convolution """Get output shape of 2D or 3D convolution
Paramters Paramters
...@@ -205,6 +206,8 @@ def conv_output_shape(tensor_format, ...@@ -205,6 +206,8 @@ def conv_output_shape(tensor_format,
data type data type
conv_dtype: str conv_dtype: str
convolution type convolution type
groups: int
number of groups
Returns Returns
------- -------
...@@ -228,7 +231,8 @@ def conv_output_shape(tensor_format, ...@@ -228,7 +231,8 @@ def conv_output_shape(tensor_format,
_get_np_int32_array_handle(wshape), _get_np_int32_array_handle(wshape),
_get_np_int32_array_handle(oshape), _get_np_int32_array_handle(oshape),
data_dtype, data_dtype,
conv_dtype) conv_dtype,
groups)
return list(oshape) return list(oshape)
...@@ -240,7 +244,8 @@ def conv_find_algo(tensor_format, ...@@ -240,7 +244,8 @@ def conv_find_algo(tensor_format,
w_shape, w_shape,
y_shape, y_shape,
data_dtype, data_dtype,
conv_dtype): conv_dtype,
groups=1):
"""Choose the best algo for the given input. """Choose the best algo for the given input.
Paramters Paramters
...@@ -265,6 +270,8 @@ def conv_find_algo(tensor_format, ...@@ -265,6 +270,8 @@ def conv_find_algo(tensor_format,
data type data type
conv_dtype: str conv_dtype: str
convolution type convolution type
groups: int
number of groups
Returns Returns
------- -------
...@@ -287,7 +294,8 @@ def conv_find_algo(tensor_format, ...@@ -287,7 +294,8 @@ def conv_find_algo(tensor_format,
_get_np_int32_array_handle(wshape), _get_np_int32_array_handle(wshape),
_get_np_int32_array_handle(yshape), _get_np_int32_array_handle(yshape),
data_dtype, data_dtype,
conv_dtype) conv_dtype,
groups)
def conv_forward(x, def conv_forward(x,
...@@ -298,7 +306,8 @@ def conv_forward(x, ...@@ -298,7 +306,8 @@ def conv_forward(x,
conv_mode, conv_mode,
tensor_format, tensor_format,
algo, algo,
conv_dtype): conv_dtype,
groups=1):
"""Create an extern op that compute 2D or 3D convolution with CuDNN """Create an extern op that compute 2D or 3D convolution with CuDNN
Parameters Parameters
...@@ -325,6 +334,8 @@ def conv_forward(x, ...@@ -325,6 +334,8 @@ def conv_forward(x,
if algo == -1, the best algo will be chosen by CUDNN if algo == -1, the best algo will be chosen by CUDNN
conv_dtype: str conv_dtype: str
convolution type convolution type
groups: int
the number of groups
Returns Returns
------- -------
...@@ -335,8 +346,7 @@ def conv_forward(x, ...@@ -335,8 +346,7 @@ def conv_forward(x,
assert dims in (4, 5) assert dims in (4, 5)
conv_dtype = x.dtype if conv_dtype is None else conv_dtype conv_dtype = x.dtype if conv_dtype is None else conv_dtype
pad, stride, dilation, _, _ = \ pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation)
_prepare_global_func_params(dims - 2, pad, stride, dilation)
oshape = conv_output_shape(tensor_format, oshape = conv_output_shape(tensor_format,
pad, pad,
...@@ -345,7 +355,8 @@ def conv_forward(x, ...@@ -345,7 +355,8 @@ def conv_forward(x,
list(x.shape), list(x.shape),
list(w.shape), list(w.shape),
x.dtype, x.dtype,
conv_dtype) conv_dtype,
groups)
if algo == -1: if algo == -1:
# For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when # For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
# using INT8 data type, CuDNN will crash down. # using INT8 data type, CuDNN will crash down.
...@@ -361,7 +372,8 @@ def conv_forward(x, ...@@ -361,7 +372,8 @@ def conv_forward(x,
list(w.shape), list(w.shape),
oshape, oshape,
x.dtype, x.dtype,
conv_dtype) conv_dtype,
groups)
if dims == 4: if dims == 4:
return te.extern( return te.extern(
...@@ -380,7 +392,8 @@ def conv_forward(x, ...@@ -380,7 +392,8 @@ def conv_forward(x,
ins[0], ins[0],
ins[1], ins[1],
outs[0], outs[0],
conv_dtype), name="y") conv_dtype,
groups), name="y")
return te.extern( return te.extern(
oshape, [x, w], oshape, [x, w],
...@@ -401,7 +414,8 @@ def conv_forward(x, ...@@ -401,7 +414,8 @@ def conv_forward(x,
ins[0], ins[0],
ins[1], ins[1],
outs[0], outs[0],
conv_dtype), name="y") conv_dtype,
groups), name="y")
def softmax(x, axis=-1): def softmax(x, axis=-1):
"""Compute softmax using CuDNN """Compute softmax using CuDNN
......
...@@ -161,7 +161,9 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): ...@@ -161,7 +161,9 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \ if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
padding[1] == padding[3]: padding[1] == padding[3]:
strategy.add_implementation( strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_cudnn, True), wrap_compute_conv2d(topi.cuda.conv2d_cudnn,
need_data_layout=True,
has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn), wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
name="conv2d_cudnn.cuda", name="conv2d_cudnn.cuda",
plevel=15) plevel=15)
...@@ -181,6 +183,20 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): ...@@ -181,6 +183,20 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
else: else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d else: # group_conv2d
# add cudnn implementation, if any
cudnn_impl = False
if target.target_name == "cuda" and "cudnn" in target.libs:
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
padding[1] == padding[3]:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_cudnn,
need_data_layout=True,
has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
name="conv2d_cudnn.cuda",
plevel=15)
cudnn_impl = True
if layout == 'NCHW': if layout == 'NCHW':
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8. # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW" assert kernel_layout == "OIHW"
...@@ -194,7 +210,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): ...@@ -194,7 +210,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True), wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8), wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
name="group_conv2d_NCHWc_int8.cuda") name="group_conv2d_NCHWc_int8.cuda")
else: elif not cudnn_impl:
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
return strategy return strategy
......
...@@ -35,6 +35,7 @@ void ConvolutionForward( ...@@ -35,6 +35,7 @@ void ConvolutionForward(
int format, int format,
int algo, int algo,
int dims, int dims,
int groups,
const int pad[], const int pad[],
const int stride[], const int stride[],
const int dilation[], const int dilation[],
...@@ -62,8 +63,10 @@ void ConvolutionForward( ...@@ -62,8 +63,10 @@ void ConvolutionForward(
// Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error
// in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
if (dims == 2) { if (dims == 2) {
// Set Desc // Set Desc
CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
pad[0], pad[0],
pad[1], pad[1],
...@@ -183,6 +186,7 @@ void ConvolutionForward( ...@@ -183,6 +186,7 @@ void ConvolutionForward(
void OutputShape( void OutputShape(
int format, int format,
int dims, int dims,
int groups,
const int pad[], const int pad[],
const int stride[], const int stride[],
const int dilation[], const int dilation[],
...@@ -202,6 +206,7 @@ void OutputShape( ...@@ -202,6 +206,7 @@ void OutputShape(
int full_dims = dims + 2; int full_dims = dims + 2;
// conv desc // conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
dims, dims,
pad, pad,
...@@ -240,6 +245,7 @@ void OutputShape( ...@@ -240,6 +245,7 @@ void OutputShape(
// Set Input // Set Input
std::vector<int> tensor_stride(full_dims); std::vector<int> tensor_stride(full_dims);
GetCudnnStride(full_dims, x_dim, tensor_stride.data()); GetCudnnStride(full_dims, x_dim, tensor_stride.data());
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
data_type, data_type,
full_dims, full_dims,
...@@ -264,6 +270,7 @@ void OutputShape( ...@@ -264,6 +270,7 @@ void OutputShape(
void FindAlgo( void FindAlgo(
int format, int format,
int dims, int dims,
int groups,
const int pad[], const int pad[],
const int stride[], const int stride[],
const int dilation[], const int dilation[],
...@@ -284,6 +291,7 @@ void FindAlgo( ...@@ -284,6 +291,7 @@ void FindAlgo(
int full_dims = dims + 2; int full_dims = dims + 2;
// conv desc // conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
dims, dims,
pad, pad,
...@@ -360,16 +368,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") ...@@ -360,16 +368,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
int algo = args[2]; int algo = args[2];
int pad_v[2], stride_v[2], dilation_v[2]; int pad_v[2], stride_v[2], dilation_v[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
pad_v[i] = args[3 + i]; pad_v[i] = args[3 + i];
stride_v[i] = args[5 + i]; stride_v[i] = args[5 + i];
dilation_v[i] = args[7 + i]; dilation_v[i] = args[7 + i];
} }
DLTensor* x = args[9]; DLTensor* x = args[9];
DLTensor* w = args[10]; DLTensor* w = args[10];
DLTensor* y = args[11]; DLTensor* y = args[11];
std::string conv_dtype = args[12]; std::string conv_dtype = args[12];
int groups = args[13];
ConvolutionForward(mode, format, algo, 2, pad_v, stride_v, dilation_v, x, w, y, conv_dtype); ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v,
dilation_v, x, w, y, conv_dtype);
}); });
...@@ -380,17 +390,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") ...@@ -380,17 +390,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
int algo = args[2]; int algo = args[2];
int pad_v[3], stride_v[3], dilation_v[3]; int pad_v[3], stride_v[3], dilation_v[3];
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
pad_v[i] = args[3 + i]; pad_v[i] = args[3 + i];
stride_v[i] = args[6 + i]; stride_v[i] = args[6 + i];
dilation_v[i] = args[9 + i]; dilation_v[i] = args[9 + i];
} }
DLTensor *x = args[12]; DLTensor *x = args[12];
DLTensor *w = args[13]; DLTensor *w = args[13];
DLTensor *y = args[14]; DLTensor *y = args[14];
std::string conv_dtype = args[15]; std::string conv_dtype = args[15];
int groups = args[16];
ConvolutionForward(mode, format, algo, 3, pad_v, stride_v, dilation_v, x, w, y, ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v,
conv_dtype); dilation_v, x, w, y, conv_dtype);
}); });
...@@ -406,8 +417,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape") ...@@ -406,8 +417,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape")
void* out_shape = args[7]; void* out_shape = args[7];
std::string data_dtype = args[8]; std::string data_dtype = args[8];
std::string conv_dtype = args[9]; std::string conv_dtype = args[9];
int groups = args[10];
OutputShape(format, dims, pad, stride, dilation, x_dim, OutputShape(format, dims, groups, pad, stride, dilation, x_dim,
w_dim, out_shape, data_dtype, conv_dtype); w_dim, out_shape, data_dtype, conv_dtype);
}); });
...@@ -424,8 +436,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo") ...@@ -424,8 +436,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo")
int* y_dim = static_cast<int*>(static_cast<void*>(args[7])); int* y_dim = static_cast<int*>(static_cast<void*>(args[7]));
std::string data_dtype = args[8]; std::string data_dtype = args[8];
std::string conv_dtype = args[9]; std::string conv_dtype = args[9];
int groups = args[10];
FindAlgo(format, dims, pad, stride, dilation, x_dim, FindAlgo(format, dims, groups, pad, stride, dilation, x_dim,
w_dim, y_dim, data_dtype, conv_dtype, ret); w_dim, y_dim, data_dtype, conv_dtype, ret);
}); });
......
...@@ -78,7 +78,6 @@ struct ConvEntry { ...@@ -78,7 +78,6 @@ struct ConvEntry {
runtime::DeviceAPI *cuda_api; runtime::DeviceAPI *cuda_api;
void *workspace{nullptr}; void *workspace{nullptr};
size_t workspace_size{0}; size_t workspace_size{0};
int group_count {0};
ConvEntry(); ConvEntry();
~ConvEntry(); ~ConvEntry();
void UpdateWorkspace(const size_t wsize); void UpdateWorkspace(const size_t wsize);
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
import tvm import tvm
from tvm import te from tvm import te
from tvm.contrib import cudnn from tvm.contrib import cudnn
from tvm.contrib.nvcc import have_fp16
import numpy as np import numpy as np
import topi.testing import topi.testing
def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
def verify_conv2d(data_dtype, conv_dtype, tensor_format=0):
in_channel = 4 in_channel = 4
out_channel = 16 out_channel = 16
filter_h = 3 filter_h = 3
...@@ -34,7 +34,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): ...@@ -34,7 +34,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0):
dilation_w = 1 dilation_w = 1
batch = 3 batch = 3
height = 32 height = 32
weight = 32 width = 32
if not tvm.runtime.enabled("cuda"): if not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled...") print("skip because cuda is not enabled...")
...@@ -42,12 +42,17 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): ...@@ -42,12 +42,17 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0):
if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
print("skip because cudnn is not enabled...") print("skip because cudnn is not enabled...")
return return
if data_dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
# schedule
if tensor_format == 0: if tensor_format == 0:
xshape = [batch, in_channel, height, weight] xshape = [batch, in_channel, height, width]
wshape = [out_channel, in_channel, filter_h, filter_w] wshape = [out_channel, in_channel // groups, filter_h, filter_w]
else: else:
xshape = [batch, height, weight, in_channel] xshape = [batch, height, width, in_channel]
wshape = [out_channel, filter_h, filter_w, in_channel] wshape = [out_channel, filter_h, filter_w, in_channel // groups]
X = te.placeholder(xshape, name='X', dtype=data_dtype) X = te.placeholder(xshape, name='X', dtype=data_dtype)
W = te.placeholder(wshape, name='W', dtype=data_dtype) W = te.placeholder(wshape, name='W', dtype=data_dtype)
...@@ -59,39 +64,41 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): ...@@ -59,39 +64,41 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0):
conv_mode=1, conv_mode=1,
tensor_format=tensor_format, tensor_format=tensor_format,
conv_dtype=conv_dtype, conv_dtype=conv_dtype,
algo=-1) algo=-1,
groups=groups)
yshape = [x.value for x in Y.shape] yshape = [x.value for x in Y.shape]
s = te.create_schedule(Y.op) s = te.create_schedule(Y.op)
def verify(): # validation
ctx = tvm.gpu(0) ctx = tvm.gpu(0)
f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d") f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d")
x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype)
w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)
y_np = np.zeros(yshape).astype(data_dtype) y_np = np.zeros(yshape).astype(data_dtype)
x = tvm.nd.array(x_np, ctx) x = tvm.nd.array(x_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
y = tvm.nd.array(y_np, ctx) y = tvm.nd.array(y_np, ctx)
if tensor_format == 0: if tensor_format == 0:
c_np = topi.testing.conv2d_nchw_python(x_np, w_np, 1, 1) c_np = topi.testing.conv2d_nchw_python(x_np, w_np, 1, 1, groups=groups)
elif tensor_format == 1: elif tensor_format == 1:
wt = w_np.transpose((1, 2, 3, 0)) #OHWI => HWIO wt = w_np.transpose((1, 2, 3, 0)) #OHWI => HWIO
c_np = topi.testing.conv2d_nhwc_python(x_np, wt, 1, 1) c_np = topi.testing.conv2d_nhwc_python(x_np, wt, 1, 1, groups=groups)
f(x, w, y) f(x, w, y)
tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=3e-5, rtol=1e-3) tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=1e-2, rtol=1e-2)
verify()
def test_conv2d(): def test_conv2d():
verify_conv2d("float32", "float32", tensor_format=0) verify_conv2d("float32", "float32", tensor_format=0)
verify_conv2d("float16", "float32", tensor_format=1) verify_conv2d("float16", "float32", tensor_format=1)
#Not pass accuracy test, need check verify_conv2d("float16", "float16", tensor_format=0)
#verify_conv2d("float16", "float16", tensor_format=0)
verify_conv2d("int8", "int32", tensor_format=1) verify_conv2d("int8", "int32", tensor_format=1)
verify_conv2d("float32", "float32", tensor_format=0, groups=2)
verify_conv2d("float16", "float32", tensor_format=1, groups=2)
verify_conv2d("float16", "float16", tensor_format=0, groups=2)
verify_conv2d("int8", "int32", tensor_format=1, groups=2)
def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1):
in_channel = 4 in_channel = 4
out_channel = 16 out_channel = 16
filter_d = 3 filter_d = 3
...@@ -109,7 +116,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): ...@@ -109,7 +116,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0):
batch = 3 batch = 3
depth = 32 depth = 32
height = 32 height = 32
weight = 32 width = 32
if not tvm.runtime.enabled("cuda"): if not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled...") print("skip because cuda is not enabled...")
...@@ -118,8 +125,9 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): ...@@ -118,8 +125,9 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0):
print("skip because cudnn is not enabled...") print("skip because cudnn is not enabled...")
return return
xshape = [batch, in_channel, depth, height, weight] # schedule
wshape = [out_channel, in_channel, filter_d, filter_h, filter_w] xshape = [batch, in_channel, depth, height, width]
wshape = [out_channel, in_channel // groups, filter_d, filter_h, filter_w]
X = te.placeholder(xshape, name='X', dtype=data_dtype) X = te.placeholder(xshape, name='X', dtype=data_dtype)
W = te.placeholder(wshape, name='W', dtype=data_dtype) W = te.placeholder(wshape, name='W', dtype=data_dtype)
...@@ -131,33 +139,31 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): ...@@ -131,33 +139,31 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0):
conv_mode=1, conv_mode=1,
tensor_format=tensor_format, tensor_format=tensor_format,
algo=-1, algo=-1,
conv_dtype=conv_dtype) conv_dtype=conv_dtype,
groups=groups)
yshape = [x.value for x in Y.shape] yshape = [x.value for x in Y.shape]
s = te.create_schedule(Y.op) s = te.create_schedule(Y.op)
def verify(): # validation
ctx = tvm.gpu(0) ctx = tvm.gpu(0)
f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv3d") f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv3d")
x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype)
w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)
y_np = np.zeros(yshape).astype(data_dtype) y_np = np.zeros(yshape).astype(data_dtype)
x = tvm.nd.array(x_np, ctx) x = tvm.nd.array(x_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
y = tvm.nd.array(y_np, ctx) y = tvm.nd.array(y_np, ctx)
if tensor_format == 0: if tensor_format == 0:
c_np = topi.testing.conv3d_ncdhw_python(x_np, w_np, 1, 1) c_np = topi.testing.conv3d_ncdhw_python(x_np, w_np, 1, 1, groups)
else: else:
raise AssertionError("For now, conv3d tensor format only support: 0(NCHW)") raise AssertionError("For now, conv3d tensor format only support: 0(NCHW)")
f(x, w, y)
tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=3e-5, rtol=1e-4)
verify()
f(x, w, y)
tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=3e-5, rtol=1e-4)
def test_conv3d(): def test_conv3d():
verify_conv3d("float32", "float32", tensor_format=0) verify_conv3d("float32", "float32", tensor_format=0)
verify_conv3d("float32", "float32", tensor_format=0, groups=2)
def verify_softmax(shape, axis, dtype="float32"): def verify_softmax(shape, axis, dtype="float32"):
A = te.placeholder(shape, dtype=dtype, name='A') A = te.placeholder(shape, dtype=dtype, name='A')
......
...@@ -66,8 +66,8 @@ def schedule_conv2d_nhwc(cfg, outs): ...@@ -66,8 +66,8 @@ def schedule_conv2d_nhwc(cfg, outs):
@autotvm.register_topi_compute("conv2d_cudnn.cuda") @autotvm.register_topi_compute("conv2d_cudnn.cuda")
def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCHW', def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1,
out_dtype='float32'): layout='NCHW', out_dtype='float32'):
"""Compute conv2d using CuDNN library""" """Compute conv2d using CuDNN library"""
if layout == 'NCHW': if layout == 'NCHW':
tensor_format = 0 # CUDNN_TENSOR_NCHW tensor_format = 0 # CUDNN_TENSOR_NCHW
...@@ -89,7 +89,7 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ...@@ -89,7 +89,7 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCHW',
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
OH = (H + pt + pb - KH) // stride_h + 1 OH = (H + pt + pb - KH) // stride_h + 1
OW = (W + pl + pr - KW) // stride_w + 1 OW = (W + pl + pr - KW) // stride_w + 1
cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) * \ cfg.add_flop(groups * 2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) * \
((KW - 1) * dilation_w + 1)) ((KW - 1) * dilation_w + 1))
if data.dtype == "int8" or kernel.dtype == "int8": if data.dtype == "int8" or kernel.dtype == "int8":
...@@ -107,7 +107,8 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ...@@ -107,7 +107,8 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCHW',
conv_mode=1, conv_mode=1,
tensor_format=tensor_format, tensor_format=tensor_format,
algo=-1, # let CUDNN choose the best algo algo=-1, # let CUDNN choose the best algo
conv_dtype=dtype) conv_dtype=dtype,
groups=groups)
@autotvm.register_topi_schedule("conv2d_cudnn.cuda") @autotvm.register_topi_schedule("conv2d_cudnn.cuda")
......
...@@ -21,7 +21,7 @@ import scipy.signal ...@@ -21,7 +21,7 @@ import scipy.signal
from topi.nn.util import get_pad_tuple from topi.nn.util import get_pad_tuple
def conv2d_nhwc_python(a_np, w_np, stride, padding): def _conv2d_nhwc_python(a_np, w_np, stride, padding):
"""Convolution operator in NHWC layout. """Convolution operator in NHWC layout.
Parameters Parameters
...@@ -77,3 +77,38 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): ...@@ -77,3 +77,38 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding):
apad, np.rot90(np.rot90(wt[f, c])), mode='valid') apad, np.rot90(np.rot90(wt[f, c])), mode='valid')
bt[n, f] += out[::stride_h, ::stride_w] bt[n, f] += out[::stride_h, ::stride_w]
return bt.transpose((0, 2, 3, 1)) return bt.transpose((0, 2, 3, 1))
def conv2d_nhwc_python(a_np, w_np, stride, padding, groups=1):
"""Convolution operator in NHWC layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_height, in_width, in_channel]
w_np : numpy.ndarray
4-D with shape [filter_height, filter_width, in_channel // groups, num_filter]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str or a list/tuple of 2 or 4 ints
Padding size, or ['VALID', 'SAME'], or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 2 ints
groups : int
Number of groups
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_height, out_width, out_channel]
"""
a_slices = np.array_split(a_np, groups, axis=3)
w_slices = np.array_split(w_np, groups, axis=3)
b_slices = [_conv2d_nhwc_python(a_slice, w_slice, stride, padding)
for a_slice, w_slice in zip(a_slices, w_slices)]
b_np = np.concatenate(b_slices, axis=3)
return b_np
...@@ -73,6 +73,7 @@ def conv3d_ncdhw_python(a_np, w_np, stride, padding, groups=1): ...@@ -73,6 +73,7 @@ def conv3d_ncdhw_python(a_np, w_np, stride, padding, groups=1):
padding : int or str or a list/tuple of three ints padding : int or str or a list/tuple of three ints
Padding size, or ['VALID', 'SAME'], or [pad_depth, pad_height, pad_width] Padding size, or ['VALID', 'SAME'], or [pad_depth, pad_height, pad_width]
groups : int groups : int
Number of groups Number of groups
......
...@@ -75,7 +75,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -75,7 +75,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
with tvm.target.create(device): with tvm.target.create(device):
if "cudnn" in device: if "cudnn" in device:
C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), "NCHW", dtype) C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), 1, "NCHW", dtype)
else: else:
C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype) C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
if add_bias: if add_bias:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment