Unverified Commit 56ab0adb by Tianqi Chen Committed by GitHub

[RUNTIME][PYTHON] Switch to use __new__ for constructing node. (#1644)

parent b95b5958
......@@ -24,7 +24,13 @@ def _return_node(x):
handle = NodeHandle(handle)
tindex = ctypes.c_int()
check_call(_LIB.TVMNodeGetTypeIndex(handle, ctypes.byref(tindex)))
return NODE_TYPE.get(tindex.value, NodeBase)(handle)
cls = NODE_TYPE.get(tindex.value, NodeBase)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
node = cls.__new__(cls)
node.handle = handle
return node
RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
......@@ -34,16 +40,6 @@ C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
class NodeBase(object):
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
self.handle = handle
def __del__(self):
if _LIB is not None:
check_call(_LIB.TVMNodeFree(self.handle))
......
......@@ -106,8 +106,8 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle)
TVMNodeTypeKey2Index(const char* type_key,
int* out_index)
int TVMNodeTypeKey2Index(const char* type_key,
int* out_index)
int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index)
int TVMNodeGetAttr(NodeHandle handle,
......
from ... import _api_internal
from ..base import string_types
from ..node_generic import _set_class_node_base
......@@ -10,6 +11,7 @@ def _register_node(int index, object cls):
NODE_TYPE.append(None)
NODE_TYPE[index] = cls
cdef inline object make_ret_node(void* chandle):
global NODE_TYPE
cdef int tindex
......@@ -20,14 +22,15 @@ cdef inline object make_ret_node(void* chandle):
if tindex < len(node_type):
cls = node_type[tindex]
if cls is not None:
obj = cls(None)
obj = cls.__new__(cls)
else:
obj = NodeBase(None)
obj = NodeBase.__new__(NodeBase)
else:
obj = NodeBase(None)
obj = NodeBase.__new__(NodeBase)
(<NodeBase>obj).chandle = chandle
return obj
cdef class NodeBase:
cdef void* chandle
......@@ -49,9 +52,6 @@ cdef class NodeBase:
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
def __dealloc__(self):
CALL(TVMNodeFree(self.chandle))
......
......@@ -21,6 +21,12 @@ except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.node import _register_node, NodeBase as _NodeBase
def _new_object(cls):
"""Helper function for pickle"""
return cls.__new__(cls)
class NodeBase(_NodeBase):
"""NodeBase is the base class of all TVM language AST object."""
def __repr__(self):
......@@ -46,7 +52,8 @@ class NodeBase(_NodeBase):
return not self.__eq__(other)
def __reduce__(self):
return (type(self), (None,), self.__getstate__())
cls = type(self)
return (_new_object, (cls, ), self.__getstate__())
def __getstate__(self):
handle = self.handle
......
......@@ -79,11 +79,13 @@ class Target(NodeBase):
- :any:`tvm.target.mali` create Mali target
- :any:`tvm.target.intel_graphics` create Intel Graphics target
"""
def __init__(self, handle):
super(Target, self).__init__(handle)
self._keys = None
self._options = None
self._libs = None
def __new__(cls):
# Always override new to enable class
obj = NodeBase.__new__(cls)
obj._keys = None
obj._options = None
obj._libs = None
return obj
@property
def keys(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment