From c7e7e7f5e60cadc4be465497980ae3a0395dbbb6 Mon Sep 17 00:00:00 2001
From: Yida Wang <yidawa@gmail.com>
Date: Fri, 20 Apr 2018 20:39:31 -0700
Subject: [PATCH] add two more device properties (#1124)

---
 include/tvm/runtime/device_api.h        |  4 +++-
 python/tvm/_ffi/runtime_ctypes.py       | 12 ++++++++++++
 src/runtime/cuda/cuda_device_api.cc     | 10 ++++++++++
 src/runtime/metal/metal_device_api.mm   |  2 ++
 src/runtime/opencl/opencl_device_api.cc | 21 +++++++++++++++++++++
 src/runtime/opencl/opencl_module.cc     |  2 +-
 src/runtime/opengl/opengl_device_api.cc |  2 ++
 src/runtime/rocm/rocm_device_api.cc     |  2 ++
 src/runtime/vulkan/vulkan_device_api.cc |  2 ++
 9 files changed, 55 insertions(+), 2 deletions(-)

diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index aa8b43223..e43cc1f31 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -21,7 +21,9 @@ enum DeviceAttrKind : int {
   kWarpSize = 2,
   kMaxSharedMemoryPerBlock = 3,
   kComputeVersion = 4,
-  kDeviceName = 5
+  kDeviceName = 5,
+  kMaxClockRate = 6,
+  kMultiProcessorCount = 7
 };
 
 /*! \brief Number of bytes each allocation must align to */
diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py
index 3e947cb63..3fc020c87 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -166,6 +166,18 @@ class TVMContext(ctypes.Structure):
         return _api_internal._GetDeviceAttr(
             self.device_type, self.device_id, 5)
 
+    @property
+    def max_clock_rate(self):
+        """Return the max clock frequency of device."""
+        return _api_internal._GetDeviceAttr(
+            self.device_type, self.device_id, 6)
+
+    @property
+    def multi_processor_count(self):
+        """Return the number of compute units of device."""
+        return _api_internal._GetDeviceAttr(
+            self.device_type, self.device_id, 7)
+
     def sync(self):
         """Synchronize until jobs finished at the context."""
         check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc
index fe13e466b..3f697faab 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -62,6 +62,16 @@ class CUDADeviceAPI final : public DeviceAPI {
         *rv = std::string(props.name);
         return;
       }
+      case kMaxClockRate: {
+        CUDA_CALL(cudaDeviceGetAttribute(
+            &value, cudaDevAttrClockRate, ctx.device_id));
+        break;
+      }
+      case kMultiProcessorCount: {
+        CUDA_CALL(cudaDeviceGetAttribute(
+            &value, cudaDevAttrMultiProcessorCount, ctx.device_id));
+        break;
+      }
     }
     *rv = value;
   }
diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm
index 077d2546c..6d225ea7f 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -42,6 +42,8 @@ void MetalWorkspace::GetAttr(
     case kMaxSharedMemoryPerBlock: return;
     case kComputeVersion: return;
     case kDeviceName: return;
+    case kMaxClockRate: return;
+    case kMultiProcessorCount: return;
     case kExist: break;
   }
 }
diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc
index 40f34a652..da527c761 100644
--- a/src/runtime/opencl/opencl_device_api.cc
+++ b/src/runtime/opencl/opencl_device_api.cc
@@ -42,6 +42,11 @@ void OpenCLWorkspace::GetAttr(
       break;
     }
     case kWarpSize: {
+      /* TODO: the warp size of OpenCL device is not always 1
+               e.g. Intel GPU has a sub group concept which contains 8 - 32 work items,
+               corresponding to the number of SIMD entries the heardware configures.
+               We need to figure out a way to query this information from the hardware.
+      */
       *rv = 1;
       break;
     }
@@ -62,6 +67,22 @@ void OpenCLWorkspace::GetAttr(
       *rv = std::string(value);
       break;
     }
+    case kMaxClockRate: {
+      cl_uint value;
+      OPENCL_CALL(clGetDeviceInfo(
+          devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY,
+          sizeof(cl_uint), &value, nullptr));
+      *rv = static_cast<int32_t>(value);
+      break;
+    }
+    case kMultiProcessorCount: {
+      cl_uint value;
+      OPENCL_CALL(clGetDeviceInfo(
+          devices[index], CL_DEVICE_MAX_COMPUTE_UNITS,
+          sizeof(cl_uint), &value, nullptr));
+      *rv = static_cast<int32_t>(value);
+      break;
+    }
     case kExist: break;
   }
 }
diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc
index bde7e6b27..d8831880f 100644
--- a/src/runtime/opencl/opencl_module.cc
+++ b/src/runtime/opencl/opencl_module.cc
@@ -176,7 +176,7 @@ class OpenCLModuleNode : public ModuleNode {
 
 class OpenCLWrappedFunc {
  public:
-  // initialize the CUDA function.
+  // initialize the OpenCL function.
   void Init(OpenCLModuleNode* m,
             std::shared_ptr<ModuleNode> sptr,
             OpenCLModuleNode::KTRefEntry entry,
diff --git a/src/runtime/opengl/opengl_device_api.cc b/src/runtime/opengl/opengl_device_api.cc
index 0c6354e23..e925f863d 100644
--- a/src/runtime/opengl/opengl_device_api.cc
+++ b/src/runtime/opengl/opengl_device_api.cc
@@ -98,6 +98,8 @@ void OpenGLWorkspace::GetAttr(
       break;
     }
     case kDeviceName: return;
+    case kMaxClockRate: return;
+    case kMultiProcessorCount: return;
   }
 }
 
diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc
index 256c715e9..55272561c 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -52,6 +52,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
         return;
       }
       case kDeviceName: return;
+      case kMaxClockRate: return;
+      case kMultiProcessorCount: return;
     }
     *rv = value;
   }
diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc
index aaf658bba..a3e4fc294 100644
--- a/src/runtime/vulkan/vulkan_device_api.cc
+++ b/src/runtime/vulkan/vulkan_device_api.cc
@@ -74,6 +74,8 @@ void VulkanWorkspace::GetAttr(
       break;
     }
     case kDeviceName: return;
+    case kMaxClockRate: return;
+    case kMultiProcessorCount: return;
     case kExist: break;
   }
 }
-- 
GitLab