From af9f69a70a51d89d09e118a854385e532e32939d Mon Sep 17 00:00:00 2001 From: Yuwei Hu <huyuwei1995@gmail.com> Date: Fri, 12 Jan 2018 01:05:19 +0800 Subject: [PATCH] [INTRIN] enable popcount on cuda, opencl, metal (#774) --- src/codegen/intrin_rule.h | 12 ++---- src/codegen/intrin_rule_cuda.cc | 16 ++++++++ src/codegen/intrin_rule_metal.cc | 13 ++++--- src/codegen/intrin_rule_opencl.cc | 13 ++++--- tests/python/integration/test_ewise.py | 51 +++++++++++++++++--------- 5 files changed, 69 insertions(+), 36 deletions(-) diff --git a/src/codegen/intrin_rule.h b/src/codegen/intrin_rule.h index e66b55dcc..c900c9088 100644 --- a/src/codegen/intrin_rule.h +++ b/src/codegen/intrin_rule.h @@ -30,18 +30,14 @@ struct FloatSuffix { } }; -// Add float suffix to the intrinsics -struct FloatDirect { +// Return the intrinsic name +struct Direct { std::string operator()(Type t, std::string name) const { - if (t.is_float()) { - return name; - } else { - return ""; - } + return name; } }; -// Directly call pure extern function for floats. +// Call pure extern function. template<typename T> inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { Expr e = args[0]; diff --git a/src/codegen/intrin_rule_cuda.cc b/src/codegen/intrin_rule_cuda.cc index a2441d597..9abb99d7c 100644 --- a/src/codegen/intrin_rule_cuda.cc +++ b/src/codegen/intrin_rule_cuda.cc @@ -36,6 +36,19 @@ struct CUDAFastMath : public CUDAMath { } }; +struct CUDAPopcount { + std::string operator()(Type t, std::string name) const { + if (t.lanes() == 1 && t.is_uint()) { + switch (t.bits()) { + case 32: return "__popc"; + case 64: return "__popcll"; + default: return ""; + } + } + return ""; + } +}; + TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") .set_body(DispatchExtern<CUDAFastMath>); @@ -51,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow") .set_body(DispatchExtern<CUDAMath>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount") +.set_body(DispatchExtern<CUDAPopcount>); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/codegen/intrin_rule_metal.cc b/src/codegen/intrin_rule_metal.cc index fbadf3a19..b0e41770e 100644 --- a/src/codegen/intrin_rule_metal.cc +++ b/src/codegen/intrin_rule_metal.cc @@ -10,19 +10,22 @@ namespace codegen { namespace intrin { TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount") +.set_body(DispatchExtern<Direct>); } // namespace intrin } // namespace codegen diff --git a/src/codegen/intrin_rule_opencl.cc b/src/codegen/intrin_rule_opencl.cc index a947715ac..924abcade 100644 --- a/src/codegen/intrin_rule_opencl.cc +++ b/src/codegen/intrin_rule_opencl.cc @@ -10,19 +10,22 @@ namespace codegen { namespace intrin { TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount") +.set_body(DispatchExtern<Direct>); } // namespace intrin } // namespace codegen diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 24adf6ff2..f8dc43da8 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -60,25 +60,40 @@ def test_log_pow_llvm(): b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5) -def test_popcount_llvm(): - # graph - n = tvm.var('n') - A = tvm.placeholder((n,), name='A', dtype="uint32") - B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B') - s = tvm.create_schedule(B.op) +def test_popcount(): + def run(dtype): + # graph + n = tvm.convert(1024) + A = tvm.placeholder((n,), name='A', dtype=dtype) + B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B') + s = tvm.create_schedule(B.op) + # simple schedule + num_thread = 8 + bx, tx = s[B].split(B.op.axis[0], factor=num_thread) - if not tvm.module.enabled("llvm"): - return - f = tvm.build(s, [A, B], "llvm") - ctx = tvm.cpu(0) - # launch the kernel. - n = 1024 - a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx) - b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx) - f(a, b) - np.testing.assert_allclose( - b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5) + def check_device(device): + if not tvm.module.enabled(device): + print("skip because %s is not enabled.." % device) + return + ctx = tvm.context(device, 0) + if str(ctx).startswith('gpu'): + s[B].bind(bx, tvm.thread_axis("blockIdx.x")) + s[B].bind(tx, tvm.thread_axis("threadIdx.x")) + func = tvm.build(s, [A, B], device) + # launch the kernel. + n = 1024 + a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx) + b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx) + func(a, b) + np.testing.assert_allclose( + b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5) + check_device("llvm") + check_device("cuda") + check_device("opencl") + check_device("metal") + run('uint32') + run('uint64') def test_add(): @@ -133,5 +148,5 @@ def test_add(): if __name__ == "__main__": test_add() test_log_pow_llvm() - test_popcount_llvm() + test_popcount() test_exp() -- GitLab