diff --git a/CMakeLists.txt b/CMakeLists.txt index e072ced6fda19af1df638ee6e100e3683c52b4d6..eb52d1f827235f37d789cc3604154a6cdd92fad8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,6 +32,8 @@ tvm_option(USE_RTTI "Build with RTTI" ON) tvm_option(USE_MSVC_MT "Build with MT" OFF) tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF) +tvm_option(USE_CUDNN "Build with cuDNN" OFF) + include_directories("include") include_directories("HalideIR/src") include_directories("dlpack/include") @@ -126,6 +128,24 @@ find_library(CUDA_NVRTC_LIBRARIES nvrtc ${CUDA_TOOLKIT_ROOT_DIR}/lib) list(APPEND TVM_LINKER_LIBS ${CUDA_NVRTC_LIB}) endif(MSVC) + + if(USE_CUDNN) + message(STATUS "Build with cuDNN support") + file(GLOB CONTRIB_CUDNN_SRCS src/contrib/cudnn/*.cc) + list(APPEND RUNTIME_SRCS ${CONTRIB_CUDNN_SRCS}) + if(MSVC) + find_library(CUDA_CUDNN_LIB cudnn + ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib/win32) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIB}) + else(MSVC) + find_library(CUDA_CUDNN_LIB cudnn + ${CUDA_TOOLKIT_ROOT_DIR}/lib64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIB}) + endif(MSVC) + endif(USE_CUDNN) + add_definitions(-DTVM_CUDA_RUNTIME=1) else(USE_CUDA) add_definitions(-DTVM_CUDA_RUNTIME=0)