Unverified Commit 56ab0adb by Tianqi Chen Committed by GitHub

[RUNTIME][PYTHON] Switch to use __new__ for constructing node. (#1644)

parent b95b5958
...@@ -24,7 +24,13 @@ def _return_node(x): ...@@ -24,7 +24,13 @@ def _return_node(x):
handle = NodeHandle(handle) handle = NodeHandle(handle)
tindex = ctypes.c_int() tindex = ctypes.c_int()
check_call(_LIB.TVMNodeGetTypeIndex(handle, ctypes.byref(tindex))) check_call(_LIB.TVMNodeGetTypeIndex(handle, ctypes.byref(tindex)))
return NODE_TYPE.get(tindex.value, NodeBase)(handle) cls = NODE_TYPE.get(tindex.value, NodeBase)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
node = cls.__new__(cls)
node.handle = handle
return node
RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
...@@ -34,16 +40,6 @@ C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func( ...@@ -34,16 +40,6 @@ C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
class NodeBase(object): class NodeBase(object):
__slots__ = ["handle"] __slots__ = ["handle"]
# pylint: disable=no-member # pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
self.handle = handle
def __del__(self): def __del__(self):
if _LIB is not None: if _LIB is not None:
check_call(_LIB.TVMNodeFree(self.handle)) check_call(_LIB.TVMNodeFree(self.handle))
......
...@@ -106,8 +106,8 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ...@@ -106,8 +106,8 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
cdef extern from "tvm/c_dsl_api.h": cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle) int TVMNodeFree(NodeHandle handle)
TVMNodeTypeKey2Index(const char* type_key, int TVMNodeTypeKey2Index(const char* type_key,
int* out_index) int* out_index)
int TVMNodeGetTypeIndex(NodeHandle handle, int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index) int* out_index)
int TVMNodeGetAttr(NodeHandle handle, int TVMNodeGetAttr(NodeHandle handle,
......
from ... import _api_internal
from ..base import string_types from ..base import string_types
from ..node_generic import _set_class_node_base from ..node_generic import _set_class_node_base
...@@ -10,6 +11,7 @@ def _register_node(int index, object cls): ...@@ -10,6 +11,7 @@ def _register_node(int index, object cls):
NODE_TYPE.append(None) NODE_TYPE.append(None)
NODE_TYPE[index] = cls NODE_TYPE[index] = cls
cdef inline object make_ret_node(void* chandle): cdef inline object make_ret_node(void* chandle):
global NODE_TYPE global NODE_TYPE
cdef int tindex cdef int tindex
...@@ -20,14 +22,15 @@ cdef inline object make_ret_node(void* chandle): ...@@ -20,14 +22,15 @@ cdef inline object make_ret_node(void* chandle):
if tindex < len(node_type): if tindex < len(node_type):
cls = node_type[tindex] cls = node_type[tindex]
if cls is not None: if cls is not None:
obj = cls(None) obj = cls.__new__(cls)
else: else:
obj = NodeBase(None) obj = NodeBase.__new__(NodeBase)
else: else:
obj = NodeBase(None) obj = NodeBase.__new__(NodeBase)
(<NodeBase>obj).chandle = chandle (<NodeBase>obj).chandle = chandle
return obj return obj
cdef class NodeBase: cdef class NodeBase:
cdef void* chandle cdef void* chandle
...@@ -49,9 +52,6 @@ cdef class NodeBase: ...@@ -49,9 +52,6 @@ cdef class NodeBase:
def __set__(self, value): def __set__(self, value):
self._set_handle(value) self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
def __dealloc__(self): def __dealloc__(self):
CALL(TVMNodeFree(self.chandle)) CALL(TVMNodeFree(self.chandle))
......
...@@ -21,6 +21,12 @@ except IMPORT_EXCEPT: ...@@ -21,6 +21,12 @@ except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.node import _register_node, NodeBase as _NodeBase from ._ctypes.node import _register_node, NodeBase as _NodeBase
def _new_object(cls):
"""Helper function for pickle"""
return cls.__new__(cls)
class NodeBase(_NodeBase): class NodeBase(_NodeBase):
"""NodeBase is the base class of all TVM language AST object.""" """NodeBase is the base class of all TVM language AST object."""
def __repr__(self): def __repr__(self):
...@@ -46,7 +52,8 @@ class NodeBase(_NodeBase): ...@@ -46,7 +52,8 @@ class NodeBase(_NodeBase):
return not self.__eq__(other) return not self.__eq__(other)
def __reduce__(self): def __reduce__(self):
return (type(self), (None,), self.__getstate__()) cls = type(self)
return (_new_object, (cls, ), self.__getstate__())
def __getstate__(self): def __getstate__(self):
handle = self.handle handle = self.handle
......
...@@ -79,11 +79,13 @@ class Target(NodeBase): ...@@ -79,11 +79,13 @@ class Target(NodeBase):
- :any:`tvm.target.mali` create Mali target - :any:`tvm.target.mali` create Mali target
- :any:`tvm.target.intel_graphics` create Intel Graphics target - :any:`tvm.target.intel_graphics` create Intel Graphics target
""" """
def __init__(self, handle): def __new__(cls):
super(Target, self).__init__(handle) # Always override new to enable class
self._keys = None obj = NodeBase.__new__(cls)
self._options = None obj._keys = None
self._libs = None obj._options = None
obj._libs = None
return obj
@property @property
def keys(self): def keys(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment