From 66fa0c3d62728d69ee034aa903b2d4079d149180 Mon Sep 17 00:00:00 2001 From: masahi <masahi129@gmail.com> Date: Fri, 29 Dec 2017 10:53:05 +0900 Subject: [PATCH] Let CUDNN choose the best algo (#734) * use cudnn findalgo to choose the best algo * fix lint --- python/tvm/contrib/cudnn.py | 79 ++++++++++++++++++++++++- src/contrib/cudnn/conv_forward.cc | 98 ++++++++++++++++++++++++++++++- topi/python/topi/cuda/conv2d.py | 2 +- 3 files changed, 176 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index e728e42f6..5200f3193 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -220,6 +220,70 @@ def conv2d_output_shape(tensor_format, return list(oshape) +def conv2d_find_algo(tensor_format, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + x_shape, + w_shape, + y_shape): + """Choose the best algo for the given input. + + Paramters + --------- + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + 2: CUDNN_TENSOR_NCHW_VECT_C + pad_h: int + height pad + pad_w: int + weight pad + stride_h: int + height stride + stride_w: int + width stride + dilation_h: int + height dilation + dilation_w: int + width dilation + x_shape: list + input shape + w_shape: list + weight shape + y_shape: list + output shape + + Returns + ------- + algo: int + algo chosen by CUDNN + """ + func = _get_global_func("tvm.contrib.cudnn.conv2d.find_algo") + return func(tensor_format, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + x_shape[0].value, + x_shape[1].value, + x_shape[2].value, + x_shape[3].value, + w_shape[0].value, + w_shape[1].value, + w_shape[2].value, + w_shape[3].value, + y_shape[0], + y_shape[1], + y_shape[2], + y_shape[3]) + + def conv2d_forward(x, w, stride_h=1, @@ -230,7 +294,7 @@ def conv2d_forward(x, dilation_w=1, conv_mode=1, tensor_format=0, - algo=0): + algo=-1): """Create an extern op that compute 2D convolution with CuDNN Parameters @@ -260,6 +324,7 @@ def conv2d_forward(x, 2: CUDNN_TENSOR_NCHW_VECT_C algo: int Forward algorithm, get index from ```algo_to_index``` function + if algo == -1, the best algo will be chosen by CUDNN Returns ------- @@ -275,6 +340,18 @@ def conv2d_forward(x, dilation_w, list(x.shape), list(w.shape)) + if algo == -1: + algo = conv2d_find_algo(tensor_format, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + list(x.shape), + list(w.shape), + oshape) + return _api.extern( oshape, [x, w], lambda ins, outs: _intrin.call_packed( diff --git a/src/contrib/cudnn/conv_forward.cc b/src/contrib/cudnn/conv_forward.cc index 480a78930..4cd25f0c2 100644 --- a/src/contrib/cudnn/conv_forward.cc +++ b/src/contrib/cudnn/conv_forward.cc @@ -153,7 +153,103 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape") static_cast<int*>(out_shape) + 1, static_cast<int*>(out_shape) + 2, static_cast<int*>(out_shape) + 3)); - }); +}); + + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo") +.set_body([](TVMArgs args, TVMRetValue *ret) { + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + int format = args[0]; + int pad_h = args[1]; + int pad_w = args[2]; + int stride_h = args[3]; + int stride_w = args[4]; + int dilation_h = args[5]; + int dilation_w = args[6]; + int x_dim0 = args[7]; + int x_dim1 = args[8]; + int x_dim2 = args[9]; + int x_dim3 = args[10]; + int w_dim0 = args[11]; + int w_dim1 = args[12]; + int w_dim2 = args[13]; + int w_dim3 = args[14]; + int y_dim0 = args[15]; + int y_dim1 = args[16]; + int y_dim2 = args[17]; + int y_dim3 = args[18]; + + // Set Format + entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format); + // conv desc + CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + CUDNN_CROSS_CORRELATION, + entry_ptr->conv_entry.data_type)); + // input desc + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.tensor_format, + CUDNN_DATA_FLOAT, + x_dim0, + x_dim1, + x_dim2, + x_dim3)); + // filter desc + CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, + CUDNN_DATA_FLOAT, + CUDNN_TENSOR_NCHW, + w_dim0, + w_dim1, + w_dim2, + w_dim3)); + + // output desc + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.tensor_format, + entry_ptr->conv_entry.data_type, + y_dim0, + y_dim1, + y_dim2, + y_dim3)); + + int returned_algo_count = 0; + cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT]; + CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(entry_ptr->handle, + entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.output_desc, + CUDNN_CONVOLUTION_FWD_ALGO_COUNT, + &returned_algo_count, + perf_results)); + + const std::vector<std::string> fwd_algo_names{ + "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM", + "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM", + "CUDNN_CONVOLUTION_FWD_ALGO_GEMM", + "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT", + "CUDNN_CONVOLUTION_FWD_ALGO_FFT", + "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING", + "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD", + "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED" + }; + + auto best_algo = perf_results[0].algo; + LOG(INFO) << "\tCUDNN Found " << returned_algo_count + << " fwd algorithms, choosing " << fwd_algo_names[best_algo]; + for (int i = 0; i < returned_algo_count; ++i) { + LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo] + << " - time: " << perf_results[i].time << " ms" + << ", Memory: " << perf_results[i].memory; + } + + ret[0] = best_algo; +}); } // namespace contrib } // namespace tvm diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index 62b5642ab..2641bfe49 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -56,7 +56,7 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 1, # dilation_w conv_mode=1, tensor_format=tensor_format, - algo=0) + algo=-1) # let CUDNN choose the best algo elif layout == 'NCHW': return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype) elif layout == 'HWCN': -- GitLab