diff --git a/src/codegen/intrin_rule.h b/src/codegen/intrin_rule.h index e66b55dcc8617cc2870b6cc33ac5f0d264a97b48..c900c9088880d77fdf139c10886c55e83ba81d7e 100644 --- a/src/codegen/intrin_rule.h +++ b/src/codegen/intrin_rule.h @@ -30,18 +30,14 @@ struct FloatSuffix { } }; -// Add float suffix to the intrinsics -struct FloatDirect { +// Return the intrinsic name +struct Direct { std::string operator()(Type t, std::string name) const { - if (t.is_float()) { - return name; - } else { - return ""; - } + return name; } }; -// Directly call pure extern function for floats. +// Call pure extern function. template<typename T> inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { Expr e = args[0]; diff --git a/src/codegen/intrin_rule_cuda.cc b/src/codegen/intrin_rule_cuda.cc index a2441d597d862d6b00ba6485c125f4920671bece..9abb99d7c7c5ffe371e536e52ecf3c9531a4f76f 100644 --- a/src/codegen/intrin_rule_cuda.cc +++ b/src/codegen/intrin_rule_cuda.cc @@ -36,6 +36,19 @@ struct CUDAFastMath : public CUDAMath { } }; +struct CUDAPopcount { + std::string operator()(Type t, std::string name) const { + if (t.lanes() == 1 && t.is_uint()) { + switch (t.bits()) { + case 32: return "__popc"; + case 64: return "__popcll"; + default: return ""; + } + } + return ""; + } +}; + TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") .set_body(DispatchExtern<CUDAFastMath>); @@ -51,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow") .set_body(DispatchExtern<CUDAMath>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount") +.set_body(DispatchExtern<CUDAPopcount>); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/codegen/intrin_rule_metal.cc b/src/codegen/intrin_rule_metal.cc index fbadf3a19bdf624864fc874b5e43c66cda8def14..b0e41770ebff14a6eb1deab236ec59f9d39d60b6 100644 --- a/src/codegen/intrin_rule_metal.cc +++ b/src/codegen/intrin_rule_metal.cc @@ -10,19 +10,22 @@ namespace codegen { namespace intrin { TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount") +.set_body(DispatchExtern<Direct>); } // namespace intrin } // namespace codegen diff --git a/src/codegen/intrin_rule_opencl.cc b/src/codegen/intrin_rule_opencl.cc index a947715acdacab78cb61cb8a41b84c348b4988aa..924abcade63f04a54c3a723360376881b5596fd3 100644 --- a/src/codegen/intrin_rule_opencl.cc +++ b/src/codegen/intrin_rule_opencl.cc @@ -10,19 +10,22 @@ namespace codegen { namespace intrin { TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow") -.set_body(DispatchExtern<FloatDirect>); +.set_body(DispatchExtern<Direct>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount") +.set_body(DispatchExtern<Direct>); } // namespace intrin } // namespace codegen diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 24adf6ff28af08ad0a3ad4884be4dc17261c7372..f8dc43da8d319920faa9b5d1ea33698e6c976da1 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -60,25 +60,40 @@ def test_log_pow_llvm(): b.asnumpy(), np.power(np.log(a.asnumpy()), 2.0), rtol=1e-5) -def test_popcount_llvm(): - # graph - n = tvm.var('n') - A = tvm.placeholder((n,), name='A', dtype="uint32") - B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B') - s = tvm.create_schedule(B.op) +def test_popcount(): + def run(dtype): + # graph + n = tvm.convert(1024) + A = tvm.placeholder((n,), name='A', dtype=dtype) + B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B') + s = tvm.create_schedule(B.op) + # simple schedule + num_thread = 8 + bx, tx = s[B].split(B.op.axis[0], factor=num_thread) - if not tvm.module.enabled("llvm"): - return - f = tvm.build(s, [A, B], "llvm") - ctx = tvm.cpu(0) - # launch the kernel. - n = 1024 - a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx) - b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx) - f(a, b) - np.testing.assert_allclose( - b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5) + def check_device(device): + if not tvm.module.enabled(device): + print("skip because %s is not enabled.." % device) + return + ctx = tvm.context(device, 0) + if str(ctx).startswith('gpu'): + s[B].bind(bx, tvm.thread_axis("blockIdx.x")) + s[B].bind(tx, tvm.thread_axis("threadIdx.x")) + func = tvm.build(s, [A, B], device) + # launch the kernel. + n = 1024 + a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx) + b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx) + func(a, b) + np.testing.assert_allclose( + b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5) + check_device("llvm") + check_device("cuda") + check_device("opencl") + check_device("metal") + run('uint32') + run('uint64') def test_add(): @@ -133,5 +148,5 @@ def test_add(): if __name__ == "__main__": test_add() test_log_pow_llvm() - test_popcount_llvm() + test_popcount() test_exp()