diff --git a/python/tvm/_ffi/_ctypes/node.py b/python/tvm/_ffi/_ctypes/node.py index 01244519532b1d10957ca9d9342ea9a60d20bbd7..925aa93f8f968d45d6be235e50b959bc89bd59c2 100644 --- a/python/tvm/_ffi/_ctypes/node.py +++ b/python/tvm/_ffi/_ctypes/node.py @@ -24,7 +24,13 @@ def _return_node(x): handle = NodeHandle(handle) tindex = ctypes.c_int() check_call(_LIB.TVMNodeGetTypeIndex(handle, ctypes.byref(tindex))) - return NODE_TYPE.get(tindex.value, NodeBase)(handle) + cls = NODE_TYPE.get(tindex.value, NodeBase) + # Avoid calling __init__ of cls, instead directly call __new__ + # This allows child class to implement their own __init__ + node = cls.__new__(cls) + node.handle = handle + return node + RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func( @@ -34,16 +40,6 @@ C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func( class NodeBase(object): __slots__ = ["handle"] # pylint: disable=no-member - def __init__(self, handle): - """Initialize the function with handle - - Parameters - ---------- - handle : SymbolHandle - the handle to the underlying C++ Symbol - """ - self.handle = handle - def __del__(self): if _LIB is not None: check_call(_LIB.TVMNodeFree(self.handle)) diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 00173c431bb7b539f03547c2624d2ab0e4e560cf..ac5532835c477634bdc22b866d512252328c3156 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -106,8 +106,8 @@ cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "tvm/c_dsl_api.h": int TVMNodeFree(NodeHandle handle) - TVMNodeTypeKey2Index(const char* type_key, - int* out_index) + int TVMNodeTypeKey2Index(const char* type_key, + int* out_index) int TVMNodeGetTypeIndex(NodeHandle handle, int* out_index) int TVMNodeGetAttr(NodeHandle handle, diff --git a/python/tvm/_ffi/_cython/node.pxi b/python/tvm/_ffi/_cython/node.pxi index a563af5237f90e94bb82cec2c076a31dfca52b41..1ced48878803d7e09fbd3fb6b068b1417a56facf 100644 --- a/python/tvm/_ffi/_cython/node.pxi +++ b/python/tvm/_ffi/_cython/node.pxi @@ -1,3 +1,4 @@ +from ... import _api_internal from ..base import string_types from ..node_generic import _set_class_node_base @@ -10,6 +11,7 @@ def _register_node(int index, object cls): NODE_TYPE.append(None) NODE_TYPE[index] = cls + cdef inline object make_ret_node(void* chandle): global NODE_TYPE cdef int tindex @@ -20,14 +22,15 @@ cdef inline object make_ret_node(void* chandle): if tindex < len(node_type): cls = node_type[tindex] if cls is not None: - obj = cls(None) + obj = cls.__new__(cls) else: - obj = NodeBase(None) + obj = NodeBase.__new__(NodeBase) else: - obj = NodeBase(None) + obj = NodeBase.__new__(NodeBase) (<NodeBase>obj).chandle = chandle return obj + cdef class NodeBase: cdef void* chandle @@ -49,9 +52,6 @@ cdef class NodeBase: def __set__(self, value): self._set_handle(value) - def __init__(self, handle): - self._set_handle(handle) - def __dealloc__(self): CALL(TVMNodeFree(self.chandle)) diff --git a/python/tvm/_ffi/node.py b/python/tvm/_ffi/node.py index d9e7397ae71fcc588381b4c7830fb273df4a8261..98ece19f77f2048e582e5ff13f473784112c0582 100644 --- a/python/tvm/_ffi/node.py +++ b/python/tvm/_ffi/node.py @@ -21,6 +21,12 @@ except IMPORT_EXCEPT: # pylint: disable=wrong-import-position from ._ctypes.node import _register_node, NodeBase as _NodeBase + +def _new_object(cls): + """Helper function for pickle""" + return cls.__new__(cls) + + class NodeBase(_NodeBase): """NodeBase is the base class of all TVM language AST object.""" def __repr__(self): @@ -46,7 +52,8 @@ class NodeBase(_NodeBase): return not self.__eq__(other) def __reduce__(self): - return (type(self), (None,), self.__getstate__()) + cls = type(self) + return (_new_object, (cls, ), self.__getstate__()) def __getstate__(self): handle = self.handle diff --git a/python/tvm/target.py b/python/tvm/target.py index 40f9e099b3a6c0e5b21b9f540136ca80de827403..07200058a021a963264ca7324c0e6b8d3be86e42 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -79,11 +79,13 @@ class Target(NodeBase): - :any:`tvm.target.mali` create Mali target - :any:`tvm.target.intel_graphics` create Intel Graphics target """ - def __init__(self, handle): - super(Target, self).__init__(handle) - self._keys = None - self._options = None - self._libs = None + def __new__(cls): + # Always override new to enable class + obj = NodeBase.__new__(cls) + obj._keys = None + obj._options = None + obj._libs = None + return obj @property def keys(self):