From 66fa0c3d62728d69ee034aa903b2d4079d149180 Mon Sep 17 00:00:00 2001
From: masahi <masahi129@gmail.com>
Date: Fri, 29 Dec 2017 10:53:05 +0900
Subject: [PATCH] Let CUDNN choose the best algo (#734)

* use cudnn findalgo to choose the best algo

* fix lint
---
 python/tvm/contrib/cudnn.py       | 79 ++++++++++++++++++++++++-
 src/contrib/cudnn/conv_forward.cc | 98 ++++++++++++++++++++++++++++++-
 topi/python/topi/cuda/conv2d.py   |  2 +-
 3 files changed, 176 insertions(+), 3 deletions(-)

diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py
index e728e42f6..5200f3193 100644
--- a/python/tvm/contrib/cudnn.py
+++ b/python/tvm/contrib/cudnn.py
@@ -220,6 +220,70 @@ def conv2d_output_shape(tensor_format,
     return list(oshape)
 
 
+def conv2d_find_algo(tensor_format,
+                     pad_h,
+                     pad_w,
+                     stride_h,
+                     stride_w,
+                     dilation_h,
+                     dilation_w,
+                     x_shape,
+                     w_shape,
+                     y_shape):
+    """Choose the best algo for the given input.
+
+    Paramters
+    ---------
+    tensor_format: int
+        0: CUDNN_TENSOR_NCHW
+        1: CUDNN_TENSOR_NHWC
+        2: CUDNN_TENSOR_NCHW_VECT_C
+    pad_h: int
+        height pad
+    pad_w: int
+        weight pad
+    stride_h: int
+        height stride
+    stride_w: int
+        width stride
+    dilation_h: int
+        height dilation
+    dilation_w: int
+        width dilation
+    x_shape: list
+        input shape
+    w_shape: list
+        weight shape
+    y_shape: list
+        output shape
+
+    Returns
+    -------
+    algo: int
+        algo chosen by CUDNN
+    """
+    func = _get_global_func("tvm.contrib.cudnn.conv2d.find_algo")
+    return func(tensor_format,
+                pad_h,
+                pad_w,
+                stride_h,
+                stride_w,
+                dilation_h,
+                dilation_w,
+                x_shape[0].value,
+                x_shape[1].value,
+                x_shape[2].value,
+                x_shape[3].value,
+                w_shape[0].value,
+                w_shape[1].value,
+                w_shape[2].value,
+                w_shape[3].value,
+                y_shape[0],
+                y_shape[1],
+                y_shape[2],
+                y_shape[3])
+
+
 def conv2d_forward(x,
                    w,
                    stride_h=1,
@@ -230,7 +294,7 @@ def conv2d_forward(x,
                    dilation_w=1,
                    conv_mode=1,
                    tensor_format=0,
-                   algo=0):
+                   algo=-1):
     """Create an extern op that compute 2D convolution with CuDNN
 
     Parameters
@@ -260,6 +324,7 @@ def conv2d_forward(x,
         2: CUDNN_TENSOR_NCHW_VECT_C
     algo: int
         Forward algorithm, get index from ```algo_to_index``` function
+        if algo == -1, the best algo will be chosen by CUDNN
 
     Returns
     -------
@@ -275,6 +340,18 @@ def conv2d_forward(x,
                                  dilation_w,
                                  list(x.shape),
                                  list(w.shape))
+    if algo == -1:
+        algo = conv2d_find_algo(tensor_format,
+                                pad_h,
+                                pad_w,
+                                stride_h,
+                                stride_w,
+                                dilation_h,
+                                dilation_w,
+                                list(x.shape),
+                                list(w.shape),
+                                oshape)
+
     return _api.extern(
         oshape, [x, w],
         lambda ins, outs: _intrin.call_packed(
diff --git a/src/contrib/cudnn/conv_forward.cc b/src/contrib/cudnn/conv_forward.cc
index 480a78930..4cd25f0c2 100644
--- a/src/contrib/cudnn/conv_forward.cc
+++ b/src/contrib/cudnn/conv_forward.cc
@@ -153,7 +153,103 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape")
                                                    static_cast<int*>(out_shape) + 1,
                                                    static_cast<int*>(out_shape) + 2,
                                                    static_cast<int*>(out_shape) + 3));
-  });
+});
+
+
+TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  int format = args[0];
+  int pad_h = args[1];
+  int pad_w = args[2];
+  int stride_h = args[3];
+  int stride_w = args[4];
+  int dilation_h = args[5];
+  int dilation_w = args[6];
+  int x_dim0 = args[7];
+  int x_dim1 = args[8];
+  int x_dim2 = args[9];
+  int x_dim3 = args[10];
+  int w_dim0 = args[11];
+  int w_dim1 = args[12];
+  int w_dim2 = args[13];
+  int w_dim3 = args[14];
+  int y_dim0 = args[15];
+  int y_dim1 = args[16];
+  int y_dim2 = args[17];
+  int y_dim3 = args[18];
+
+  // Set Format
+  entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
+  // conv desc
+  CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
+                                             pad_h,
+                                             pad_w,
+                                             stride_h,
+                                             stride_w,
+                                             dilation_h,
+                                             dilation_w,
+                                             CUDNN_CROSS_CORRELATION,
+                                             entry_ptr->conv_entry.data_type));
+  // input desc
+  CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
+                                        entry_ptr->conv_entry.tensor_format,
+                                        CUDNN_DATA_FLOAT,
+                                        x_dim0,
+                                        x_dim1,
+                                        x_dim2,
+                                        x_dim3));
+  // filter desc
+  CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc,
+                                        CUDNN_DATA_FLOAT,
+                                        CUDNN_TENSOR_NCHW,
+                                        w_dim0,
+                                        w_dim1,
+                                        w_dim2,
+                                        w_dim3));
+
+  // output desc
+  CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc,
+                                        entry_ptr->conv_entry.tensor_format,
+                                        entry_ptr->conv_entry.data_type,
+                                        y_dim0,
+                                        y_dim1,
+                                        y_dim2,
+                                        y_dim3));
+
+  int returned_algo_count = 0;
+  cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
+  CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(entry_ptr->handle,
+                                                  entry_ptr->conv_entry.input_desc,
+                                                  entry_ptr->conv_entry.filter_desc,
+                                                  entry_ptr->conv_entry.conv_desc,
+                                                  entry_ptr->conv_entry.output_desc,
+                                                  CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
+                                                  &returned_algo_count,
+                                                  perf_results));
+
+  const std::vector<std::string> fwd_algo_names{
+      "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
+      "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
+      "CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
+      "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
+      "CUDNN_CONVOLUTION_FWD_ALGO_FFT",
+      "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
+      "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
+      "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"
+  };
+
+  auto best_algo = perf_results[0].algo;
+  LOG(INFO) << "\tCUDNN Found " << returned_algo_count
+            << " fwd algorithms, choosing " << fwd_algo_names[best_algo];
+  for (int i = 0; i < returned_algo_count; ++i) {
+    LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo]
+              << " - time: " << perf_results[i].time << " ms"
+              << ", Memory: " << perf_results[i].memory;
+  }
+
+  ret[0] = best_algo;
+});
 
 }  // namespace contrib
 }  // namespace tvm
diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py
index 62b5642ab..2641bfe49 100644
--- a/topi/python/topi/cuda/conv2d.py
+++ b/topi/python/topi/cuda/conv2d.py
@@ -56,7 +56,7 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
                                     1,  # dilation_w
                                     conv_mode=1,
                                     tensor_format=tensor_format,
-                                    algo=0)
+                                    algo=-1) # let CUDNN choose the best algo
     elif layout == 'NCHW':
         return topi.nn.conv2d_nchw(data, kernel, stride, padding, out_dtype)
     elif layout == 'HWCN':
-- 
GitLab