From 2e17e850052f40891e01fb0ccc561df760cf5ed5 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng <mercy_zheng@sjtu.edu.cn>
Date: Sat, 7 Apr 2018 01:11:48 +0800
Subject: [PATCH] add query for shared memory size (#1083)

---
 include/tvm/runtime/device_api.h    | 3 ++-
 python/tvm/_ffi/runtime_ctypes.py   | 8 +++++++-
 src/runtime/cuda/cuda_device_api.cc | 5 +++++
 3 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index ff3fe8062..f21be5670 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -19,7 +19,8 @@ enum DeviceAttrKind : int {
   kExist = 0,
   kMaxThreadsPerBlock = 1,
   kWarpSize = 2,
-  kComputeVersion = 3,
+  kMaxSharedMemoryPerBlock = 3,
+  kComputeVersion = 4,
 };
 
 /*! \brief Number of bytes each allocation must align to */
diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py
index 23af85bf2..2fdde2380 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -140,6 +140,12 @@ class TVMContext(ctypes.Structure):
         return _api_internal._GetDeviceAttr(
             self.device_type, self.device_id, 2)
 
+    @property
+    def max_shared_memory_per_block(self):
+        """Total amount of shared memory per block in bytes"""
+        return _api_internal._GetDeviceAttr(
+            self.device_type, self.device_id, 3)
+
     @property
     def compute_version(self):
         """Get compute verison number in string.
@@ -152,7 +158,7 @@ class TVMContext(ctypes.Structure):
             The version string in `major.minor` format.
         """
         return _api_internal._GetDeviceAttr(
-            self.device_type, self.device_id, 3)
+            self.device_type, self.device_id, 4)
 
     def sync(self):
         """Synchronize until jobs finished at the context."""
diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc
index 9b0fee502..8de06d902 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -40,6 +40,11 @@ class CUDADeviceAPI final : public DeviceAPI {
             &value, cudaDevAttrWarpSize, ctx.device_id));
         break;
       }
+      case kMaxSharedMemoryPerBlock: {
+        CUDA_CALL(cudaDeviceGetAttribute(
+            &value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id));
+        break;
+      }
       case kComputeVersion: {
         std::ostringstream os;
         CUDA_CALL(cudaDeviceGetAttribute(
-- 
GitLab