Commit 66fa0c3d by masahi Committed by Tianqi Chen

Let CUDNN choose the best algo (#734)

* use cudnn findalgo to choose the best algo

* fix lint
parent f0cdb50e
...@@ -220,6 +220,70 @@ def conv2d_output_shape(tensor_format, ...@@ -220,6 +220,70 @@ def conv2d_output_shape(tensor_format,
return list(oshape) return list(oshape)
def conv2d_find_algo(tensor_format,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
x_shape,
w_shape,
y_shape):
"""Choose the best algo for the given input.
Paramters
---------
tensor_format: int
0: CUDNN_TENSOR_NCHW
1: CUDNN_TENSOR_NHWC
2: CUDNN_TENSOR_NCHW_VECT_C
pad_h: int
height pad
pad_w: int
weight pad
stride_h: int
height stride
stride_w: int
width stride
dilation_h: int
height dilation
dilation_w: int
width dilation
x_shape: list
input shape
w_shape: list
weight shape
y_shape: list
output shape
Returns
-------
algo: int
algo chosen by CUDNN
"""
func = _get_global_func("tvm.contrib.cudnn.conv2d.find_algo")
return func(tensor_format,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
x_shape[0].value,
x_shape[1].value,
x_shape[2].value,
x_shape[3].value,
w_shape[0].value,
w_shape[1].value,
w_shape[2].value,
w_shape[3].value,
y_shape[0],
y_shape[1],
y_shape[2],
y_shape[3])
def conv2d_forward(x, def conv2d_forward(x,
w, w,
stride_h=1, stride_h=1,
...@@ -230,7 +294,7 @@ def conv2d_forward(x, ...@@ -230,7 +294,7 @@ def conv2d_forward(x,
dilation_w=1, dilation_w=1,
conv_mode=1, conv_mode=1,
tensor_format=0, tensor_format=0,
algo=0): algo=-1):
"""Create an extern op that compute 2D convolution with CuDNN """Create an extern op that compute 2D convolution with CuDNN
Parameters Parameters
...@@ -260,6 +324,7 @@ def conv2d_forward(x, ...@@ -260,6 +324,7 @@ def conv2d_forward(x,
2: CUDNN_TENSOR_NCHW_VECT_C 2: CUDNN_TENSOR_NCHW_VECT_C
algo: int algo: int
Forward algorithm, get index from ```algo_to_index``` function Forward algorithm, get index from ```algo_to_index``` function
if algo == -1, the best algo will be chosen by CUDNN
Returns Returns
------- -------
...@@ -275,6 +340,18 @@ def conv2d_forward(x, ...@@ -275,6 +340,18 @@ def conv2d_forward(x,
dilation_w, dilation_w,
list(x.shape), list(x.shape),
list(w.shape)) list(w.shape))
if algo == -1:
algo = conv2d_find_algo(tensor_format,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
list(x.shape),
list(w.shape),
oshape)
return _api.extern( return _api.extern(
oshape, [x, w], oshape, [x, w],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
......
...@@ -153,7 +153,103 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape") ...@@ -153,7 +153,103 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape")
static_cast<int*>(out_shape) + 1, static_cast<int*>(out_shape) + 1,
static_cast<int*>(out_shape) + 2, static_cast<int*>(out_shape) + 2,
static_cast<int*>(out_shape) + 3)); static_cast<int*>(out_shape) + 3));
}); });
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
int format = args[0];
int pad_h = args[1];
int pad_w = args[2];
int stride_h = args[3];
int stride_w = args[4];
int dilation_h = args[5];
int dilation_w = args[6];
int x_dim0 = args[7];
int x_dim1 = args[8];
int x_dim2 = args[9];
int x_dim3 = args[10];
int w_dim0 = args[11];
int w_dim1 = args[12];
int w_dim2 = args[13];
int w_dim3 = args[14];
int y_dim0 = args[15];
int y_dim1 = args[16];
int y_dim2 = args[17];
int y_dim3 = args[18];
// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// conv desc
CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
CUDNN_CROSS_CORRELATION,
entry_ptr->conv_entry.data_type));
// input desc
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.tensor_format,
CUDNN_DATA_FLOAT,
x_dim0,
x_dim1,
x_dim2,
x_dim3));
// filter desc
CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
w_dim0,
w_dim1,
w_dim2,
w_dim3));
// output desc
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.tensor_format,
entry_ptr->conv_entry.data_type,
y_dim0,
y_dim1,
y_dim2,
y_dim3));
int returned_algo_count = 0;
cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(entry_ptr->handle,
entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.conv_desc,
entry_ptr->conv_entry.output_desc,
CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
&returned_algo_count,
perf_results));
const std::vector<std::string> fwd_algo_names{
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
"CUDNN_CONVOLUTION_FWD_ALGO_FFT",
"CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"
};
auto best_algo = perf_results[0].algo;
LOG(INFO) << "\tCUDNN Found " << returned_algo_count
<< " fwd algorithms, choosing " << fwd_algo_names[best_algo];
for (int i = 0; i < returned_algo_count; ++i) {
LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo]
<< " - time: " << perf_results[i].time << " ms"
<< ", Memory: " << perf_results[i].memory;
}
ret[0] = best_algo;
});
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
...@@ -56,7 +56,7 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 ...@@ -56,7 +56,7 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
1, # dilation_w 1, # dilation_w
conv_mode=1, conv_mode=1,
tensor_format=tensor_format, tensor_format=tensor_format,
algo=0) algo=-1) # let CUDNN choose the best algo
elif layout == 'NCHW': elif layout == 'NCHW':
return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype) return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype)
elif layout == 'HWCN': elif layout == 'HWCN':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment