From d50f7b66477097639012411ef8f13db193bfd027 Mon Sep 17 00:00:00 2001 From: Sergey Mironov <grrwlf@gmail.com> Date: Fri, 7 Dec 2018 17:43:28 +0300 Subject: [PATCH] Fix missing sigmoid intrinsic in C++ (#2231) --- python/tvm/intrin.py | 3 --- src/codegen/intrin_rule.cc | 10 ++++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 3207b6112..cd9a108c5 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -492,6 +492,3 @@ def _rule_float_direct(op): register_intrin_rule("opencl", "exp", _rule_float_direct, override=True) # default pattern for exp register_intrin_rule("default", "exp", _rule_float_suffix, override=True) - -# default pattern for sigmoid -register_intrin_rule("default", "sigmoid", lambda op: 1.0 / (1.0 + exp(-op.args[0]))) diff --git a/src/codegen/intrin_rule.cc b/src/codegen/intrin_rule.cc index 822d515fb..f326fceb6 100644 --- a/src/codegen/intrin_rule.cc +++ b/src/codegen/intrin_rule.cc @@ -24,6 +24,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow") .set_body(DispatchExtern<FloatSuffix>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid") +.set_body([](const TVMArgs& args, TVMRetValue* rv){ + Expr e = args[0]; + const Call* call = e.as<Call>(); + CHECK(call != nullptr); + + auto one = make_const(call->args[0].type(), 1); + *rv = one / (one + exp(-call->args[0])); + }); + } // namespace intrin } // namespace codegen } // namespace tvm -- GitLab