/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2017 by Contributors * \file Use external cudnn utils function */ #include <tvm/runtime/registry.h> #include <tvm/runtime/util.h> #include <tvm/runtime/device_api.h> #include "cudnn_utils.h" namespace tvm { namespace contrib { using namespace runtime; TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") .set_body([](TVMArgs args, TVMRetValue *ret) { int mode = args[0]; int format = args[1]; int algo = args[2]; int pad_h = args[3]; int pad_w = args[4]; int stride_h = args[5]; int stride_w = args[6]; int dilation_h = args[7]; int dilation_w = args[8]; DLTensor *x = args[9]; DLTensor *w = args[10]; DLTensor *y = args[11]; CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Mode entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode); // Set Format entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format); // Set Algo entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo); // Set Ctx entry_ptr->conv_entry.ctx = x->ctx; // Set Data Type entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); // Set Desc CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); // Set Filter CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.data_type, CUDNN_TENSOR_NCHW, static_cast<int>(w->shape[0]), static_cast<int>(w->shape[1]), static_cast<int>(w->shape[2]), static_cast<int>(w->shape[3]))); // Set Input CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, entry_ptr->conv_entry.data_type, static_cast<int>(x->shape[0]), static_cast<int>(x->shape[1]), static_cast<int>(x->shape[2]), static_cast<int>(x->shape[3]))); // Set Output CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, entry_ptr->conv_entry.data_type, static_cast<int>(y->shape[0]), static_cast<int>(y->shape[1]), static_cast<int>(y->shape[2]), static_cast<int>(y->shape[3]))); // Set workspace size_t workspace_size = 0; CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.fwd_algo, &workspace_size)); entry_ptr->conv_entry.UpdateWorkspace(workspace_size); CUDNN_CALL(cudnnConvolutionForward(entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo, entry_ptr->conv_entry.workspace, workspace_size, CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), entry_ptr->conv_entry.output_desc, y->data)); }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape") .set_body([](TVMArgs args, TVMRetValue *ret) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); int format = args[0]; int pad_h = args[1]; int pad_w = args[2]; int stride_h = args[3]; int stride_w = args[4]; int dilation_h = args[5]; int dilation_w = args[6]; int x_dim0 = args[7]; int x_dim1 = args[8]; int x_dim2 = args[9]; int x_dim3 = args[10]; int w_dim0 = args[11]; int w_dim1 = args[12]; int w_dim2 = args[13]; int w_dim3 = args[14]; void *out_shape = args[15]; // Set Format entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format); // conv desc CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, CUDNN_CROSS_CORRELATION, entry_ptr->conv_entry.data_type)); // input desc CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, CUDNN_DATA_FLOAT, x_dim0, x_dim1, x_dim2, x_dim3)); // filter desc CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, w_dim0, w_dim1, w_dim2, w_dim3)); CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, static_cast<int*>(out_shape), static_cast<int*>(out_shape) + 1, static_cast<int*>(out_shape) + 2, static_cast<int*>(out_shape) + 3)); }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo") .set_body([](TVMArgs args, TVMRetValue *ret) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); int format = args[0]; int pad_h = args[1]; int pad_w = args[2]; int stride_h = args[3]; int stride_w = args[4]; int dilation_h = args[5]; int dilation_w = args[6]; int x_dim0 = args[7]; int x_dim1 = args[8]; int x_dim2 = args[9]; int x_dim3 = args[10]; int w_dim0 = args[11]; int w_dim1 = args[12]; int w_dim2 = args[13]; int w_dim3 = args[14]; int y_dim0 = args[15]; int y_dim1 = args[16]; int y_dim2 = args[17]; int y_dim3 = args[18]; // Set Format entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format); // conv desc CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, CUDNN_CROSS_CORRELATION, entry_ptr->conv_entry.data_type)); // input desc CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, CUDNN_DATA_FLOAT, x_dim0, x_dim1, x_dim2, x_dim3)); // filter desc CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, w_dim0, w_dim1, w_dim2, w_dim3)); // output desc CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, entry_ptr->conv_entry.data_type, y_dim0, y_dim1, y_dim2, y_dim3)); int returned_algo_count = 0; cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT]; CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, CUDNN_CONVOLUTION_FWD_ALGO_COUNT, &returned_algo_count, perf_results)); const std::vector<std::string> fwd_algo_names{ "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM", "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM", "CUDNN_CONVOLUTION_FWD_ALGO_GEMM", "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT", "CUDNN_CONVOLUTION_FWD_ALGO_FFT", "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING", "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD", "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED" }; auto best_algo = perf_results[0].algo; LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " fwd algorithms, choosing " << fwd_algo_names[best_algo]; for (int i = 0; i < returned_algo_count; ++i) { LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo] << " - time: " << perf_results[i].time << " ms" << ", Memory: " << perf_results[i].memory; } ret[0] = best_algo; }); } // namespace contrib } // namespace tvm