diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index ff3fe80629200ad34ab709295b0884f08ab1ed68..f21be567021daa667d6a6508a945cbb6cb7f0cf0 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -19,7 +19,8 @@ enum DeviceAttrKind : int { kExist = 0, kMaxThreadsPerBlock = 1, kWarpSize = 2, - kComputeVersion = 3, + kMaxSharedMemoryPerBlock = 3, + kComputeVersion = 4, }; /*! \brief Number of bytes each allocation must align to */ diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 23af85bf26df965d308dc0402a54f74fa574801f..2fdde23803b3f3cbc11a23aa7c61f0a4f57b1dd3 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -140,6 +140,12 @@ class TVMContext(ctypes.Structure): return _api_internal._GetDeviceAttr( self.device_type, self.device_id, 2) + @property + def max_shared_memory_per_block(self): + """Total amount of shared memory per block in bytes""" + return _api_internal._GetDeviceAttr( + self.device_type, self.device_id, 3) + @property def compute_version(self): """Get compute verison number in string. @@ -152,7 +158,7 @@ class TVMContext(ctypes.Structure): The version string in `major.minor` format. """ return _api_internal._GetDeviceAttr( - self.device_type, self.device_id, 3) + self.device_type, self.device_id, 4) def sync(self): """Synchronize until jobs finished at the context.""" diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 9b0fee5023f182d981462a02de58e31a899da327..8de06d902449ce6a776136bb6dc5f0f9dd1653ca 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -40,6 +40,11 @@ class CUDADeviceAPI final : public DeviceAPI { &value, cudaDevAttrWarpSize, ctx.device_id)); break; } + case kMaxSharedMemoryPerBlock: { + CUDA_CALL(cudaDeviceGetAttribute( + &value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id)); + break; + } case kComputeVersion: { std::ostringstream os; CUDA_CALL(cudaDeviceGetAttribute(