From a6b4a219e26f2ad0cff5a1a2629aa418aacca1e7 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Thu, 22 Feb 2018 18:26:32 -0800 Subject: [PATCH] [RUNTIME] Refactor extension type handling, now it is header only (#924) * [RUNTIME] Refactor extension type handling, now it is header only --- apps/extension/src/tvm_ext.cc | 20 +++- apps/extension/tests/test_ext.py | 6 ++ include/tvm/runtime/c_runtime_api.h | 18 ++++ include/tvm/runtime/module.h | 12 ++- include/tvm/runtime/packed_func.h | 122 ++++++++++++----------- python/tvm/_ffi/function.py | 25 +++++ python/tvm/api.py | 2 +- python/tvm/build_module.py | 32 +++--- src/runtime/c_runtime_api.cc | 8 ++ src/runtime/module.cc | 13 --- src/runtime/registry.cc | 29 ------ tests/cpp/packed_func_test.cc | 3 - tests/scripts/task_python_integration.sh | 1 + topi/src/topi.cc | 2 - 14 files changed, 166 insertions(+), 127 deletions(-) diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index 6d7f4bdf7..8b086863f 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -22,13 +22,10 @@ struct extension_class_info<tvm_ext::IntVector> { } // namespace tvm } // namespace runtime - -namespace tvm_ext { - using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_EXT_TYPE(IntVector); +namespace tvm_ext { TVM_REGISTER_GLOBAL("tvm_ext.ivec_create") .set_body([](TVMArgs args, TVMRetValue *rv) { @@ -66,3 +63,18 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev") *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))(); }); } // namespace tvm_ext + +// This callback approach allows extension allows tvm to extract +// This way can be helpful when we want to use a header only +// minimum version of TVM Runtime. +extern "C" int TVMExtDeclare(TVMFunctionHandle pregister) { + const PackedFunc& fregister = + *static_cast<PackedFunc*>(pregister); + auto mul = [](TVMArgs args, TVMRetValue *rv) { + int x = args[0]; + int y = args[1]; + *rv = x * y; + }; + fregister("mul", PackedFunc(mul)); + return 0; +} diff --git a/apps/extension/tests/test_ext.py b/apps/extension/tests/test_ext.py index 0bbfff14e..628602f0b 100644 --- a/apps/extension/tests/test_ext.py +++ b/apps/extension/tests/test_ext.py @@ -44,8 +44,14 @@ def test_ext_vec(): tvm.convert(ivec_cb)(ivec) +def test_extract_ext(): + fdict = tvm.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare) + assert fdict["mul"](3, 4) == 12 + + if __name__ == "__main__": test_ext_dev() test_ext_vec() test_bind_add() test_sym_add() + test_extract_ext() diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index edade00c7..e4a06b39d 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -24,6 +24,13 @@ #define TVM_EXTERN_C #endif +// Macros to do weak linking +#ifdef _MSC_VER +#define TVM_WEAK __declspec(selectany) +#else +#define TVM_WEAK __attribute__((weak)) +#endif + #ifdef __EMSCRIPTEN__ #include <emscripten/emscripten.h> #define TVM_DLL EMSCRIPTEN_KEEPALIVE @@ -313,6 +320,17 @@ typedef int (*TVMPackedCFunc)( */ typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle); +/*! + * \brief Signature for extension function declarer. + * + * TVM call this function to get the extension functions + * The declarer will call register_func to register function and their name. + * + * \param resource_func_handle The register function + * \return 0 if success, -1 if failure happens + */ +typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle); + /*! * \brief Wrap a TVMPackedCFunc to become a FunctionHandle. * diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 3d0991034..f8e5069f5 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -38,8 +38,14 @@ class Module { * \param query_imports Whether also query dependency modules. * \return The result function. * This function will return PackedFunc(nullptr) if function do not exist. + * \note Implemented in packed_func.cc */ - TVM_DLL PackedFunc GetFunction(const std::string& name, bool query_imports = false); + inline PackedFunc GetFunction(const std::string& name, bool query_imports = false); + /*! \return internal container */ + inline ModuleNode* operator->(); + /*! \return internal container */ + inline const ModuleNode* operator->() const; + // The following functions requires link with runtime. /*! * \brief Import another module into this module. * \param other The module to be imported. @@ -57,10 +63,6 @@ class Module { */ TVM_DLL static Module LoadFromFile(const std::string& file_name, const std::string& format = ""); - /*! \return internal container */ - inline ModuleNode* operator->(); - /*! \return internal container */ - inline const ModuleNode* operator->() const; private: std::shared_ptr<ModuleNode> node_; diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index b01e662b9..ca2e020ba 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -183,31 +183,17 @@ struct extension_class_info { }; /*! - * \brief Runtime function table about extension type. + * \brief Capsule structure holding extension types + * Capsule is self-contained and include + * all the information to clone and destroy the type. */ -class ExtTypeVTable { - public: +struct TVMExtTypeCapsule { + /*! \brief The pointer to the object */ + void* ptr; /*! \brief function to be called to delete a handle */ void (*destroy)(void* handle); /*! \brief function to be called when clone a handle */ void* (*clone)(void* handle); - /*! - * \brief Register type - * \tparam T The type to be register. - * \return The registered vtable. - */ - template <typename T> - static inline ExtTypeVTable* Register_(); - /*! - * \brief Get a vtable based on type code. - * \param type_code The type code - * \return The registered vtable. - */ - TVM_DLL static ExtTypeVTable* Get(int type_code); - - private: - // Internal registration function. - TVM_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt); }; /*! @@ -255,8 +241,9 @@ class TVMPODValue_ { } template<typename TExtension> const TExtension& AsExtension() const { - CHECK_LT(type_code_, kExtEnd); - return static_cast<TExtension*>(value_.v_handle)[0]; + CHECK_EQ(type_code_, extension_class_info<TExtension>::code); + return static_cast<TExtension*>( + static_cast<TVMExtTypeCapsule*>(value_.v_handle)->ptr)[0]; } int type_code() const { return type_code_; @@ -488,14 +475,6 @@ class TVMRetValue : public TVMPODValue_ { this->Assign(other); return *this; } - template<typename T, - typename = typename std::enable_if< - extension_class_info<T>::code != 0>::type> - TVMRetValue& operator=(const T& other) { - this->SwitchToClass<T>( - extension_class_info<T>::code, other); - return *this; - } /*! * \brief Move the value back to front-end via C API. * This marks the current container as null. @@ -521,6 +500,11 @@ class TVMRetValue : public TVMPODValue_ { type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; return value_; } + // assign extension + template<typename T, + typename = typename std::enable_if< + extension_class_info<T>::code != 0>::type> + inline TVMRetValue& operator=(const T& other); // NodeRef related extenstions: in tvm/packed_func_ext.h template<typename T, typename = typename std::enable_if< @@ -564,11 +548,9 @@ class TVMRetValue : public TVMPODValue_ { SwitchToPOD(other.type_code()); value_ = other.value_; } else { - this->Clear(); - type_code_ = other.type_code(); - value_.v_handle = - (*(ExtTypeVTable::Get(other.type_code())->clone))( - other.value().v_handle); + TVMExtTypeCapsule cap = *other.template ptr<TVMExtTypeCapsule>(); + cap.ptr = cap.clone(cap.ptr); + SwitchToClass<TVMExtTypeCapsule>(other.type_code(), cap); } break; } @@ -600,7 +582,9 @@ class TVMRetValue : public TVMPODValue_ { case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break; } if (type_code_ > kExtBegin) { - (*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle); + TVMExtTypeCapsule *cap = ptr<TVMExtTypeCapsule>(); + cap->destroy(cap->ptr); + delete cap; } type_code_ = kNull; } @@ -716,8 +700,10 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*) /* \brief argument settter to PackedFunc */ class TVMArgsSetter { public: - TVMArgsSetter(TVMValue* values, int* type_codes) - : values_(values), type_codes_(type_codes) {} + TVMArgsSetter(TVMValue* values, + int* type_codes, + TVMExtTypeCapsule* exts) + : values_(values), type_codes_(type_codes), exts_(exts) {} // setters for POD types template<typename T, typename = typename std::enable_if< @@ -807,15 +793,21 @@ class TVMArgsSetter { TVMValue* values_; /*! \brief The type code fields */ int* type_codes_; + /*! \brief Temporary storage for extension types */ + TVMExtTypeCapsule* exts_; }; template<typename... Args> inline TVMRetValue PackedFunc::operator()(Args&& ...args) const { const int kNumArgs = sizeof...(Args); + // Compiler will remove an static array when it is not touched. const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; TVMValue values[kArraySize]; int type_codes[kArraySize]; - detail::for_each(TVMArgsSetter(values, type_codes), + // If the function call does not contain extension type, + // exts will get optimized away by compiler. + TVMExtTypeCapsule exts[kArraySize]; + detail::for_each(TVMArgsSetter(values, type_codes, exts), std::forward<Args>(args)...); TVMRetValue rv; body_(TVMArgs(values, type_codes, kNumArgs), &rv); @@ -853,14 +845,6 @@ inline TVMRetValue::operator T() const { ::Apply(this); } -template<typename T, typename> -inline void TVMArgsSetter::operator()(size_t i, const T& value) const { - static_assert(extension_class_info<T>::code != 0, - "Need to have extesion code"); - type_codes_[i] = extension_class_info<T>::code; - values_[i].v_handle = const_cast<T*>(&value); -} - // extension type handling template<typename T> struct ExtTypeInfo { @@ -872,16 +856,42 @@ struct ExtTypeInfo { } }; -template<typename T> -inline ExtTypeVTable* ExtTypeVTable::Register_() { - const int code = extension_class_info<T>::code; - static_assert(code != 0, - "require extension_class_info traits to be declared with non-zero code"); - ExtTypeVTable vt; - vt.clone = ExtTypeInfo<T>::clone; - vt.destroy = ExtTypeInfo<T>::destroy; - return ExtTypeVTable::RegisterInternal(code, vt); +template<typename T, typename> +inline TVMRetValue& TVMRetValue::operator=(const T& other) { + TVMExtTypeCapsule cap; + cap.clone = ExtTypeInfo<T>::clone; + cap.destroy = ExtTypeInfo<T>::destroy; + cap.ptr = new T(other); + SwitchToClass<TVMExtTypeCapsule>( + extension_class_info<T>::code, cap); + return *this; +} + +template<typename T, typename> +inline void TVMArgsSetter::operator()(size_t i, const T& value) const { + static_assert(extension_class_info<T>::code != 0, + "Need to have extesion code"); + type_codes_[i] = extension_class_info<T>::code; + exts_[i].clone = ExtTypeInfo<T>::clone; + exts_[i].destroy = ExtTypeInfo<T>::destroy; + exts_[i].ptr = const_cast<T*>(&value); + values_[i].v_handle = &exts_[i]; } + +// Implement Module::GetFunction +// Put implementation in this file so we have seen the PackedFunc +inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { + PackedFunc pf = node_->GetFunction(name, node_); + if (pf != nullptr) return pf; + if (query_imports) { + for (const Module& m : node_->imports_) { + pf = m.node_->GetFunction(name, m.node_); + if (pf != nullptr) return pf; + } + } + return pf; +} + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_PACKED_FUNC_H_ diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index 2edb355fb..526d972f6 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -234,6 +234,31 @@ def list_global_func_names(): return fnames +def extract_ext_funcs(finit): + """ + Extract the extension PackedFuncs from a C module. + + Parameters + ---------- + finit : ctypes function + a ctypes that takes signature of TVMExtensionDeclarer + + Returns + ------- + fdict : dict of str to Function + The extracted functions + """ + fdict = {} + def _list(name, func): + fdict[name] = func + myf = convert_to_tvm_func(_list) + ret = finit(myf.handle) + _ = myf + if ret != 0: + raise RuntimeError("cannot initialize with %s" % finit) + return fdict + + def _get_api(f): flocal = f flocal.is_global = True diff --git a/python/tvm/api.py b/python/tvm/api.py index 7c90b0ec9..66c154bc9 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -8,7 +8,7 @@ from ._ffi.base import string_types from ._ffi.node import register_node, NodeBase from ._ffi.node import convert_to_node as _convert_to_node from ._ffi.function import Function -from ._ffi.function import _init_api, register_func, get_global_func +from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs from ._ffi.function import convert_to_tvm_func as _convert_tvm_func from ._ffi.runtime_ctypes import TVMType from . import _api_internal diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 86d150c08..d868e2e0d 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -23,16 +23,16 @@ from . import target as _target from . import make class DumpIR(object): - """Dump IR for each pass. - With it, you can dump ir just like gcc/llvm. - - How to use: - ----------- - .. code-block:: python + """ + Dump IR for each pass. + With it, you can dump ir just like gcc/llvm. - with tvm.build_config(dump_pass_ir=True) - run() + How to use: + ----------- + .. code-block:: python + with tvm.build_config(dump_pass_ir=True) + run() """ scope_level = 0 def __init__(self): @@ -40,9 +40,9 @@ class DumpIR(object): self._recover_list = [] def decorate(self, func): - ''' decorate the pass function''' + """ decorate the pass function""" def dump(*args, **kwargs): - '''dump function''' + """dump function""" retv = func(*args, **kwargs) if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)): return retv @@ -59,7 +59,7 @@ class DumpIR(object): return dump def decorate_irpass(self): - '''decorate ir_pass and ScheduleOps''' + """decorate ir_pass and ScheduleOps""" self._old_sgpass = schedule.ScheduleOps schedule.ScheduleOps = self.decorate(schedule.ScheduleOps) vset = vars(ir_pass) @@ -71,7 +71,7 @@ class DumpIR(object): vset[k] = self.decorate(v) if isinstance(v, types.FunctionType) else v def decorate_custompass(self): - ''' decorate add_lower_pass pass in BuildConfig''' + """ decorate add_lower_pass pass in BuildConfig""" cfg = BuildConfig.current self._old_custom_pass = cfg.add_lower_pass custom_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] @@ -79,7 +79,7 @@ class DumpIR(object): BuildConfig.current.add_lower_pass = pass_list def enter(self): - '''only decorate outermost nest''' + """only decorate outermost nest""" if DumpIR.scope_level > 0: return self.decorate_irpass() @@ -88,7 +88,7 @@ class DumpIR(object): DumpIR.scope_level += 1 def exit(self): - '''recover outermost nest''' + """recover outermost nest""" if DumpIR.scope_level > 1: return # recover decorated functions @@ -163,6 +163,7 @@ class BuildConfig(NodeBase): "'%s' object cannot set attribute '%s'" % (str(type(self)), name)) return super(BuildConfig, self).__setattr__(name, value) + def build_config(**kwargs): """Configure the build behavior by setting config variables. @@ -226,6 +227,7 @@ def build_config(**kwargs): setattr(config, k, kwargs[k]) return config + if not _RUNTIME_ONLY: # BuildConfig is not available in tvm_runtime BuildConfig.current = build_config() @@ -352,8 +354,10 @@ def lower(sch, stmt = f(stmt) if simple_mode: return stmt + return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) + def build(sch, args=None, target=None, diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 5e3b3e803..9a7005f2e 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -347,6 +347,14 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, API_END(); } +int TVMExtTypeFree(void* handle, int type_code) { + API_BEGIN(); + TVMExtTypeCapsule* cap = static_cast<TVMExtTypeCapsule*>(handle); + cap->destroy(cap->ptr); + delete cap; + API_END(); +} + int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, diff --git a/src/runtime/module.cc b/src/runtime/module.cc index c0796a61a..d5ece6560 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -13,19 +13,6 @@ namespace tvm { namespace runtime { -PackedFunc Module::GetFunction( - const std::string& name, bool query_imports) { - PackedFunc pf = node_->GetFunction(name, node_); - if (pf != nullptr) return pf; - if (query_imports) { - for (const Module& m : node_->imports_) { - pf = m.node_->GetFunction(name, m.node_); - if (pf != nullptr) return pf; - } - } - return pf; -} - void Module::Import(Module other) { // specially handle rpc if (!std::strcmp((*this)->type_key(), "rpc")) { diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index d7587b6ce..563731a60 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -22,15 +22,10 @@ struct Registry::Manager { // and the resource can become invalid because of indeterminstic order of destruction. // The resources will only be recycled during program exit. std::unordered_map<std::string, Registry*> fmap; - // vtable for extension type - std::array<ExtTypeVTable, kExtEnd> ext_vtable; // mutex std::mutex mutex; Manager() { - for (auto& x : ext_vtable) { - x.destroy = nullptr; - } } static Manager* Global() { @@ -88,24 +83,6 @@ std::vector<std::string> Registry::ListNames() { return keys; } -ExtTypeVTable* ExtTypeVTable::Get(int type_code) { - CHECK(type_code > kExtBegin && type_code < kExtEnd); - Registry::Manager* m = Registry::Manager::Global(); - ExtTypeVTable* vt = &(m->ext_vtable[type_code]); - CHECK(vt->destroy != nullptr) - << "Extension type not registered"; - return vt; -} - -ExtTypeVTable* ExtTypeVTable::RegisterInternal( - int type_code, const ExtTypeVTable& vt) { - CHECK(type_code > kExtBegin && type_code < kExtEnd); - Registry::Manager* m = Registry::Manager::Global(); - std::lock_guard<std::mutex>(m->mutex); - ExtTypeVTable* pvt = &(m->ext_vtable[type_code]); - pvt[0] = vt; - return pvt; -} } // namespace runtime } // namespace tvm @@ -120,12 +97,6 @@ struct TVMFuncThreadLocalEntry { /*! \brief Thread local store that can be used to hold return values. */ typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore; -int TVMExtTypeFree(void* handle, int type_code) { - API_BEGIN(); - tvm::runtime::ExtTypeVTable::Get(type_code)->destroy(handle); - API_END(); -} - int TVMFuncRegisterGlobal( const char* name, TVMFunctionHandle f, int override) { API_BEGIN(); diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 00e428f25..8771a04e5 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -126,9 +126,6 @@ struct extension_class_info<test::IntVector> { } // runtime } // tvm -// do registration, this need to be in cc file -TVM_REGISTER_EXT_TYPE(test::IntVector); - TEST(PackedFunc, ExtensionType) { using namespace tvm; using namespace tvm::runtime; diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index d10c9a6b1..7cdade714 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -6,6 +6,7 @@ rm -rf python/tvm/*.pyc python/tvm/*/*.pyc # Test TVM make cython || exit -1 +make cython3 || exit -1 # Test extern package package cd apps/extension diff --git a/topi/src/topi.cc b/topi/src/topi.cc index d6b67c74b..1d73a8fc6 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -54,8 +54,6 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_EXT_TYPE(tvm::Target); - /*! \brief Canonicalize an argument that may be Array<Expr> or int to Array<Expr> */ Array<Expr> ArrayOrInt(TVMArgValue arg) { if (arg.type_code() == kDLInt || arg.type_code() == kDLUInt) { -- GitLab