Skip to content
Snippets Groups Projects
Commit 0c523787 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by Tianqi Chen
Browse files

[PASS] Enhance gpu verify pass (#1660)

parent 9f99a4fa
No related branches found
No related tags found
No related merge requests found
...@@ -86,17 +86,29 @@ class GPUCodeVerifier : public IRVisitor { ...@@ -86,17 +86,29 @@ class GPUCodeVerifier : public IRVisitor {
// record the number of threads in a block // record the number of threads in a block
std::string name = var.get()->name_hint; std::string name = var.get()->name_hint;
if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") { if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
size_t length = static_cast<size_t>(extent->value);
if (!visited_threads_.count(name)) { if (!visited_threads_.count(name)) {
visited_threads_.insert(name); visited_threads_.insert(name);
size_t length = static_cast<size_t>(extent->value);
thread_per_block_ *= length; thread_per_block_ *= length;
if (name == "threadIdx.x") { if (name == "threadIdx.x") {
valid_ &= length <= max_thread_x_; valid_ &= length <= max_thread_x_;
thread_x_extent_ = length;
} else if (name == "threadIdx.y") { } else if (name == "threadIdx.y") {
valid_ &= length <= max_thread_y_; valid_ &= length <= max_thread_y_;
thread_y_extent_ = length;
} else if (name == "threadIdx.z") { } else if (name == "threadIdx.z") {
valid_ &= length <= max_thread_z_; valid_ &= length <= max_thread_z_;
thread_z_extent_ = length;
}
} else {
// the thread should be bound to axes with the same length
if (name == "threadIdx.x") {
valid_ &= length == thread_x_extent_;
} else if (name == "threadIdx.y") {
valid_ &= length == thread_y_extent_;
} else if (name == "threadIdx.z") {
valid_ &= length == thread_z_extent_;
} }
} }
} }
...@@ -111,6 +123,8 @@ class GPUCodeVerifier : public IRVisitor { ...@@ -111,6 +123,8 @@ class GPUCodeVerifier : public IRVisitor {
std::unordered_set<const tvm::Variable *> visited_shared_buffers_; std::unordered_set<const tvm::Variable *> visited_shared_buffers_;
std::unordered_set<std::string> visited_threads_; std::unordered_set<std::string> visited_threads_;
size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;
size_t local_memory_per_block_; size_t local_memory_per_block_;
size_t shared_memory_per_block_; size_t shared_memory_per_block_;
size_t thread_per_block_; size_t thread_per_block_;
......
...@@ -162,8 +162,32 @@ def test_multiple_kernels(): ...@@ -162,8 +162,32 @@ def test_multiple_kernels():
tvm.build(s, [A, C], target) tvm.build(s, [A, C], target)
assert valid[0] assert valid[0]
def test_wrong_bind():
N = 1024
A = tvm.placeholder((N, N-1), name='A')
B = tvm.compute((N, N-1), lambda i, j: A[i, j])
s = tvm.create_schedule([B.op])
# bind a thread axis to two loop axes with different lengths
s[B].bind(s[B].op.axis[0], tvm.thread_axis("threadIdx.x"))
s[B].bind(s[B].op.axis[1], tvm.thread_axis("threadIdx.x"))
for target in ['opencl', 'cuda']:
if not tvm.context(target).exist:
continue
valid = [None]
with tvm.build_config(**{"add_lower_pass": [
(2, get_verify_pass(valid, max_threads_per_block=N*N))]}):
tvm.build(s, [A, B], target)
assert not valid[0]
if __name__ == "__main__": if __name__ == "__main__":
test_local_memory() test_local_memory()
test_shared_memory() test_shared_memory()
test_num_thread() test_num_thread()
test_multiple_kernels() test_multiple_kernels()
test_wrong_bind()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment