From 01a7ce0cb6490c1233bc33c33d679e742b178241 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Mon, 23 Jan 2017 21:41:04 -0800 Subject: [PATCH] [RUNTIME] Add Function, Unify TVMTypeCode and TVMArgTypeID (#24) --- include/tvm/c_api.h | 60 +++++----- include/tvm/codegen.h | 5 + include/tvm/ir.h | 8 ++ include/tvm/{ => runtime}/c_runtime_api.h | 129 ++++++++++------------ include/tvm/runtime/runtime.h | 127 +++++++++++++++++++++ python/tvm/__init__.py | 2 +- python/tvm/_api_internal.py | 1 + python/tvm/_ctypes/_api.py | 127 +++++++++++---------- python/tvm/_ctypes/_runtime_api.py | 101 +++++++++-------- python/tvm/_ctypes/_types.py | 72 ++++++++++++ python/tvm/_function_internal.py | 1 - python/tvm/{function.py => api.py} | 26 ++--- python/tvm/collections.py | 14 +-- python/tvm/ndarray.py | 9 +- python/tvm/schedule.py | 20 ++-- python/tvm/tensor.py | 8 +- src/README.md | 3 +- src/c_api/c_api.cc | 76 +++++++------ src/c_api/c_api_codegen.cc | 24 +++- src/c_api/c_api_function.cc | 6 +- src/c_api/c_api_lang.cc | 36 +++--- src/c_api/c_api_registry.h | 115 +++++++++++-------- src/codegen/codegen_c.h | 1 - src/runtime/c_runtime_api.cc | 23 +++- src/runtime/device_api.h | 2 +- src/runtime/runtime_base.h | 2 +- tests/cpp/packed_func_test.cc | 26 +++++ tests/python/test_codegen_cuda.py | 2 +- tests/python/test_runtime_function.py | 17 +++ 29 files changed, 684 insertions(+), 359 deletions(-) rename include/tvm/{ => runtime}/c_runtime_api.h (72%) create mode 100644 include/tvm/runtime/runtime.h create mode 100644 python/tvm/_api_internal.py create mode 100644 python/tvm/_ctypes/_types.py delete mode 100644 python/tvm/_function_internal.py rename python/tvm/{function.py => api.py} (90%) create mode 100644 tests/cpp/packed_func_test.cc create mode 100644 tests/python/test_runtime_function.py diff --git a/include/tvm/c_api.h b/include/tvm/c_api.h index 2e7bb8545..30fd3ea13 100644 --- a/include/tvm/c_api.h +++ b/include/tvm/c_api.h @@ -6,11 +6,11 @@ #ifndef TVM_C_API_H_ #define TVM_C_API_H_ -#include "./c_runtime_api.h" +#include "./runtime/c_runtime_api.h" TVM_EXTERN_C { /*! \brief handle to functions */ -typedef void* APIFunctionHandle; +typedef void* APIFuncHandle; /*! \brief handle to node */ typedef void* NodeHandle; @@ -18,16 +18,18 @@ typedef void* NodeHandle; * \brief List all the node function name * \param out_size The number of functions * \param out_array The array of function names. + * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMListAPIFunctionNames(int *out_size, - const char*** out_array); +TVM_DLL int TVMListAPIFuncNames(int *out_size, + const char*** out_array); /*! * \brief get function handle by name * \param name The name of function * \param handle The returning function handle + * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMGetAPIFunctionHandle(const char* name, - APIFunctionHandle *handle); +TVM_DLL int TVMGetAPIFuncHandle(const char* name, + APIFuncHandle *handle); /*! * \brief Get the detailed information about function. @@ -42,24 +44,26 @@ TVM_DLL int TVMGetAPIFunctionHandle(const char* name, * \param return_type Return type of the function, if any. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMGetAPIFunctionInfo(APIFunctionHandle handle, - const char **real_name, - const char **description, - int *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type); +TVM_DLL int TVMGetAPIFuncInfo(APIFuncHandle handle, + const char **real_name, + const char **description, + int *num_doc_args, + const char ***arg_names, + const char ***arg_type_infos, + const char ***arg_descriptions, + const char **return_type); /*! * \brief Push an argument to the function calling stack. * If push fails, the stack will be reset to empty * - * \param arg number of attributes - * \param type_id The typeid of attributes. + * \param arg The argument + * \param type_code The type_code of argument as in TVMTypeCode + * \return 0 when success, -1 when failure happens + * \note API calls always exchanges with type bits=64, lanes=1 */ -TVM_DLL int TVMAPIPushStack(TVMArg arg, - int type_id); +TVM_DLL int TVMAPIPushStack(TVMValue arg, + int type_code); /*! * \brief call a function by using arguments in the stack. @@ -67,15 +71,18 @@ TVM_DLL int TVMAPIPushStack(TVMArg arg, * * \param handle The function handle * \param ret_val The return value. - * \param ret_typeid the type id of return value. + * \param ret_type_code the type code of return value. + * \return 0 when success, -1 when failure happens + * \note API calls always exchanges with type bits=64, lanes=1 */ -TVM_DLL int TVMAPIFunctionCall(APIFunctionHandle handle, - TVMArg* ret_val, - int* ret_typeid); +TVM_DLL int TVMAPIFuncCall(APIFuncHandle handle, + TVMValue* ret_val, + int* ret_type_code); /*! * \brief free the node handle * \param handle The node handle to be freed. + * \return 0 when success, -1 when failure happens */ TVM_DLL int TVMNodeFree(NodeHandle handle); @@ -84,13 +91,15 @@ TVM_DLL int TVMNodeFree(NodeHandle handle); * \param handle The node handle * \param key The attribute name * \param out_value The attribute value - * \param out_typeid The typeid of the attribute. + * \param out_type_code The type code of the attribute. * \param out_success Whether get is successful. + * \return 0 when success, -1 when failure happens + * \note API calls always exchanges with type bits=64, lanes=1 */ TVM_DLL int TVMNodeGetAttr(NodeHandle handle, const char* key, - TVMArg* out_value, - int* out_typeid, + TVMValue* out_value, + int* out_type_code, int* out_success); /*! @@ -98,6 +107,7 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle, * \param handle The node handle * \param out_size The number of functions * \param out_array The array of function names. + * \return 0 when success, -1 when failure happens */ TVM_DLL int TVMNodeListAttrNames(NodeHandle handle, int *out_size, diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h index b4a15e5a5..e58fab8cc 100644 --- a/include/tvm/codegen.h +++ b/include/tvm/codegen.h @@ -10,6 +10,8 @@ #include "./base.h" #include "./expr.h" #include "./module.h" +#include "./runtime/runtime.h" + namespace tvm { /*! \brief namespace for lowlevel IR pass and codegen */ @@ -62,6 +64,9 @@ Array<Var> UndefinedVars(const LoweredFunc& f); */ Array<LoweredFunc> SplitHostDevice(LoweredFunc func); + +runtime::PackedFunc BuildStackVM(LoweredFunc func); + } // namespace codegen } // namespace tvm diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 067610421..24be3f788 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -78,6 +78,14 @@ constexpr const char* tvm_array_get_field = "tvm_array_get_field"; * } */ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null"; +/*! + * \brief See pesudo code + * + * bool tvm_print(VType value) { + * LOG(INFO) << value; + * } + */ +constexpr const char* tvm_print = "tvm_print"; /*! \brief The field id of each field in array */ enum TVMArrayFieldKind { diff --git a/include/tvm/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h similarity index 72% rename from include/tvm/c_runtime_api.h rename to include/tvm/runtime/c_runtime_api.h index 25b81d80c..34db25539 100644 --- a/include/tvm/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -9,8 +9,8 @@ * So this is a minimum runtime code gluing, and some limited * memory management code to enable quick testing. */ -#ifndef TVM_C_RUNTIME_API_H_ -#define TVM_C_RUNTIME_API_H_ +#ifndef TVM_RUNTIME_C_RUNTIME_API_H_ +#define TVM_RUNTIME_C_RUNTIME_API_H_ #ifdef __cplusplus #define TVM_EXTERN_C extern "C" @@ -38,27 +38,51 @@ TVM_EXTERN_C { typedef uint32_t tvm_index_t; /*! - * \brief union type for arguments and return values - * in both runtime API and TVM API calls + * \brief Union type of values + * being passed through API and function calls. */ typedef union { - long v_long; // NOLINT(*) - double v_double; - const char* v_str; + int64_t v_int64; + double v_float64; void* v_handle; -} TVMArg; + const char* v_str; +} TVMValue; /*! - * \brief The type index in TVM. + * \brief The type code in TVMType + * \note TVMType is used in two places. */ typedef enum { - kNull = 0, - kLong = 1, - kDouble = 2, - kStr = 3, - kNodeHandle = 4, - kArrayHandle = 5 -} TVMArgTypeID; + kInt = 0U, + kUInt = 1U, + kFloat = 2U, + kHandle = 3U, + // The next few fields are extension types + // that is used by TVM API calls. + kNull = 4U, + kNodeHandle = 5U, + kStr = 6U, + kFuncHandle = 7U +} TVMTypeCode; + +/*! + * \brief The data type used in TVM Runtime. + * + * Examples + * - float: type_code = 2, bits = 32, lanes=1 + * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 + * - int8: type_code = 0, bits = 8, lanes=1 + * + * \note Arguments TVM API function always takes bits=64 and lanes=1 + */ +typedef struct { + /*! \brief type code, in TVMTypeCode */ + uint8_t type_code; + /*! \brief number of bits of the type */ + uint8_t bits; + /*! \brief number of lanes, */ + uint16_t lanes; +} TVMType; /*! * \brief The device type @@ -82,29 +106,6 @@ typedef struct { int dev_id; } TVMContext; -/*! \brief The type code in TVMDataType */ -typedef enum { - kInt = 0U, - kUInt = 1U, - kFloat = 2U -} TVMTypeCode; - -/*! - * \brief the data type - * Examples - * - float: type_code = 2, bits = 32, lanes=1 - * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 - * - int8: type_code = 0, bits = 8, lanes=1 - */ -typedef struct { - /*! \brief type code, in TVMTypeCode */ - uint8_t type_code; - /*! \brief number of bits of the type */ - uint8_t bits; - /*! \brief number of lanes, */ - uint16_t lanes; -} TVMDataType; - /*! * \brief Data structure representing a n-dimensional array(tensor). * This is used to pass data specification into TVM. @@ -122,7 +123,7 @@ typedef struct { /*! \brief number of dimensions of the array */ tvm_index_t ndim; /*! \brief The data type flag */ - TVMDataType dtype; + TVMType dtype; /*! \brief The device context this array sits on */ TVMContext ctx; } TVMArray; @@ -191,7 +192,7 @@ TVM_DLL int TVMContextEnabled(TVMContext ctx, */ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, tvm_index_t ndim, - TVMDataType dtype, + TVMType dtype, TVMContext ctx, TVMArrayHandle* out); /*! @@ -217,45 +218,27 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream); /*! - * \brief TVM Function API: Get resource requirement - * - * By default TVM function try not to do internal allocations. - * Instead, TVMFuncRequirement can be called, given the input arguments. - * - * \param func function handle to be launched. - * \param args The arguments - * \param arg_type_ids The type id of the arguments - * \param num_args Number of arguments. - * \param out_workspace_size The workspace size needed to launch this function. - * \param out_workspace_align The alignment requirement of workspace. - * - * \note The data pointer in the arrays is not used by requirement. + * \brief Free the function when it is no longer needed. + * \param func The function handle + * \return whether */ -TVM_DLL int TVMFuncRequirement(TVMFunctionHandle func, - TVMArg* args, - int* arg_type_ids, - int num_args, - size_t* out_workspace_size, - size_t* out_workspace_align); +TVM_DLL int TVMFuncFree(TVMFunctionHandle func); /*! - * \brief TVM Function API: Launch generated function. + * \brief Call a function whose parameters are all packed. * - * \param func function handle to be launched. + * \param func node handle of the function. * \param args The arguments - * \param arg_type_ids The type id of the arguments + * \param type_codes The type codes of the arguments * \param num_args Number of arguments. - * \param stream The stream this function to be launched on. - * \param workspace Additional workspace used to launch this function. * - * \sa TVMFuncRequirement + * \return 0 when success, -1 when failure happens + * \note TVM calls always exchanges with type bits=64, lanes=1 */ -TVM_DLL int TVMFuncLaunch(TVMFunctionHandle func, - TVMArg* args, - int* arg_type_ids, - int num_args, - TVMStreamHandle stream, - TVMArrayHandle workspace); +TVM_DLL int TVMFuncCall(TVMFunctionHandle func, + TVMValue* args, + int* type_codes, + int num_args); } // TVM_EXTERN_C -#endif // TVM_C_RUNTIME_API_H_ +#endif // TVM_RUNTIME_C_RUNTIME_API_H_ diff --git a/include/tvm/runtime/runtime.h b/include/tvm/runtime/runtime.h new file mode 100644 index 000000000..ef53d6c5f --- /dev/null +++ b/include/tvm/runtime/runtime.h @@ -0,0 +1,127 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file runtime.h + * \brief Runtime related c++ class. + */ +#ifndef TVM_RUNTIME_RUNTIME_H_ +#define TVM_RUNTIME_RUNTIME_H_ + +#include <functional> +#include <tuple> +#include "./c_runtime_api.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief Packed function is a runtime function + * whose argument type_codes are erased by packed format. + * + * This is an useful unified interface to call generated functions. + */ +class PackedFunc { + public: + /*! \brief The internal std::function */ + using FType = std::function<void(const TVMValue* args, const int* type_codes, int num_args)>; + PackedFunc() {} + explicit PackedFunc(FType body) : body_(body) {} + /*! + * \brief invoke the packed function by directly passing in arguments. + * \param args Arguments to be passed. + * \tparam Args arguments to be passed. + * \return The first return value. + */ + template<typename... Args> + inline void operator()(Args&& ...args) const; + /*! + * \brief Call the function in packed format. + * \param args The arguments + * \param type_codes The type_codes of the arguments + * \param num_args Number of arguments. + */ + inline void CallPacked(const TVMValue* args, const int* type_codes, int num_args) const; + /*! \return the internal body function */ + inline FType body() const { + return body_; + } + + private: + /*! \brief internal container of packed function */ + FType body_; +}; + +// implementations +inline void PackedFunc::CallPacked( + const TVMValue* args, const int* type_codes, int num_args) const { + body_(args, type_codes, num_args); +} + +template<bool stop, std::size_t I, typename F, typename ...Args> +struct for_each_dispatcher_ { + static inline void run(const std::tuple<Args...>& args, F f) { + f(I, std::get<I>(args)); + for_each_dispatcher_<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f); + } +}; + +template<std::size_t I, typename F, typename ...Args> +struct for_each_dispatcher_<true, I, F, Args...> { + static inline void run(const std::tuple<Args...>& args, F f) {} +}; + +template<typename F, typename ...Args> +inline void for_each(const std::tuple<Args...>& args, F f) { + for_each_dispatcher_<sizeof...(Args) == 0, 0, F, Args...>::run(args, f); +} + +namespace arg_setter { +template<typename T> +inline void Set(TVMValue& arg, int& t, T v); // NOLINT(*) +template<> +inline void Set<double>(TVMValue& arg, int& t, double value) { // NOLINT(*) + arg.v_float64 = value; + t = kFloat; +} +template<> +inline void Set<int>(TVMValue& arg, int& t, int value) { // NOLINT(*) + arg.v_int64 = value; + t = kInt; +} +template<> +inline void Set<long>(TVMValue& arg, int& t, long value) { // NOLINT(*) + arg.v_int64 = value; + t = kInt; +} +template<> +inline void Set<TVMArray*>(TVMValue& arg, int& t, TVMArray* value) { // NOLINT(*) + arg.v_handle = value; + t = kHandle; +} +template<> +inline void Set<void*>(TVMValue& arg, int& t, void* value) { // NOLINT(*) + arg.v_handle = value; + t = kHandle; +} +} // namespace arg_setter + +struct PackedFuncArgSetter { + TVMValue* args; + int* type_codes; + template<typename T> + inline void operator()(size_t i, T v) const { + arg_setter::Set(args[i], type_codes[i], v); + } +}; + +template<typename... Args> +inline void PackedFunc::operator()(Args&& ...args) const { + auto targ = std::make_tuple(std::forward<Args>(args)...); + const int kNumArgs = sizeof...(Args); + TVMValue tvm_args[kNumArgs]; + int tvm_arg_type_ids[kNumArgs]; + for_each(targ, PackedFuncArgSetter{tvm_args, tvm_arg_type_ids}); + body_(tvm_args, tvm_arg_type_ids, kNumArgs); +} +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RUNTIME_H_ diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 91b5abb6c..6729a8bd2 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -16,4 +16,4 @@ from . import ndarray as nd from .ndarray import cpu, gpu, opencl, init_opencl from ._base import TVMError -from .function import * +from .api import * diff --git a/python/tvm/_api_internal.py b/python/tvm/_api_internal.py new file mode 100644 index 000000000..c0301ceea --- /dev/null +++ b/python/tvm/_api_internal.py @@ -0,0 +1 @@ +"""namespace of internal API""" diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py index 4b9d9d493..19d638d1a 100644 --- a/python/tvm/_ctypes/_api.py +++ b/python/tvm/_ctypes/_api.py @@ -11,24 +11,26 @@ from numbers import Number, Integral from .._base import _LIB from .._base import c_str, py_str, string_types from .._base import check_call, ctypes2docstring -from .. import _function_internal - -class TVMArg(ctypes.Union): - """TVMArg in C API""" - _fields_ = [("v_long", ctypes.c_long), - ("v_double", ctypes.c_double), - ("v_str", ctypes.c_char_p), - ("v_handle", ctypes.c_void_p)] +from .. import _api_internal +from . import _runtime_api +from ._types import TVMValue, TypeCode # type definitions -APIFunctionHandle = ctypes.c_void_p +APIFuncHandle = ctypes.c_void_p NodeHandle = ctypes.c_void_p +FunctionHandle = ctypes.c_void_p + +class APIType(object): + """TVMType used in API calls""" + INT = ctypes.c_int(TypeCode.INT) + UINT = ctypes.c_int(TypeCode.UINT) + FLOAT = ctypes.c_int(TypeCode.FLOAT) + HANDLE = ctypes.c_int(TypeCode.HANDLE) + NULL = ctypes.c_int(TypeCode.NULL) + NODE_HANDLE = ctypes.c_int(TypeCode.NODE_HANDLE) + STR = ctypes.c_int(TypeCode.STR) + FUNC_HANDLE = ctypes.c_int(TypeCode.FUNC_HANDLE) -kNull = 0 -kLong = 1 -kDouble = 2 -kStr = 3 -kNodeHandle = 4 NODE_TYPE = { } @@ -37,22 +39,31 @@ def _return_node(x): handle = x.v_handle if not isinstance(handle, NodeHandle): handle = NodeHandle(handle) - ret_val = TVMArg() - ret_typeid = ctypes.c_int() + ret_val = TVMValue() + ret_type_code = ctypes.c_int() ret_success = ctypes.c_int() check_call(_LIB.TVMNodeGetAttr( handle, c_str("type_key"), ctypes.byref(ret_val), - ctypes.byref(ret_typeid), + ctypes.byref(ret_type_code), ctypes.byref(ret_success))) return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle) + +def _return_func(x): + handle = x.v_handle + if not isinstance(handle, FunctionHandle): + handle = FunctionHandle(handle) + return _runtime_api._function_cls(handle) + + RET_SWITCH = { - kNull: lambda x: None, - kLong: lambda x: x.v_long, - kDouble: lambda x: x.v_double, - kStr: lambda x: py_str(x.v_str), - kNodeHandle: _return_node + TypeCode.NULL: lambda x: None, + TypeCode.INT: lambda x: x.v_int64, + TypeCode.FLOAT: lambda x: x.v_float64, + TypeCode.STR: lambda x: py_str(x.v_str), + TypeCode.NODE_HANDLE: _return_node, + TypeCode.FUNC_HANDLE: _return_func } class SliceBase(object): @@ -74,28 +85,28 @@ class NodeBase(object): self.handle = handle def __repr__(self): - return _function_internal._format_str(self) + return _api_internal._format_str(self) def __del__(self): check_call(_LIB.TVMNodeFree(self.handle)) def __getattr__(self, name): - ret_val = TVMArg() - ret_typeid = ctypes.c_int() + ret_val = TVMValue() + ret_type_code = ctypes.c_int() ret_success = ctypes.c_int() check_call(_LIB.TVMNodeGetAttr( self.handle, c_str(name), ctypes.byref(ret_val), - ctypes.byref(ret_typeid), + ctypes.byref(ret_type_code), ctypes.byref(ret_success))) - value = RET_SWITCH[ret_typeid.value](ret_val) + value = RET_SWITCH[ret_type_code.value](ret_val) if not ret_success.value: raise AttributeError( "'%s' object has no attribute '%s'" % (str(type(self)), name)) return value def __hash__(self): - return _function_internal._raw_ptr(self) + return _api_internal._raw_ptr(self) def __eq__(self, other): if not isinstance(other, NodeBase): @@ -121,7 +132,7 @@ class NodeBase(object): def __getstate__(self): handle = self.handle if handle is not None: - return {'handle': _function_internal._save_json(self)} + return {'handle': _api_internal._save_json(self)} else: return {'handle': None} @@ -131,7 +142,7 @@ class NodeBase(object): if handle is not None: json_str = handle _push_arg(json_str) - other = _function_internal._load_json(json_str) + other = _api_internal._load_json(json_str) self.handle = other.handle other.handle = None else: @@ -145,7 +156,7 @@ def const(value, dtype=None): dtype = 'int32' else: dtype = 'float32' - return _function_internal._const(value, dtype) + return _api_internal._const(value, dtype) def convert(value): @@ -154,7 +165,7 @@ def convert(value): return const(value) elif isinstance(value, (list, tuple)): value = [convert(x) for x in value] - return _function_internal._Array(*value) + return _api_internal._Array(*value) elif isinstance(value, dict): vlist = [] for it in value.items(): @@ -162,7 +173,7 @@ def convert(value): raise ValueError("key of map must already been a container type") vlist.append(it[0]) vlist.append(convert(it[1])) - return _function_internal._Map(*vlist) + return _api_internal._Map(*vlist) elif isinstance(value, SliceBase): return value.tensor(*value.indices) else: @@ -172,21 +183,21 @@ def convert(value): def _push_arg(arg): - a = TVMArg() + a = TVMValue() if arg is None: - _LIB.TVMAPIPushStack(a, ctypes.c_int(kNull)) + _LIB.TVMAPIPushStack(a, APIType.NULL) elif isinstance(arg, NodeBase): a.v_handle = arg.handle - _LIB.TVMAPIPushStack(a, ctypes.c_int(kNodeHandle)) - elif isinstance(arg, int): - a.v_long = ctypes.c_long(arg) - _LIB.TVMAPIPushStack(a, ctypes.c_int(kLong)) + _LIB.TVMAPIPushStack(a, APIType.NODE_HANDLE) + elif isinstance(arg, Integral): + a.v_int64 = ctypes.c_int64(arg) + _LIB.TVMAPIPushStack(a, APIType.INT) elif isinstance(arg, Number): a.v_double = ctypes.c_double(arg) - _LIB.TVMAPIPushStack(a, ctypes.c_int(kDouble)) + _LIB.TVMAPIPushStack(a, APIType.FLOAT) elif isinstance(arg, string_types): a.v_str = c_str(arg) - _LIB.TVMAPIPushStack(a, ctypes.c_int(kStr)) + _LIB.TVMAPIPushStack(a, APIType.STR) else: raise TypeError("Don't know how to handle type %s" % type(arg)) @@ -201,7 +212,7 @@ def _make_function(handle, name): arg_descs = ctypes.POINTER(ctypes.c_char_p)() ret_type = ctypes.c_char_p() - check_call(_LIB.TVMGetAPIFunctionInfo( + check_call(_LIB.TVMGetAPIFuncInfo( handle, ctypes.byref(real_name), ctypes.byref(desc), ctypes.byref(num_args), ctypes.byref(arg_names), @@ -214,13 +225,7 @@ def _make_function(handle, name): desc = py_str(desc.value) doc_str = ('%s\n\n' + - '%s\n' + - 'name : string, optional.\n' + - ' Name of the resulting symbol.\n\n' + - 'Returns\n' + - '-------\n' + - 'symbol: Symbol\n' + - ' The result symbol.') + '%s\n') doc_str = doc_str % (desc, param_str) arg_names = [py_str(arg_names[i]) for i in range(num_args.value)] @@ -235,11 +240,11 @@ def _make_function(handle, name): for arg in cargs: _push_arg(arg) - ret_val = TVMArg() - ret_typeid = ctypes.c_int() - check_call(_LIB.TVMAPIFunctionCall( - handle, ctypes.byref(ret_val), ctypes.byref(ret_typeid))) - return RET_SWITCH[ret_typeid.value](ret_val) + ret_val = TVMValue() + ret_type_code = ctypes.c_int() + check_call(_LIB.TVMAPIFuncCall( + handle, ctypes.byref(ret_val), ctypes.byref(ret_type_code))) + return RET_SWITCH[ret_type_code.value](ret_val) func.__name__ = func_name func.__doc__ = doc_str @@ -265,19 +270,19 @@ def register_node(type_key=None): NODE_TYPE[cls.__name__] = cls return cls -def _init_function_module(root_namespace): +def _init_api_module(root_namespace): """List and add all the functions to current module.""" plist = ctypes.POINTER(ctypes.c_char_p)() size = ctypes.c_uint() - check_call(_LIB.TVMListAPIFunctionNames(ctypes.byref(size), - ctypes.byref(plist))) + check_call(_LIB.TVMListAPIFuncNames(ctypes.byref(size), + ctypes.byref(plist))) op_names = [] for i in range(size.value): op_names.append(py_str(plist[i])) - module_obj = sys.modules["%s.function" % root_namespace] - module_internal = sys.modules["%s._function_internal" % root_namespace] + module_obj = sys.modules["%s.api" % root_namespace] + module_internal = sys.modules["%s._api_internal" % root_namespace] namespace_match = { "_make_": sys.modules["%s.make" % root_namespace], "_pass_": sys.modules["%s.ir_pass" % root_namespace], @@ -286,8 +291,8 @@ def _init_function_module(root_namespace): } for name in op_names: - hdl = APIFunctionHandle() - check_call(_LIB.TVMGetAPIFunctionHandle(c_str(name), ctypes.byref(hdl))) + hdl = APIFuncHandle() + check_call(_LIB.TVMGetAPIFuncHandle(c_str(name), ctypes.byref(hdl))) fname = name target_module = module_internal if name.startswith('_') else module_obj for k, v in namespace_match.items(): diff --git a/python/tvm/_ctypes/_runtime_api.py b/python/tvm/_ctypes/_runtime_api.py index 5b3b81904..dc4a64ee6 100644 --- a/python/tvm/_ctypes/_runtime_api.py +++ b/python/tvm/_ctypes/_runtime_api.py @@ -4,16 +4,16 @@ from __future__ import absolute_import as _abs import ctypes +from numbers import Number, Integral import numpy as np from .._base import _LIB -from .._base import c_array, c_str +from .._base import c_array, c_str, string_types from .._base import check_call - +from ._types import TVMValue, TypeCode, TVMType tvm_index_t = ctypes.c_uint32 - class TVMContext(ctypes.Structure): """TVM context strucure.""" _fields_ = [("dev_mask", ctypes.c_int), @@ -72,52 +72,13 @@ def opencl(dev_id=0): return TVMContext(4, dev_id) -class TVMDataType(ctypes.Structure): - """TVM datatype structure""" - _fields_ = [("type_code", ctypes.c_uint8), - ("bits", ctypes.c_uint8), - ("lanes", ctypes.c_uint16)] - CODE2STR = { - 0 : 'int', - 1 : 'uint', - 2 : 'float' - } - def __init__(self, type_str, lanes=1): - super(TVMDataType, self).__init__() - if isinstance(type_str, np.dtype): - type_str = str(type_str) - - if type_str.startswith("int"): - self.type_code = 0 - bits = int(type_str[3:]) - elif type_str.startswith("uint"): - self.type_code = 1 - bits = int(type_str[4:]) - elif type_str.startswith("float"): - self.type_code = 2 - bits = int(type_str[5:]) - else: - raise ValueError("Donot know how to handle type %s" % type_str) - bits = 32 if bits == 0 else bits - if (bits & (bits - 1)) != 0 or bits < 8: - raise ValueError("Donot know how to handle type %s" % type_str) - self.bits = bits - self.lanes = lanes - - def __repr__(self): - x = "%s%d" % (TVMDataType.CODE2STR[self.type_code], self.bits) - if self.lanes != 1: - x += "x%d" % self.lanes - return x - - class TVMArray(ctypes.Structure): - """TVMArg in C API""" + """TVMValue in C API""" _fields_ = [("data", ctypes.c_void_p), ("shape", ctypes.POINTER(tvm_index_t)), ("strides", ctypes.POINTER(tvm_index_t)), ("ndim", tvm_index_t), - ("dtype", TVMDataType), + ("dtype", TVMType), ("ctx", TVMContext)] TVMArrayHandle = ctypes.POINTER(TVMArray) @@ -133,7 +94,7 @@ def numpyasarray(np_data): arr.data = data.ctypes.data_as(ctypes.c_void_p) arr.shape = shape arr.strides = None - arr.dtype = TVMDataType(np.dtype(data.dtype).name) + arr.dtype = TVMType(np.dtype(data.dtype).name) arr.ndim = data.ndim # CPU device arr.ctx = cpu(0) @@ -141,6 +102,7 @@ def numpyasarray(np_data): _ndarray_cls = None +_function_cls = None def empty(shape, dtype="float32", ctx=cpu(0)): @@ -165,7 +127,7 @@ def empty(shape, dtype="float32", ctx=cpu(0)): shape = c_array(tvm_index_t, shape) ndim = tvm_index_t(len(shape)) handle = TVMArrayHandle() - dtype = TVMDataType(dtype) + dtype = TVMType(dtype) check_call(_LIB.TVMArrayAlloc( shape, ndim, dtype, ctx, ctypes.byref(handle))) return _ndarray_cls(handle) @@ -313,6 +275,51 @@ class NDArrayBase(object): return target -def _init_runtime_module(ndarray_class): +class FunctionBase(object): + """A function object at runtim.""" + __slots__ = ["handle"] + # pylint: disable=no-member + def __init__(self, handle): + """Initialize the function with handle + + Parameters + ---------- + handle : FunctionHandle + the handle to the underlying function. + """ + self.handle = handle + + def __del__(self): + check_call(_LIB.TVMFuncFree(self.handle)) + + def __call__(self, *args): + num_args = len(args) + tvm_args = (TVMValue * num_args)() + tvm_type_code = (ctypes.c_int * num_args)() + for i, arg in enumerate(args): + if arg is None: + tvm_args[i].v_handle = None + tvm_type_code[i] = TypeCode.NULL + elif isinstance(arg, NDArrayBase): + tvm_args[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) + tvm_type_code[i] = TypeCode.HANDLE + elif isinstance(arg, Integral): + tvm_args[i].v_int64 = arg + tvm_type_code[i] = TypeCode.INT + elif isinstance(arg, Number): + tvm_args[i].v_float64 = arg + tvm_type_code[i] = TypeCode.FLOAT + elif isinstance(arg, string_types): + tvm_args[i].v_str = c_str(arg) + tvm_type_code[i] = TypeCode.STR + else: + raise TypeError("Don't know how to handle type %s" % type(arg)) + check_call(_LIB.TVMFuncCall( + self.handle, tvm_args, tvm_type_code, ctypes.c_int(num_args))) + + +def _init_runtime_module(ndarray_class, function_class): global _ndarray_cls + global _function_cls _ndarray_cls = ndarray_class + _function_cls = function_class diff --git a/python/tvm/_ctypes/_types.py b/python/tvm/_ctypes/_types.py new file mode 100644 index 000000000..29d78981a --- /dev/null +++ b/python/tvm/_ctypes/_types.py @@ -0,0 +1,72 @@ +"""The C Types used in API.""" +# pylint: disable=invalid-name +from __future__ import absolute_import as _abs + +import ctypes +import numpy as np + +class TVMValue(ctypes.Union): + """TVMValue in C API""" + _fields_ = [("v_int64", ctypes.c_int64), + ("v_float64", ctypes.c_double), + ("v_handle", ctypes.c_void_p), + ("v_str", ctypes.c_char_p)] + +class TypeCode(object): + """Type code used in API calls""" + INT = 0 + UINT = 1 + FLOAT = 2 + HANDLE = 3 + NULL = 4 + NODE_HANDLE = 5 + STR = 6 + FUNC_HANDLE = 7 + +def _api_type(code): + """create a type accepted by API""" + t = TVMType() + t.bits = 64 + t.lanes = 1 + t.type_code = code + return t + + +class TVMType(ctypes.Structure): + """TVM datatype structure""" + _fields_ = [("type_code", ctypes.c_uint8), + ("bits", ctypes.c_uint8), + ("lanes", ctypes.c_uint16)] + CODE2STR = { + 0 : 'int', + 1 : 'uint', + 2 : 'float' + } + def __init__(self, type_str, lanes=1): + super(TVMType, self).__init__() + if isinstance(type_str, np.dtype): + type_str = str(type_str) + + if type_str.startswith("int"): + self.type_code = 0 + bits = int(type_str[3:]) + elif type_str.startswith("uint"): + self.type_code = 1 + bits = int(type_str[4:]) + elif type_str.startswith("float"): + self.type_code = 2 + bits = int(type_str[5:]) + else: + raise ValueError("Donot know how to handle type %s" % type_str) + + bits = 32 if bits == 0 else bits + if (bits & (bits - 1)) != 0 or bits < 8: + raise ValueError("Donot know how to handle type %s" % type_str) + self.bits = bits + self.lanes = lanes + + def __repr__(self): + x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits) + if self.lanes != 1: + x += "x%d" % self.lanes + return x diff --git a/python/tvm/_function_internal.py b/python/tvm/_function_internal.py deleted file mode 100644 index ff00c73eb..000000000 --- a/python/tvm/_function_internal.py +++ /dev/null @@ -1 +0,0 @@ -"""namespace of internal function""" diff --git a/python/tvm/function.py b/python/tvm/api.py similarity index 90% rename from python/tvm/function.py rename to python/tvm/api.py index 72929da80..e537c9321 100644 --- a/python/tvm/function.py +++ b/python/tvm/api.py @@ -3,8 +3,8 @@ """Functions defined in TVM.""" from __future__ import absolute_import as _abs from numbers import Integral as _Integral -from ._ctypes._api import _init_function_module, convert -from . import _function_internal +from ._ctypes._api import _init_api_module, convert +from . import _api_internal from . import make as _make from . import expr as _expr from . import collections as _collections @@ -20,7 +20,7 @@ def const(value, dtype=None): dtype = 'int32' else: dtype = 'float32' - return _function_internal._const(value, dtype) + return _api_internal._const(value, dtype) def load_json(json_str): @@ -36,7 +36,7 @@ def load_json(json_str): node : Node The loaded tvm node. """ - return _function_internal._load_json(json_str) + return _api_internal._load_json(json_str) def save_json(node): @@ -52,7 +52,7 @@ def save_json(node): json_str : str Saved json string. """ - return _function_internal._save_json(node) + return _api_internal._save_json(node) def Var(name="tindex", dtype=int32): @@ -66,7 +66,7 @@ def Var(name="tindex", dtype=int32): dtype : int The data type """ - return _function_internal._Var(name, dtype) + return _api_internal._Var(name, dtype) def placeholder(shape, dtype=None, name="placeholder"): @@ -90,7 +90,7 @@ def placeholder(shape, dtype=None, name="placeholder"): """ shape = (shape,) if isinstance(shape, _expr.Expr) else shape dtype = float32 if dtype is None else dtype - return _function_internal._Placeholder( + return _api_internal._Placeholder( shape, dtype, name) @@ -128,9 +128,9 @@ def compute(shape, fcompute, name="compute"): dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)] body = fcompute(*[v.var for v in dim_var]) body = convert(body) - op_node = _function_internal._ComputeOp( + op_node = _api_internal._ComputeOp( name, dim_var, body) - return _function_internal._Tensor( + return _api_internal._Tensor( shape, body.dtype, op_node, 0) @@ -168,7 +168,7 @@ def Buffer(shape, dtype=None, if ptr is None: ptr = Var(name, "handle") - return _function_internal._Buffer( + return _api_internal._Buffer( name, ptr, shape, strides, dtype) @@ -202,7 +202,7 @@ def IterVar(dom=None, name=None, thread_tag=''): if name is None: name = thread_tag if thread_tag else name name = name if name else 'iter' - return _function_internal._IterVar(dom, name, thread_tag) + return _api_internal._IterVar(dom, name, thread_tag) def sum(expr, rdom): @@ -263,7 +263,7 @@ def Schedule(ops): """ if not isinstance(ops, (list, _collections.Array)): ops = [ops] - return _function_internal._Schedule(ops) + return _api_internal._Schedule(ops) -_init_function_module("tvm") +_init_api_module("tvm") diff --git a/python/tvm/collections.py b/python/tvm/collections.py index 2e43e2e6b..810b726cd 100644 --- a/python/tvm/collections.py +++ b/python/tvm/collections.py @@ -2,7 +2,7 @@ """Collection structure in the high level DSL.""" from __future__ import absolute_import as _abs from ._ctypes._api import NodeBase, register_node -from . import _function_internal +from . import _api_internal from . import expr as _expr @register_node @@ -11,10 +11,10 @@ class Array(NodeBase): def __getitem__(self, i): if i >= len(self): raise IndexError("array index out ot range") - return _function_internal._ArrayGetItem(self, i) + return _api_internal._ArrayGetItem(self, i) def __len__(self): - return _function_internal._ArraySize(self) + return _api_internal._ArraySize(self) def __repr__(self): return '[' + (','.join(str(x) for x in self)) + ']' @@ -23,18 +23,18 @@ class Array(NodeBase): class Map(NodeBase): """Map container of TVM""" def __getitem__(self, k): - return _function_internal._MapGetItem(self, k) + return _api_internal._MapGetItem(self, k) def __contains__(self, k): - return _function_internal._MapCount(self, k) != 0 + return _api_internal._MapCount(self, k) != 0 def items(self): """Get the items from the map""" - akvs = _function_internal._MapItems(self) + akvs = _api_internal._MapItems(self) return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)] def __len__(self): - return _function_internal._MapSize(self) + return _api_internal._MapSize(self) def __repr__(self): return '{' + (", ".join(str(x[0]) + ": " +str(x[1]) for x in self.items())) + '}' diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py index eafc065b1..c3059662a 100644 --- a/python/tvm/ndarray.py +++ b/python/tvm/ndarray.py @@ -6,7 +6,7 @@ This is a simplified runtime API for quick testing and proptyping. from __future__ import absolute_import as _abs import numpy as _np -from ._ctypes._runtime_api import TVMContext, TVMDataType, NDArrayBase +from ._ctypes._runtime_api import TVMContext, TVMType, NDArrayBase, FunctionBase from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync from ._ctypes._runtime_api import _init_runtime_module from ._ctypes._runtime_api import init_opencl @@ -26,6 +26,11 @@ class NDArray(NDArrayBase): pass +class Function(FunctionBase): + """Function class that can executed a generated code.""" + pass + + def array(arr, ctx=cpu(0)): """Create an array from source arr. @@ -49,4 +54,4 @@ def array(arr, ctx=cpu(0)): return ret -_init_runtime_module(NDArray) +_init_runtime_module(NDArray, Function) diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index b276c90a1..93767b4c2 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -2,7 +2,7 @@ """Collection structure in the high level DSL.""" from __future__ import absolute_import as _abs from ._ctypes._api import NodeBase, register_node -from . import _function_internal +from . import _api_internal from . import tensor as _tensor @register_node @@ -56,11 +56,11 @@ class Stage(NodeBase): if outer is not None: if outer.thread_tag == '': raise ValueError("split by outer must have special thread_tag") - inner = _function_internal._StageSplitByOuter(self, parent, outer, factor) + inner = _api_internal._StageSplitByOuter(self, parent, outer, factor) else: if factor is None: raise ValueError("either outer or factor need to be provided") - outer, inner = _function_internal._StageSplitByFactor(self, parent, factor) + outer, inner = _api_internal._StageSplitByFactor(self, parent, factor) return outer, inner def fuse(self, inner, outer): @@ -79,7 +79,7 @@ class Stage(NodeBase): inner : IterVar The fused variable of iteration. """ - return _function_internal._StageFuse(self, inner, outer) + return _api_internal._StageFuse(self, inner, outer) def set_scope(self, scope): """Set the thread scope of this stage @@ -89,7 +89,7 @@ class Stage(NodeBase): scope : str The thread scope of this stage """ - return _function_internal._StageSetScope(self, scope) + return _api_internal._StageSetScope(self, scope) def compute_at(self, parent, scope): """Attach the stage at parent's scope @@ -102,7 +102,7 @@ class Stage(NodeBase): scope : IterVar The loop scope t be attached to. """ - _function_internal._StageComputeAt(self, parent, scope) + _api_internal._StageComputeAt(self, parent, scope) def compute_inline(self): """Mark stage as inline @@ -112,7 +112,7 @@ class Stage(NodeBase): parent : Stage The parent stage """ - _function_internal._StageComputeInline(self) + _api_internal._StageComputeInline(self) def compute_root(self): """Attach the stage at parent, and mark it as root @@ -122,7 +122,7 @@ class Stage(NodeBase): parent : Stage The parent stage """ - _function_internal._StageComputeInline(self) + _api_internal._StageComputeInline(self) def reorder(self, *args): """reorder the arguments in the specified order. @@ -132,7 +132,7 @@ class Stage(NodeBase): args : list of IterVar The order to be ordered """ - _function_internal._StageReorder(self, args) + _api_internal._StageReorder(self, args) def tile(self, x_parent, y_parent, x_factor, y_factor): """ Perform tiling on two dimensions @@ -161,6 +161,6 @@ class Stage(NodeBase): p_y_inner : IterVar Inner axis of y dimension """ - x_outer, y_outer, x_inner, y_inner = _function_internal._StageTile( + x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile( self, x_parent, y_parent, x_factor, y_factor) return x_outer, y_outer, x_inner, y_inner diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index fdaec1d33..51767ee36 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -3,7 +3,7 @@ from __future__ import absolute_import as _abs from ._ctypes._api import NodeBase, SliceBase, register_node, convert from . import collections as _collections -from . import _function_internal +from . import _api_internal from . import make as _make from . import expr as _expr @@ -44,12 +44,12 @@ class Tensor(NodeBase): return TensorSlice(self, indices) def __hash__(self): - return _function_internal._TensorHash(self) + return _api_internal._TensorHash(self) def __eq__(self, other): if not isinstance(other, Tensor): return False - return _function_internal._TensorEqual(self, other) + return _api_internal._TensorEqual(self, other) @property def ndim(self): @@ -72,7 +72,7 @@ class Operation(NodeBase): out : Tensor The i-th output. """ - return _function_internal._OpGetOutput(self, index) + return _api_internal._OpGetOutput(self, index) @register_node class ComputeOp(Operation): diff --git a/src/README.md b/src/README.md index 59cf081d8..652de2778 100644 --- a/src/README.md +++ b/src/README.md @@ -4,4 +4,5 @@ - lang The definition of DSL related data structure - schedule The operations on the schedule graph before converting to IR. - pass The optimization pass on the IR structure -- runtime The runtime related codes. \ No newline at end of file +- runtime Minimum runtime related codes. +- jit JIT runtime related code. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index e6ce2a5a9..7b3605942 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -22,7 +22,7 @@ struct TVMAPIThreadLocalEntry { arg_stack.clear(); ret_value.sptr.reset(); } - inline void SetReturn(ArgVariant* ret_val, int* ret_typeid); + inline void SetReturn(TVMValue* ret_val, int* ret_type_code); }; using namespace tvm; @@ -97,11 +97,11 @@ struct APIAttrDir : public AttrVisitor { } }; -int TVMListAPIFunctionNames(int *out_size, +int TVMListAPIFuncNames(int *out_size, const char*** out_array) { API_BEGIN(); TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); - ret->ret_vec_str = dmlc::Registry<APIFunctionReg>::ListAllNames(); + ret->ret_vec_str = dmlc::Registry<APIFuncReg>::ListAllNames(); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); @@ -111,16 +111,16 @@ int TVMListAPIFunctionNames(int *out_size, API_END(); } -int TVMGetAPIFunctionHandle(const char* fname, - APIFunctionHandle* out) { +int TVMGetAPIFuncHandle(const char* fname, + APIFuncHandle* out) { API_BEGIN(); - const APIFunctionReg* reg = dmlc::Registry<APIFunctionReg>::Find(fname); + const APIFuncReg* reg = dmlc::Registry<APIFuncReg>::Find(fname); CHECK(reg != nullptr) << "cannot find function " << fname; - *out = (APIFunctionHandle)reg; + *out = (APIFuncHandle)reg; API_END(); } -int TVMGetAPIFunctionInfo(APIFunctionHandle handle, +int TVMGetAPIFuncInfo(APIFuncHandle handle, const char **real_name, const char **description, int *num_doc_args, @@ -128,7 +128,7 @@ int TVMGetAPIFunctionInfo(APIFunctionHandle handle, const char ***arg_type_infos, const char ***arg_descriptions, const char **return_type) { - const auto *op = static_cast<const APIFunctionReg *>(handle); + const auto *op = static_cast<const APIFuncReg *>(handle); TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); API_BEGIN(); @@ -152,33 +152,37 @@ int TVMGetAPIFunctionInfo(APIFunctionHandle handle, API_END(); } -int TVMAPIPushStack(ArgVariant arg, - int type_id) { +int TVMAPIPushStack(TVMValue arg, + int type_code) { TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); API_BEGIN(); ret->arg_stack.resize(ret->arg_stack.size() + 1); APIVariantValue& v = ret->arg_stack.back(); - v.type_id = static_cast<ArgVariantID>(type_id); - if (type_id == kStr) { - v.str = arg.v_str; - } else if (type_id == kNodeHandle) { - v.sptr = *static_cast<TVMAPINode*>(arg.v_handle); - } else { - v.v_union = arg; + v.type_code = type_code; + switch (type_code) { + case kInt: case kUInt: case kFloat: case kNull: { + v.v_union = arg; break; + } + case kStr: { + v.str = arg.v_str; break; + } + case kNodeHandle: { + v.sptr = *static_cast<TVMAPINode*>(arg.v_handle); break; + } + default: LOG(FATAL) << "TVM API cannot take type " << TVMTypeCode2Str(type_code); } - API_END_HANDLE_ERROR(ret->Clear()); } -int TVMAPIFunctionCall(APIFunctionHandle handle, - ArgVariant* ret_val, - int* ret_typeid) { +int TVMAPIFuncCall(APIFuncHandle handle, + TVMValue* ret_val, + int* ret_type_code) { TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); API_BEGIN(); - const auto *op = static_cast<const APIFunctionReg *>(handle); + const auto *op = static_cast<const APIFuncReg *>(handle); op->body(ret->arg_stack, &(ret->ret_value)); - ret->SetReturn(ret_val, ret_typeid); + ret->SetReturn(ret_val, ret_type_code); ret->arg_stack.clear(); API_END_HANDLE_ERROR(ret->Clear()); } @@ -191,28 +195,28 @@ int TVMNodeFree(NodeHandle handle) { int TVMNodeGetAttr(NodeHandle handle, const char* key, - ArgVariant* ret_val, - int* ret_typeid, + TVMValue* ret_val, + int* ret_type_code, int* ret_success) { TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); API_BEGIN(); - ret->ret_value.type_id = kNull; + ret->ret_value.type_code = kNull; APIAttrGetter getter; getter.skey = key; getter.ret = &(ret->ret_value); TVMAPINode* tnode = static_cast<TVMAPINode*>(handle); if (getter.skey == "type_key") { ret_val->v_str = (*tnode)->type_key(); - *ret_typeid = kStr; + *ret_type_code = kStr; *ret_success = 1; } else { (*tnode)->VisitAttrs(&getter); - if (ret->ret_value.type_id != kNull) { - ret->SetReturn(ret_val, ret_typeid); + if (ret->ret_value.type_code != kNull) { + ret->SetReturn(ret_val, ret_type_code); *ret_success = 1; } else { *ret_success = getter.found_node_ref ? 1 : 0; - *ret_typeid = kNull; + *ret_type_code = kNull; } } API_END_HANDLE_ERROR(ret->Clear()); @@ -238,16 +242,18 @@ int TVMNodeListAttrNames(NodeHandle handle, } -inline void TVMAPIThreadLocalEntry::SetReturn(ArgVariant* ret_val, - int* ret_typeid) { +inline void TVMAPIThreadLocalEntry::SetReturn(TVMValue* ret_val, + int* ret_type_code) { APIVariantValue& rv = ret_value; - *ret_typeid = rv.type_id; - if (rv.type_id == kNodeHandle) { + *ret_type_code = rv.type_code; + if (rv.type_code == kNodeHandle) { if (rv.sptr.get() != nullptr) { ret_val->v_handle = new TVMAPINode(std::move(rv.sptr)); } else { ret_val->v_handle = nullptr; } + } else if (rv.type_code == kFuncHandle) { + ret_val->v_handle = new runtime::PackedFunc::FType(std::move(rv.func)); } else { *ret_val = rv.v_union; } diff --git a/src/c_api/c_api_codegen.cc b/src/c_api/c_api_codegen.cc index 0fa5973a4..e38a90777 100644 --- a/src/c_api/c_api_codegen.cc +++ b/src/c_api/c_api_codegen.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2016 by Contributors - * Implementation of API functions related to IR build - * \file c_api_ir.cc + * Implementation of API functions related to Codegen + * \file c_api_codegen.cc */ #include <tvm/expr.h> #include <tvm/ir.h> @@ -32,5 +32,25 @@ TVM_REGISTER_API(_codegen_SplitHostDevice) *ret = SplitHostDevice(args.at(0)); }); + +// generate a dummy packed function for testing +void DummyHelloFunction(const TVMValue* args, const int* type_code, int num_args) { + LOG(INFO) << num_args << " arguments"; + for (int i = 0; i < num_args; ++i) { + switch (type_code[i]) { + case kNull: LOG(INFO) << i << ":nullptr"; break; + case kFloat: LOG(INFO) << i << ": double=" << args[i].v_float64; break; + case kInt: LOG(INFO) << i << ": long=" << args[i].v_int64; break; + case kHandle: LOG(INFO) << i << ": handle=" << args[i].v_handle; break; + default: LOG(FATAL) << "unhandled type " << TVMTypeCode2Str(type_code[i]); + } + } +} + +TVM_REGISTER_API(_codegen_DummyHelloFunction) +.set_body([](const ArgStack& args, RetValue *ret) { + *ret = runtime::PackedFunc(DummyHelloFunction); + }); + } // namespace codegen } // namespace tvm diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc index 05a0262cb..f550804ad 100644 --- a/src/c_api/c_api_function.cc +++ b/src/c_api/c_api_function.cc @@ -8,7 +8,7 @@ #include "./c_api_registry.h" namespace dmlc { -DMLC_REGISTRY_ENABLE(::tvm::APIFunctionReg); +DMLC_REGISTRY_ENABLE(::tvm::APIFuncReg); } // namespace dmlc namespace tvm { @@ -18,7 +18,7 @@ using RetValue = APIVariantValue; TVM_REGISTER_API(_format_str) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); + CHECK(args.at(0).type_code == kNodeHandle); std::ostringstream os; os << args.at(0).operator NodeRef(); *ret = os.str(); @@ -27,7 +27,7 @@ TVM_REGISTER_API(_format_str) TVM_REGISTER_API(_raw_ptr) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); + CHECK(args.at(0).type_code == kNodeHandle); *ret = reinterpret_cast<int64_t>(args.at(0).sptr.get()); }) .add_argument("src", "NodeBase", "the node base"); diff --git a/src/c_api/c_api_lang.cc b/src/c_api/c_api_lang.cc index b04b535e8..4119d2fb3 100644 --- a/src/c_api/c_api_lang.cc +++ b/src/c_api/c_api_lang.cc @@ -16,9 +16,9 @@ using RetValue = APIVariantValue; TVM_REGISTER_API(_const) .set_body([](const ArgStack& args, RetValue *ret) { - if (args.at(0).type_id == kLong) { + if (args.at(0).type_code == kInt) { *ret = make_const(args.at(1), args.at(0).operator int64_t()); - } else if (args.at(0).type_id == kDouble) { + } else if (args.at(0).type_code == kFloat) { *ret = make_const(args.at(1), args.at(0).operator double()); } else { LOG(FATAL) << "only accept int or float"; @@ -31,19 +31,19 @@ TVM_REGISTER_API(_Array) .set_body([](const ArgStack& args, RetValue *ret) { std::vector<std::shared_ptr<Node> > data; for (size_t i = 0; i < args.size(); ++i) { - CHECK(args.at(i).type_id == kNodeHandle) + CHECK(args.at(i).type_code == kNodeHandle) << "need content of array to be NodeBase"; data.push_back(args.at(i).sptr); } auto node = std::make_shared<ArrayNode>(); node->data = std::move(data); - ret->type_id = kNodeHandle; + ret->type_code = kNodeHandle; ret->sptr = node; }); TVM_REGISTER_API(_ArrayGetItem) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); + CHECK(args.at(0).type_code == kNodeHandle); int64_t i = args.at(1); auto& sptr = args.at(0).sptr; CHECK(sptr->is_type<ArrayNode>()); @@ -51,12 +51,12 @@ TVM_REGISTER_API(_ArrayGetItem) CHECK_LT(static_cast<size_t>(i), n->data.size()) << "out of bound of array"; ret->sptr = n->data[i]; - ret->type_id = kNodeHandle; + ret->type_code = kNodeHandle; }); TVM_REGISTER_API(_ArraySize) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); + CHECK(args.at(0).type_code == kNodeHandle); auto& sptr = args.at(0).sptr; CHECK(sptr->is_type<ArrayNode>()); *ret = static_cast<int64_t>( @@ -68,21 +68,21 @@ TVM_REGISTER_API(_Map) CHECK_EQ(args.size() % 2, 0U); MapNode::ContainerType data; for (size_t i = 0; i < args.size(); i += 2) { - CHECK(args.at(i).type_id == kNodeHandle) + CHECK(args.at(i).type_code == kNodeHandle) << "need content of array to be NodeBase"; - CHECK(args.at(i + 1).type_id == kNodeHandle) + CHECK(args.at(i + 1).type_code == kNodeHandle) << "need content of array to be NodeBase"; data.emplace(std::make_pair(args.at(i).sptr, args.at(i + 1).sptr)); } auto node = std::make_shared<MapNode>(); node->data = std::move(data); - ret->type_id = kNodeHandle; + ret->type_code = kNodeHandle; ret->sptr = node; }); TVM_REGISTER_API(_MapSize) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); + CHECK(args.at(0).type_code == kNodeHandle); auto& sptr = args.at(0).sptr; CHECK(sptr->is_type<MapNode>()); auto* n = static_cast<const MapNode*>(sptr.get()); @@ -91,8 +91,8 @@ TVM_REGISTER_API(_MapSize) TVM_REGISTER_API(_MapGetItem) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); - CHECK(args.at(1).type_id == kNodeHandle); + CHECK(args.at(0).type_code == kNodeHandle); + CHECK(args.at(1).type_code == kNodeHandle); auto& sptr = args.at(0).sptr; CHECK(sptr->is_type<MapNode>()); auto* n = static_cast<const MapNode*>(sptr.get()); @@ -100,13 +100,13 @@ TVM_REGISTER_API(_MapGetItem) CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; ret->sptr = (*it).second; - ret->type_id = kNodeHandle; + ret->type_code = kNodeHandle; }); TVM_REGISTER_API(_MapCount) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); - CHECK(args.at(1).type_id == kNodeHandle); + CHECK(args.at(0).type_code == kNodeHandle); + CHECK(args.at(1).type_code == kNodeHandle); auto& sptr = args.at(0).sptr; CHECK(sptr->is_type<MapNode>()); auto* n = static_cast<const MapNode*>(sptr.get()); @@ -115,7 +115,7 @@ TVM_REGISTER_API(_MapCount) TVM_REGISTER_API(_MapItems) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); + CHECK(args.at(0).type_code == kNodeHandle); auto& sptr = args.at(0).sptr; CHECK(sptr->is_type<MapNode>()); auto* n = static_cast<const MapNode*>(sptr.get()); @@ -125,7 +125,7 @@ TVM_REGISTER_API(_MapItems) rkvs->data.push_back(kv.second); } ret->sptr = rkvs; - ret->type_id = kNodeHandle; + ret->type_code = kNodeHandle; }); TVM_REGISTER_API(Range) diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 7223ebaee..648c0c84e 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -9,25 +9,25 @@ #include <tvm/base.h> #include <tvm/expr.h> #include <tvm/c_api.h> +#include <tvm/runtime/runtime.h> #include <memory> #include <limits> #include <string> #include <vector> #include "../base/common.h" -using ArgVariant = TVMArg; -using ArgVariantID = TVMArgTypeID; - namespace tvm { -inline const char* TypeId2Str(ArgVariantID type_id) { - switch (type_id) { - case kNull: return "Null"; - case kLong: return "Long"; - case kDouble: return "Double"; - case kStr: return "Str"; +inline const char* TVMTypeCode2Str(int type_code) { + switch (type_code) { + case kInt: return "int"; + case kFloat: return "float"; + case kStr: return "str"; + case kHandle: return "Handle"; + case kNull: return "NULL"; case kNodeHandle: return "NodeHandle"; - default: LOG(FATAL) << "unknown type_id=" << type_id; return ""; + default: LOG(FATAL) << "unknown type_code=" + << static_cast<int>(type_code); return ""; } } @@ -96,72 +96,83 @@ inline std::string NodeTypeName() { class APIVariantValue { public: /*! \brief the type id */ - ArgVariantID type_id{kNull}; + int type_code{kNull}; /*! \brief shared pointer container */ std::shared_ptr<Node> sptr; /*! \brief string container */ std::string str; /*! \brief the variant holder */ - ArgVariant v_union; + TVMValue v_union; + /*! \brief std::function */ + runtime::PackedFunc::FType func; // constructor - APIVariantValue() {} + APIVariantValue() { + } // clear value inline void Clear() { } // assign op inline APIVariantValue& operator=(double value) { - type_id = kDouble; - v_union.v_double = value; + type_code = kFloat; + v_union.v_float64 = value; return *this; } inline APIVariantValue& operator=(std::nullptr_t value) { - type_id = kNull; + type_code = kHandle; + v_union.v_handle = value; return *this; } inline APIVariantValue& operator=(int64_t value) { - type_id = kLong; - v_union.v_long = value; + type_code = kInt; + v_union.v_int64 = value; return *this; } inline APIVariantValue& operator=(bool value) { - type_id = kLong; - v_union.v_long = value; + type_code = kInt; + v_union.v_int64 = value; return *this; } inline APIVariantValue& operator=(std::string value) { - type_id = kStr; + type_code = kStr; str = std::move(value); v_union.v_str = str.c_str(); return *this; } inline APIVariantValue& operator=(const NodeRef& ref) { if (ref.node_.get() == nullptr) { - type_id = kNull; + type_code = kNull; } else { - type_id = kNodeHandle; + type_code = kNodeHandle; this->sptr = ref.node_; } return *this; } + inline APIVariantValue& operator=(const runtime::PackedFunc& f) { + type_code = kFuncHandle; + this->func = f.body(); + return *this; + } inline APIVariantValue& operator=(const Type& value) { return operator=(Type2String(value)); } template<typename T, typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type> inline operator T() const { - if (type_id == kNull) return T(); - CHECK_EQ(type_id, kNodeHandle); + if (type_code == kNull) return T(); + CHECK_EQ(type_code, kNodeHandle); CHECK(NodeTypeChecker<T>::Check(sptr.get())) << "Did not get expected type " << NodeTypeName<T>(); return T(sptr); } inline operator Expr() const { - if (type_id == kNull) return Expr(); - if (type_id == kLong) return Expr(operator int()); - if (type_id == kDouble) { + if (type_code == kNull) { + return Expr(); + } + if (type_code == kInt) return Expr(operator int()); + if (type_code == kFloat) { return Expr(static_cast<float>(operator double())); } - CHECK_EQ(type_id, kNodeHandle); + CHECK_EQ(type_code, kNodeHandle); if (sptr->is_type<IterVarNode>()) { return IterVar(sptr)->var; } else { @@ -171,52 +182,58 @@ class APIVariantValue { } } inline operator double() const { - CHECK_EQ(type_id, kDouble); - return v_union.v_double; + CHECK_EQ(type_code, kFloat); + return v_union.v_float64; } inline operator int64_t() const { - CHECK_EQ(type_id, kLong); - return v_union.v_long; + CHECK_EQ(type_code, kInt); + return v_union.v_int64; } inline operator uint64_t() const { - CHECK_EQ(type_id, kLong); - return v_union.v_long; + CHECK_EQ(type_code, kInt); + return v_union.v_int64; } inline operator int() const { - CHECK_EQ(type_id, kLong); - CHECK_LE(v_union.v_long, + CHECK_EQ(type_code, kInt); + CHECK_LE(v_union.v_int64, std::numeric_limits<int>::max()); - return v_union.v_long; + return v_union.v_int64; } inline operator bool() const { - CHECK_EQ(type_id, kLong) - << "expect boolean(int) but get " << TypeId2Str(type_id); - return v_union.v_long != 0; + CHECK_EQ(type_code, kInt) + << "expect boolean(int) but get " + << TVMTypeCode2Str(type_code); + return v_union.v_int64 != 0; } inline operator std::string() const { - CHECK_EQ(type_id, kStr) - << "expect Str but get " << TypeId2Str(type_id); + CHECK_EQ(type_code, kStr) + << "expect Str but get " + << TVMTypeCode2Str(type_code); return str; } inline operator Type() const { return String2Type(operator std::string()); } + inline operator runtime::PackedFunc() const { + CHECK_EQ(type_code, kFuncHandle); + return runtime::PackedFunc(func); + } }; // common defintiion of API function. -using APIFunction = std::function< +using APIFunc = std::function< void(const std::vector<APIVariantValue> &args, APIVariantValue* ret)>; /*! * \brief Registry entry for DataIterator factory functions. */ -struct APIFunctionReg - : public dmlc::FunctionRegEntryBase<APIFunctionReg, - APIFunction> { +struct APIFuncReg + : public dmlc::FunctionRegEntryBase<APIFuncReg, + APIFunc> { }; -#define TVM_REGISTER_API(TypeName) \ - DMLC_REGISTRY_REGISTER(::tvm::APIFunctionReg, APIFunctionReg, TypeName) \ +#define TVM_REGISTER_API(TypeName) \ + DMLC_REGISTRY_REGISTER(::tvm::APIFuncReg, APIFuncReg, TypeName) \ } // namespace tvm diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 4630e9990..d9392d232 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -7,7 +7,6 @@ #define TVM_CODEGEN_CODEGEN_C_H_ #include <tvm/ir.h> -#include <tvm/ir_visitor.h> #include <tvm/module.h> #include <string> #include <unordered_map> diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index e68790b58..40f4a16e1 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -3,7 +3,8 @@ * \file c_runtime_api.cc * \brief Device specific implementations */ -#include <tvm/c_runtime_api.h> +#include <tvm/runtime/c_runtime_api.h> +#include <tvm/runtime/runtime.h> #include <algorithm> #include "./runtime_base.h" #include "./device_api.h" @@ -34,7 +35,7 @@ inline void TVMArrayFree_(TVMArray* arr) { delete arr; } -inline void VerifyType(TVMDataType dtype) { +inline void VerifyType(TVMType dtype) { CHECK_GE(dtype.lanes, 1U); if (dtype.type_code == kFloat) { CHECK_EQ(dtype.bits % 32U, 0U); @@ -98,7 +99,7 @@ int TVMContextEnabled(TVMContext ctx, int TVMArrayAlloc(const tvm_index_t* shape, tvm_index_t ndim, - TVMDataType dtype, + TVMType dtype, TVMContext ctx, TVMArrayHandle* out) { TVMArray* arr = nullptr; @@ -166,3 +167,19 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) { }); API_END(); } + +int TVMFuncFree(TVMFunctionHandle func) { + API_BEGIN(); + delete static_cast<PackedFunc::FType*>(func); + API_END(); +} + +int TVMFuncCall(TVMFunctionHandle func, + TVMValue* args, + int* arg_type_codes, + int num_args) { + API_BEGIN(); + (*static_cast<const PackedFunc::FType*>(func))( + args, arg_type_codes, num_args); + API_END(); +} diff --git a/src/runtime/device_api.h b/src/runtime/device_api.h index c2b163624..3ef1a7c0e 100644 --- a/src/runtime/device_api.h +++ b/src/runtime/device_api.h @@ -7,7 +7,7 @@ #define TVM_RUNTIME_DEVICE_API_H_ #include <tvm/base.h> -#include <tvm/c_runtime_api.h> +#include <tvm/runtime/c_runtime_api.h> namespace tvm { namespace runtime { diff --git a/src/runtime/runtime_base.h b/src/runtime/runtime_base.h index 1a3233342..31fb8ede2 100644 --- a/src/runtime/runtime_base.h +++ b/src/runtime/runtime_base.h @@ -6,7 +6,7 @@ #ifndef TVM_RUNTIME_RUNTIME_BASE_H_ #define TVM_RUNTIME_RUNTIME_BASE_H_ -#include <tvm/c_runtime_api.h> +#include <tvm/runtime/c_runtime_api.h> #include <stdexcept> /*! \brief macro to guard beginning and end section of all functions */ diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc new file mode 100644 index 000000000..0bc31ef56 --- /dev/null +++ b/tests/cpp/packed_func_test.cc @@ -0,0 +1,26 @@ +#include <dmlc/logging.h> +#include <gtest/gtest.h> +#include <tvm/runtime/runtime.h> + +TEST(PackedFunc, Basic) { + using namespace tvm::runtime; + int x = 0; + void* handle = &x; + TVMArray a; + + PackedFunc([&](const TVMValue* args, const int* type_codes, int num_args) { + CHECK(num_args == 3); + CHECK(args[0].v_float64 == 1.0); + CHECK(type_codes[0] == kFloat); + CHECK(args[1].v_handle == &a); + CHECK(type_codes[1] == kHandle); + CHECK(args[2].v_handle == &x); + CHECK(type_codes[2] == kHandle); + })(1.0, &a, handle); +} + +int main(int argc, char ** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/test_codegen_cuda.py b/tests/python/test_codegen_cuda.py index dc20dda36..1c23122ca 100644 --- a/tests/python/test_codegen_cuda.py +++ b/tests/python/test_codegen_cuda.py @@ -15,11 +15,11 @@ def mock_test_add(): thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") _, x = s[C].split(C.op.axis[0], factor=num_thread, outer=grid_x) _, x = s[C].split(x, outer=thread_x) + # compile to IR bounds = tvm.schedule.InferBound(s) stmt = tvm.ir_pass.ScheduleOps(s, bounds) - Ab = tvm.Buffer(A.shape, A.dtype, name='A') Bb = tvm.Buffer(B.shape, B.dtype, name='B') Cb = tvm.Buffer(C.shape, C.dtype, name='C') diff --git a/tests/python/test_runtime_function.py b/tests/python/test_runtime_function.py new file mode 100644 index 000000000..3c16c7f31 --- /dev/null +++ b/tests/python/test_runtime_function.py @@ -0,0 +1,17 @@ +import tvm +import numpy as np + + + +def test_function(): + ctx = tvm.cpu(0) + x = np.random.randint(0, 10, size=(3, 4)) + x = np.array(x) + y = tvm.nd.array(x, ctx=ctx) + + f = tvm.codegen.DummyHelloFunction() + f(y, 10) + + +if __name__ == "__main__": + test_function() -- GitLab