From 32f158e6c2fe48bf536e47b4c85a9c28fc26dc3a Mon Sep 17 00:00:00 2001 From: xqdan <danxiaoqiang@126.com> Date: Mon, 29 Oct 2018 09:37:05 +0800 Subject: [PATCH] [intrin]support fmod for cuda (#1964) --- python/tvm/intrin.py | 16 +++++++++++ src/codegen/intrin_rule_cuda.cc | 2 ++ src/codegen/intrin_rule_metal.cc | 3 ++ src/codegen/intrin_rule_opencl.cc | 3 ++ src/lang/ir_operator.cc | 6 ++++ tests/python/integration/test_ewise.py | 40 ++++++++++++++++++++++++++ 6 files changed, 70 insertions(+) diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 30da873b5..3207b6112 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -376,6 +376,22 @@ def popcount(x): """ return call_pure_intrin(x.dtype, "popcount", x) +def fmod(x, y): + """Return the remainder of x divided by y with the same sign as x. + + Parameters + ---------- + x : Expr + Input argument. + y : Expr + Input argument. + + Returns + ------- + z : Expr + The result. + """ + return call_pure_intrin(x.dtype, "fmod", x, y) # Intrinsic rule related code def register_intrin_rule(target, intrin, f=None, override=False): diff --git a/src/codegen/intrin_rule_cuda.cc b/src/codegen/intrin_rule_cuda.cc index ee98a5432..a6867c7f2 100644 --- a/src/codegen/intrin_rule_cuda.cc +++ b/src/codegen/intrin_rule_cuda.cc @@ -91,6 +91,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle") .set_body(DispatchExtern<CUDAShuffle>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod") +.set_body(DispatchExtern<CUDAMath>); } // namespace intrin } // namespace codegen diff --git a/src/codegen/intrin_rule_metal.cc b/src/codegen/intrin_rule_metal.cc index 8b499fb9e..2e65d5537 100644 --- a/src/codegen/intrin_rule_metal.cc +++ b/src/codegen/intrin_rule_metal.cc @@ -42,6 +42,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount") .set_body(DispatchExtern<Direct>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod") +.set_body(DispatchExtern<Direct>); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/codegen/intrin_rule_opencl.cc b/src/codegen/intrin_rule_opencl.cc index 1cb1aed01..e4cf11bf6 100644 --- a/src/codegen/intrin_rule_opencl.cc +++ b/src/codegen/intrin_rule_opencl.cc @@ -42,6 +42,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount") .set_body(DispatchExtern<Direct>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod") +.set_body(DispatchExtern<Direct>); + // There is no warp shuffle instruction in standard OpenCL // When shuffle is used, we assume it is intel's shuffle extension struct IntelShuffle { diff --git a/src/lang/ir_operator.cc b/src/lang/ir_operator.cc index 275752644..9ae291290 100644 --- a/src/lang/ir_operator.cc +++ b/src/lang/ir_operator.cc @@ -450,4 +450,10 @@ Expr prod(Expr source, Array<IterVar> rdom) { return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } +Expr fmod(Expr x, Expr y) { + BinaryOpMatchTypes(x, y); + CHECK(x.type().is_float()) << "fmod only applies to float"; + return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic); +} + } // namespace tvm diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 0f58c2367..b3f17b7c1 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -38,6 +38,45 @@ def test_exp(): check_device("cuda", "llvm") check_device("vulkan") +def test_fmod(): + # graph + def run(dtype): + n = tvm.var('n') + A = tvm.placeholder((n,), name='A', dtype=dtype) + B = tvm.placeholder((n,), name='B', dtype=dtype) + C = tvm.compute(A.shape, lambda *i: tvm.fmod(A(*i), B(*i)), name='C') + s = tvm.create_schedule(C.op) + # create iter var and assign them tags. + num_thread = 8 + bx, tx = s[C].split(C.op.axis[0], factor=num_thread) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("skip because %s is not enabled.." % device) + return + target = tvm.target.create(device) + if "cpu" not in target.keys: + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + fmod = tvm.build(s, [A, B, C], device, name="myfmod") + + # launch the kernel. + n = 1024 + a = tvm.nd.array((np.random.uniform(size=n) * 256).astype(A.dtype), ctx) + b = tvm.nd.array((np.random.uniform(size=n) * 256).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + ftimer = fmod.time_evaluator(fmod.entry_name, ctx, number=1) + tcost = ftimer(a, b, c).mean + #fmod(a, b, c) + np.testing.assert_allclose( + c.asnumpy(), np.mod(a.asnumpy(), b.asnumpy()), rtol=1e-5) + + check_device("cuda") + check_device("opencl -device=intel_graphics") + check_device("metal") + + run("float32") def test_multiple_cache_write(): # graph @@ -245,3 +284,4 @@ if __name__ == "__main__": test_add() test_log_pow_llvm() test_popcount() + test_fmod() -- GitLab