diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h index 1173e0367718c05d9a718aa4e235fd3336e7a01f..755379b5061fc6e1161a73dbcd133bd9327f1f0c 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/ir_operator.h @@ -51,6 +51,7 @@ Expr min(Expr source, Array<IterVar> axis); TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(sigmoid); +TVM_DECLARE_INTRIN_UNARY(sqrt); } // namespace tvm #endif // TVM_IR_OPERATOR_H_ diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 274643a712be049edbf161763a8d7886e6f69a12..5a088a2e1040ff2dbca358e9162ecfae98a3f9e9 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -166,6 +166,22 @@ def log(x): return call_pure_intrin(x.dtype, "log", x) +def sqrt(x): + """Take log of input x. + + Parameters + ---------- + x : Expr + Input argument. + + Returns + ------- + y : Expr + The result. + """ + return call_pure_intrin(x.dtype, "sqrt", x) + + # Intrinsic rule related code def register_intrin_rule(target, intrin, f=None, override=False): """Register an intrinsic function generation rule. diff --git a/src/codegen/intrin_rule.cc b/src/codegen/intrin_rule.cc index d8c9c92a4bcace24e6b689d9f4fcef775e386414..dfe877b6b9e89866ad01bd5acf8030bc67b4a8ea 100644 --- a/src/codegen/intrin_rule.cc +++ b/src/codegen/intrin_rule.cc @@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") .set_body(DispatchExtern<FloatSuffix>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") +.set_body(DispatchExtern<FloatSuffix>); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/codegen/intrin_rule_cuda.cc b/src/codegen/intrin_rule_cuda.cc index 0ec2b0a268a7cf98623404786d2fbffa78fd64ba..edfe8c1127ac8c7dec81caf672edb10fe29e301f 100644 --- a/src/codegen/intrin_rule_cuda.cc +++ b/src/codegen/intrin_rule_cuda.cc @@ -45,6 +45,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh") .set_body(DispatchExtern<CUDAMath>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt") +.set_body(DispatchExtern<CUDAMath>); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/codegen/intrin_rule_opencl.cc b/src/codegen/intrin_rule_opencl.cc index 9283b53f6f44722f9174158d6f1b7729db55f0ba..5025382258bbcf5ab25ebcc4b89024a09f410883 100644 --- a/src/codegen/intrin_rule_opencl.cc +++ b/src/codegen/intrin_rule_opencl.cc @@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh") .set_body(DispatchExtern<FloatDirect>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt") +.set_body(DispatchExtern<FloatDirect>); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/codegen/llvm/intrin_rule_llvm.cc b/src/codegen/llvm/intrin_rule_llvm.cc index 13d90820abd04e9cecc7bd878a83d5bc4e5ea1e1..c20651538b867a612037c35721f36773509a06b0 100644 --- a/src/codegen/llvm/intrin_rule_llvm.cc +++ b/src/codegen/llvm/intrin_rule_llvm.cc @@ -37,6 +37,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") .set_body(DispatchLLVMIntrin<::llvm::Intrinsic::log>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt") +.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::sqrt>); + } // namespace llvm } // namespace codegen } // namespace tvm diff --git a/tests/python/unittest/test_topi_basic.py b/tests/python/unittest/test_topi_basic.py index ec0f71df81b3cb96f9932deb3138137681da3dd4..748231aad9791576432046fa04618b380ff10e79 100644 --- a/tests/python/unittest/test_topi_basic.py +++ b/tests/python/unittest/test_topi_basic.py @@ -14,6 +14,8 @@ def test_ewise(): test_apply(topi.exp, "exp") test_apply(topi.tanh, "tanh") test_apply(topi.sigmoid, "sigmoid") + test_apply(topi.log, "log") + test_apply(topi.sqrt, "sqrt") if __name__ == "__main__": diff --git a/topi/include/topi/ewise.h b/topi/include/topi/ewise.h index 72c0005100f888976056533b6e3bfad870ccfb30..2909e726fe07c0e6cea0ac772099a91ddcdebe94 100644 --- a/topi/include/topi/ewise.h +++ b/topi/include/topi/ewise.h @@ -22,5 +22,6 @@ using namespace tvm; TOPI_DECLARE_UNARY_OP(exp); TOPI_DECLARE_UNARY_OP(tanh); TOPI_DECLARE_UNARY_OP(sigmoid); +TOPI_DECLARE_UNARY_OP(sqrt); } // namespace topi #endif // TOPI_EWISE_H_ diff --git a/topi/python/topi/ewise.py b/topi/python/topi/ewise.py index f70288f1df2690314634ae43718483f38d641f60..a0c420c9f1f72fcd88bed0d6eb2e84311db20873 100644 --- a/topi/python/topi/ewise.py +++ b/topi/python/topi/ewise.py @@ -34,6 +34,38 @@ def tanh(x): return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i))) +def log(x): + """Take logarithm of input x. + + Parameters + ---------- + x : tvm.Tensor + Input argument. + + Returns + ------- + y : tvm.Tensor + The result. + """ + return tvm.compute(x.shape, lambda *i: tvm.log(x(*i))) + + +def sqrt(x): + """Take square root of input x. + + Parameters + ---------- + x : tvm.Tensor + Input argument. + + Returns + ------- + y : tvm.Tensor + The result. + """ + return tvm.compute(x.shape, lambda *i: tvm.sqrt(x(*i))) + + def sigmoid(x): """Take sigmoid tanh of input x.