From af9f69a70a51d89d09e118a854385e532e32939d Mon Sep 17 00:00:00 2001
From: Yuwei Hu <huyuwei1995@gmail.com>
Date: Fri, 12 Jan 2018 01:05:19 +0800
Subject: [PATCH] [INTRIN] enable popcount on cuda, opencl, metal (#774)

---
 src/codegen/intrin_rule.h              | 12 ++----
 src/codegen/intrin_rule_cuda.cc        | 16 ++++++++
 src/codegen/intrin_rule_metal.cc       | 13 ++++---
 src/codegen/intrin_rule_opencl.cc      | 13 ++++---
 tests/python/integration/test_ewise.py | 51 +++++++++++++++++---------
 5 files changed, 69 insertions(+), 36 deletions(-)

diff --git a/src/codegen/intrin_rule.h b/src/codegen/intrin_rule.h
index e66b55dcc..c900c9088 100644
--- a/src/codegen/intrin_rule.h
+++ b/src/codegen/intrin_rule.h
@@ -30,18 +30,14 @@ struct FloatSuffix {
   }
 };
 
-// Add float suffix to the intrinsics
-struct FloatDirect {
+// Return the intrinsic name
+struct Direct {
   std::string operator()(Type t, std::string name) const {
-    if (t.is_float()) {
-      return name;
-    } else {
-      return "";
-    }
+    return name;
   }
 };
 
-// Directly call pure extern function for floats.
+// Call pure extern function.
 template<typename T>
 inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
   Expr e = args[0];
diff --git a/src/codegen/intrin_rule_cuda.cc b/src/codegen/intrin_rule_cuda.cc
index a2441d597..9abb99d7c 100644
--- a/src/codegen/intrin_rule_cuda.cc
+++ b/src/codegen/intrin_rule_cuda.cc
@@ -36,6 +36,19 @@ struct CUDAFastMath : public CUDAMath {
   }
 };
 
+struct CUDAPopcount {
+  std::string operator()(Type t, std::string name) const {
+    if (t.lanes() == 1 && t.is_uint()) {
+      switch (t.bits()) {
+        case 32: return "__popc";
+        case 64: return "__popcll";
+        default: return "";
+      }
+    }
+    return "";
+  }
+};
+
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
 .set_body(DispatchExtern<CUDAFastMath>);
 
@@ -51,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt")
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow")
 .set_body(DispatchExtern<CUDAMath>);
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
+.set_body(DispatchExtern<CUDAPopcount>);
+
 }  // namespace intrin
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/codegen/intrin_rule_metal.cc b/src/codegen/intrin_rule_metal.cc
index fbadf3a19..b0e41770e 100644
--- a/src/codegen/intrin_rule_metal.cc
+++ b/src/codegen/intrin_rule_metal.cc
@@ -10,19 +10,22 @@ namespace codegen {
 namespace intrin {
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp")
-.set_body(DispatchExtern<FloatDirect>);
+.set_body(DispatchExtern<Direct>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log")
-.set_body(DispatchExtern<FloatDirect>);
+.set_body(DispatchExtern<Direct>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh")
-.set_body(DispatchExtern<FloatDirect>);
+.set_body(DispatchExtern<Direct>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt")
-.set_body(DispatchExtern<FloatDirect>);
+.set_body(DispatchExtern<Direct>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow")
-.set_body(DispatchExtern<FloatDirect>);
+.set_body(DispatchExtern<Direct>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount")
+.set_body(DispatchExtern<Direct>);
 
 }  // namespace intrin
 }  // namespace codegen
diff --git a/src/codegen/intrin_rule_opencl.cc b/src/codegen/intrin_rule_opencl.cc
index a947715ac..924abcade 100644
--- a/src/codegen/intrin_rule_opencl.cc
+++ b/src/codegen/intrin_rule_opencl.cc
@@ -10,19 +10,22 @@ namespace codegen {
 namespace intrin {
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
-.set_body(DispatchExtern<FloatDirect>);
+.set_body(DispatchExtern<Direct>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log")
-.set_body(DispatchExtern<FloatDirect>);
+.set_body(DispatchExtern<Direct>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh")
-.set_body(DispatchExtern<FloatDirect>);
+.set_body(DispatchExtern<Direct>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt")
-.set_body(DispatchExtern<FloatDirect>);
+.set_body(DispatchExtern<Direct>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow")
-.set_body(DispatchExtern<FloatDirect>);
+.set_body(DispatchExtern<Direct>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount")
+.set_body(DispatchExtern<Direct>);
 
 }  // namespace intrin
 }  // namespace codegen
diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py
index 24adf6ff2..f8dc43da8 100644
--- a/tests/python/integration/test_ewise.py
+++ b/tests/python/integration/test_ewise.py
@@ -60,25 +60,40 @@ def test_log_pow_llvm():
         b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5)
 
 
-def test_popcount_llvm():
-    # graph
-    n = tvm.var('n')
-    A = tvm.placeholder((n,), name='A', dtype="uint32")
-    B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B')
-    s = tvm.create_schedule(B.op)
+def test_popcount():
+    def run(dtype):
+        # graph
+        n = tvm.convert(1024)
+        A = tvm.placeholder((n,), name='A', dtype=dtype)
+        B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B')
+        s = tvm.create_schedule(B.op)
+        # simple schedule
+        num_thread = 8
+        bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
 
-    if not tvm.module.enabled("llvm"):
-        return
-    f = tvm.build(s, [A, B], "llvm")
-    ctx = tvm.cpu(0)
-    # launch the kernel.
-    n = 1024
-    a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx)
-    b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx)
-    f(a, b)
-    np.testing.assert_allclose(
-        b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5)
+        def check_device(device):
+            if not tvm.module.enabled(device):
+                print("skip because %s is not enabled.." % device)
+                return
+            ctx = tvm.context(device, 0)
+            if str(ctx).startswith('gpu'):
+                s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
+                s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
+            func = tvm.build(s, [A, B], device)
+            # launch the kernel.
+            n = 1024
+            a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx)
+            b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx)
+            func(a, b)
+            np.testing.assert_allclose(
+                b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5)
 
+        check_device("llvm")
+        check_device("cuda")
+        check_device("opencl")
+        check_device("metal")
+    run('uint32')
+    run('uint64')
 
 
 def test_add():
@@ -133,5 +148,5 @@ def test_add():
 if __name__ == "__main__":
     test_add()
     test_log_pow_llvm()
-    test_popcount_llvm()
+    test_popcount()
     test_exp()
-- 
GitLab