diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 63e8ca7cd16b57c2d47d71a5d03e0408a6199af3..758d03b5b18bb4501ed1dc4c8297502836ac4eb9 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -118,6 +118,163 @@ class PackedFunc { FType body_; }; +/*! + * \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc<R(Args..)>" + */ +template<typename FType> +class TypedPackedFunc; + +/*! + * \anchor TypedPackedFuncAnchor + * \brief A PackedFunc wrapper to provide typed function signature. + * It is backed by a PackedFunc internally. + * + * TypedPackedFunc enables compile time type checking. + * TypedPackedFunc works with the runtime system: + * - It can be passed as an argument of PackedFunc. + * - It can be assigned to TVMRetValue. + * - It can be directly converted to a type-erased PackedFunc. + * + * Developers should prefer TypedPackedFunc over PackedFunc in C++ code + * as it enables compile time checking. + * We can construct a TypedPackedFunc from a lambda function + * with the same signature. + * + * \code + * // user defined lambda function. + * auto addone = [](int x)->int { + * return x + 1; + * }; + * // We can directly convert + * // lambda function to TypedPackedFunc + * TypedPackedFunc<int(int)> ftyped(addone); + * // invoke the function. + * int y = ftyped(1); + * // Can be directly converted to PackedFunc + * PackedFunc packed = ftype; + * \endcode + * \tparam R The return value of the function. + * \tparam Args The argument signature of the function. + */ +template<typename R, typename ...Args> +class TypedPackedFunc<R(Args...)> { + public: + /*! \brief short hand for this function type */ + using TSelf = TypedPackedFunc<R(Args...)>; + /*! \brief default constructor */ + TypedPackedFunc() {} + /*! + * \brief construct by wrap a PackedFunc + * + * Example usage: + * \code + * PackedFunc packed([](TVMArgs args, TVMRetValue *rv) { + * int x = args[0]; + * *rv = x + 1; + * }); + * // construct from packed function + * TypedPackedFunc<int(int)> ftyped(packed); + * // call the typed version. + * CHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param packed The packed function + */ + explicit TypedPackedFunc(PackedFunc packed) + : packed_(packed) { + } + /*! + * \brief construct from a lambda function with the same signature. + * + * Example usage: + * \code + * auto typed_lambda = [](int x)->int { return x + 1; } + * // construct from packed function + * TypedPackedFunc<int(int)> ftyped(typed_lambda); + * // call the typed version. + * CHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param typed_lambda typed lambda function. + * \tparam FLambda the type of the lambda function. + */ + template<typename FLambda, + typename = typename std::enable_if< + std::is_convertible<FLambda, + std::function<R(Args...)> + >::value>::type> + explicit TypedPackedFunc(const FLambda& typed_lambda) { + this->AssignTypedLambda(typed_lambda); + } + /*! + * \brief copy assignment operator from typed lambda + * + * Example usage: + * \code + * // construct from packed function + * TypedPackedFunc<int(int)> ftyped; + * ftyped = [](int x) { return x + 1; } + * // call the typed version. + * CHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param typed_lambda typed lambda function. + * \tparam FLambda the type of the lambda function. + * \returns reference to self. + */ + template<typename FLambda, + typename = typename std::enable_if< + std::is_convertible<FLambda, + std::function<R(Args...)> + >::value>::type> + TSelf& operator=(FLambda typed_lambda) { // NOLINT(*) + this->AssignTypedLambda(typed_lambda); + return *this; + } + /*! + * \brief copy assignment operator from PackedFunc. + * \param packed The packed function. + * \returns reference to self. + */ + TSelf& operator=(PackedFunc packed) { + packed_ = packed; + return *this; + } + /*! + * \brief Invoke the operator. + * \param args The arguments + * \returns The return value. + */ + inline R operator()(Args ...args) const; + /*! + * \brief convert to PackedFunc + * \return the internal PackedFunc + */ + operator PackedFunc() const { + return packed(); + } + /*! + * \return reference the internal PackedFunc + */ + const PackedFunc& packed() const { + return packed_; + } + + private: + friend class TVMRetValue; + /*! \brief The internal packed function */ + PackedFunc packed_; + /*! + * \brief Assign the packed field using a typed lambda function. + * + * \param flambda The lambda function. + * \tparam FLambda The lambda function type. + * \note We capture the lambda when possible for maximum efficiency. + */ + template<typename FLambda> + inline void AssignTypedLambda(FLambda flambda); +}; + /*! \brief Arguments into TVM functions. */ class TVMArgs { public: @@ -361,6 +518,10 @@ class TVMArgValue : public TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); return *ptr<PackedFunc>(); } + template<typename FType> + operator TypedPackedFunc<FType>() const { + return TypedPackedFunc<FType>(operator PackedFunc()); + } operator Module() const { TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle); return *ptr<Module>(); @@ -446,6 +607,10 @@ class TVMRetValue : public TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); return *ptr<PackedFunc>(); } + template<typename FType> + operator TypedPackedFunc<FType>() const { + return TypedPackedFunc<FType>(operator PackedFunc()); + } operator Module() const { TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle); return *ptr<Module>(); @@ -512,6 +677,10 @@ class TVMRetValue : public TVMPODValue_ { this->SwitchToClass(kFuncHandle, f); return *this; } + template<typename FType> + TVMRetValue& operator=(const TypedPackedFunc<FType>& f) { + return operator=(f.packed()); + } TVMRetValue& operator=(Module m) { this->SwitchToClass(kModuleHandle, m); return *this; @@ -847,6 +1016,10 @@ class TVMArgsSetter { values_[i].v_handle = const_cast<PackedFunc*>(&value); type_codes_[i] = kFuncHandle; } + template<typename FType> + void operator()(size_t i, const TypedPackedFunc<FType>& value) const { // NOLINT(*) + operator()(i, value.packed()); + } void operator()(size_t i, const Module& value) const { // NOLINT(*) values_[i].v_handle = const_cast<Module*>(&value); type_codes_[i] = kModuleHandle; @@ -894,6 +1067,84 @@ inline TVMRetValue PackedFunc::operator()(Args&& ...args) const { return rv; } +namespace detail { +template<typename R, int nleft, int index, typename F> +struct unpack_call_dispatcher { + template<typename ...Args> + static void run(const F& f, + const TVMArgs& args_pack, + TVMRetValue* rv, + Args&&... unpacked_args) { + unpack_call_dispatcher<R, nleft - 1, index + 1, F> + ::run(f, args_pack, rv, + std::forward<Args>(unpacked_args)..., + args_pack[index]); + } +}; + +template<typename R, int index, typename F> +struct unpack_call_dispatcher<R, 0, index, F> { + template<typename ...Args> + static void run(const F& f, + const TVMArgs& args_pack, + TVMRetValue* rv, + Args&&... unpacked_args) { + *rv = R(f(std::forward<Args>(unpacked_args)...)); + } +}; + +template<int index, typename F> +struct unpack_call_dispatcher<void, 0, index, F> { + template<typename ...Args> + static void run(const F& f, + const TVMArgs& args_pack, + TVMRetValue* rv, + Args&&... unpacked_args) { + f(std::forward<Args>(unpacked_args)...); + } +}; + +template<typename R, int nargs, typename F> +inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { + unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv); +} + +template<typename R, typename ...Args> +inline R call_packed(const PackedFunc& pf, Args&& ...args) { + return R(pf(std::forward<Args>(args)...)); +} + +template<typename R> +struct typed_packed_call_dispatcher { + template<typename ...Args> + static inline R run(const PackedFunc& pf, Args&& ...args) { + return pf(std::forward<Args>(args)...); + } +}; + +template<> +struct typed_packed_call_dispatcher<void> { + template<typename ...Args> + static inline void run(const PackedFunc& pf, Args&& ...args) { + pf(std::forward<Args>(args)...); + } +}; +} // namespace detail + +template<typename R, typename ...Args> +template<typename FType> +inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) { + packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) { + detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv); + }); +} + +template<typename R, typename ...Args> +inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const { + return detail::typed_packed_call_dispatcher<R> + ::run(packed_, std::forward<Args>(args)...); +} + // extension and node type handling namespace detail { template<typename T, typename TSrc, bool is_ext> diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 9b2f1df737312093adbb1ed3cd63f035b9c9f568..abe26fabe9ea24cc90f0055d81181975ce6edcec 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -135,6 +135,29 @@ TEST(PackedFunc, Type) { CHECK(get_type2("float32x2").operator Type() == Float(32, 2)); } +TEST(TypedPackedFunc, HighOrder) { + using namespace tvm; + using namespace tvm::runtime; + using Int1Func = TypedPackedFunc<int(int)>; + using Int2Func = TypedPackedFunc<int(int, int)>; + using BindFunc = TypedPackedFunc<Int1Func(Int2Func, int value)>; + BindFunc ftyped; + ftyped = [](Int2Func f1, int value) -> Int1Func { + auto binded = [f1, value](int x) { + return f1(value, x); + }; + Int1Func x(binded); + return x; + }; + auto add = [](int x, int y) { return x + y; }; + CHECK_EQ(ftyped(Int2Func(add), 1)(2), 3); + PackedFunc f = ftyped(Int2Func(add), 1); + CHECK_EQ(f(3).operator int(), 4); + // call the type erased version. + Int1Func f1 = ftyped.packed()(Int2Func(add), 1); + CHECK_EQ(f1(3), 4); +} + // new namespoace namespace test { // register int vector as extension type