From 0c523787297039ce00b320c1d32e022e61e97ac2 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng <mercy_zheng@sjtu.edu.cn>
Date: Wed, 29 Aug 2018 23:09:33 -0700
Subject: [PATCH] [PASS] Enhance gpu verify pass (#1660)

---
 src/pass/verify_gpu_code.cc                   | 16 ++++++++++++-
 .../unittest/test_pass_verify_gpu_code.py     | 24 +++++++++++++++++++
 2 files changed, 39 insertions(+), 1 deletion(-)

diff --git a/src/pass/verify_gpu_code.cc b/src/pass/verify_gpu_code.cc
index 363b7c4cf..70908eb43 100644
--- a/src/pass/verify_gpu_code.cc
+++ b/src/pass/verify_gpu_code.cc
@@ -86,17 +86,29 @@ class GPUCodeVerifier : public IRVisitor {
       // record the number of threads in a block
       std::string name = var.get()->name_hint;
       if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
+        size_t length = static_cast<size_t>(extent->value);
         if (!visited_threads_.count(name)) {
           visited_threads_.insert(name);
-          size_t length = static_cast<size_t>(extent->value);
           thread_per_block_ *= length;
 
           if (name == "threadIdx.x") {
             valid_ &= length <= max_thread_x_;
+            thread_x_extent_ = length;
           } else if (name == "threadIdx.y") {
             valid_ &= length <= max_thread_y_;
+            thread_y_extent_ = length;
           } else if (name == "threadIdx.z") {
             valid_ &= length <= max_thread_z_;
+            thread_z_extent_ = length;
+          }
+        } else {
+          // the thread should be bound to axes with the same length
+          if (name == "threadIdx.x") {
+            valid_ &= length == thread_x_extent_;
+          } else if (name == "threadIdx.y") {
+            valid_ &= length == thread_y_extent_;
+          } else if (name == "threadIdx.z") {
+            valid_ &= length == thread_z_extent_;
           }
         }
       }
@@ -111,6 +123,8 @@ class GPUCodeVerifier : public IRVisitor {
   std::unordered_set<const tvm::Variable *> visited_shared_buffers_;
   std::unordered_set<std::string> visited_threads_;
 
+  size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;
+
   size_t local_memory_per_block_;
   size_t shared_memory_per_block_;
   size_t thread_per_block_;
diff --git a/tests/python/unittest/test_pass_verify_gpu_code.py b/tests/python/unittest/test_pass_verify_gpu_code.py
index 6fc0387cf..e3884a727 100644
--- a/tests/python/unittest/test_pass_verify_gpu_code.py
+++ b/tests/python/unittest/test_pass_verify_gpu_code.py
@@ -162,8 +162,32 @@ def test_multiple_kernels():
             tvm.build(s, [A, C], target)
         assert valid[0]
 
+def test_wrong_bind():
+    N = 1024
+
+    A = tvm.placeholder((N, N-1), name='A')
+    B = tvm.compute((N, N-1), lambda i, j: A[i, j])
+
+    s = tvm.create_schedule([B.op])
+
+    # bind a thread axis to two loop axes with different lengths
+    s[B].bind(s[B].op.axis[0], tvm.thread_axis("threadIdx.x"))
+    s[B].bind(s[B].op.axis[1], tvm.thread_axis("threadIdx.x"))
+
+    for target in ['opencl', 'cuda']:
+        if not tvm.context(target).exist:
+            continue
+
+        valid = [None]
+        with tvm.build_config(**{"add_lower_pass": [
+                (2, get_verify_pass(valid, max_threads_per_block=N*N))]}):
+            tvm.build(s, [A, B], target)
+        assert not valid[0]
+
+
 if __name__ == "__main__":
     test_local_memory()
     test_shared_memory()
     test_num_thread()
     test_multiple_kernels()
+    test_wrong_bind()
-- 
GitLab