Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
56ab0adb
Unverified
Commit
56ab0adb
authored
Aug 23, 2018
by
Tianqi Chen
Committed by
GitHub
Aug 23, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RUNTIME][PYTHON] Switch to use __new__ for constructing node. (#1644)
parent
b95b5958
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
30 additions
and
25 deletions
+30
-25
python/tvm/_ffi/_ctypes/node.py
+7
-11
python/tvm/_ffi/_cython/base.pxi
+2
-2
python/tvm/_ffi/_cython/node.pxi
+6
-6
python/tvm/_ffi/node.py
+8
-1
python/tvm/target.py
+7
-5
No files found.
python/tvm/_ffi/_ctypes/node.py
View file @
56ab0adb
...
@@ -24,7 +24,13 @@ def _return_node(x):
...
@@ -24,7 +24,13 @@ def _return_node(x):
handle
=
NodeHandle
(
handle
)
handle
=
NodeHandle
(
handle
)
tindex
=
ctypes
.
c_int
()
tindex
=
ctypes
.
c_int
()
check_call
(
_LIB
.
TVMNodeGetTypeIndex
(
handle
,
ctypes
.
byref
(
tindex
)))
check_call
(
_LIB
.
TVMNodeGetTypeIndex
(
handle
,
ctypes
.
byref
(
tindex
)))
return
NODE_TYPE
.
get
(
tindex
.
value
,
NodeBase
)(
handle
)
cls
=
NODE_TYPE
.
get
(
tindex
.
value
,
NodeBase
)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
node
=
cls
.
__new__
(
cls
)
node
.
handle
=
handle
return
node
RETURN_SWITCH
[
TypeCode
.
NODE_HANDLE
]
=
_return_node
RETURN_SWITCH
[
TypeCode
.
NODE_HANDLE
]
=
_return_node
C_TO_PY_ARG_SWITCH
[
TypeCode
.
NODE_HANDLE
]
=
_wrap_arg_func
(
C_TO_PY_ARG_SWITCH
[
TypeCode
.
NODE_HANDLE
]
=
_wrap_arg_func
(
...
@@ -34,16 +40,6 @@ C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
...
@@ -34,16 +40,6 @@ C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
class
NodeBase
(
object
):
class
NodeBase
(
object
):
__slots__
=
[
"handle"
]
__slots__
=
[
"handle"
]
# pylint: disable=no-member
# pylint: disable=no-member
def
__init__
(
self
,
handle
):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
self
.
handle
=
handle
def
__del__
(
self
):
def
__del__
(
self
):
if
_LIB
is
not
None
:
if
_LIB
is
not
None
:
check_call
(
_LIB
.
TVMNodeFree
(
self
.
handle
))
check_call
(
_LIB
.
TVMNodeFree
(
self
.
handle
))
...
...
python/tvm/_ffi/_cython/base.pxi
View file @
56ab0adb
...
@@ -106,8 +106,8 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
...
@@ -106,8 +106,8 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
cdef extern from "tvm/c_dsl_api.h":
cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle)
int TVMNodeFree(NodeHandle handle)
TVMNodeTypeKey2Index(const char* type_key,
int
TVMNodeTypeKey2Index(const char* type_key,
int* out_index)
int* out_index)
int TVMNodeGetTypeIndex(NodeHandle handle,
int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index)
int* out_index)
int TVMNodeGetAttr(NodeHandle handle,
int TVMNodeGetAttr(NodeHandle handle,
...
...
python/tvm/_ffi/_cython/node.pxi
View file @
56ab0adb
from ... import _api_internal
from ..base import string_types
from ..base import string_types
from ..node_generic import _set_class_node_base
from ..node_generic import _set_class_node_base
...
@@ -10,6 +11,7 @@ def _register_node(int index, object cls):
...
@@ -10,6 +11,7 @@ def _register_node(int index, object cls):
NODE_TYPE.append(None)
NODE_TYPE.append(None)
NODE_TYPE[index] = cls
NODE_TYPE[index] = cls
cdef inline object make_ret_node(void* chandle):
cdef inline object make_ret_node(void* chandle):
global NODE_TYPE
global NODE_TYPE
cdef int tindex
cdef int tindex
...
@@ -20,14 +22,15 @@ cdef inline object make_ret_node(void* chandle):
...
@@ -20,14 +22,15 @@ cdef inline object make_ret_node(void* chandle):
if tindex < len(node_type):
if tindex < len(node_type):
cls = node_type[tindex]
cls = node_type[tindex]
if cls is not None:
if cls is not None:
obj = cls
(None
)
obj = cls
.__new__(cls
)
else:
else:
obj = NodeBase
(Non
e)
obj = NodeBase
.__new__(NodeBas
e)
else:
else:
obj = NodeBase
(Non
e)
obj = NodeBase
.__new__(NodeBas
e)
(<NodeBase>obj).chandle = chandle
(<NodeBase>obj).chandle = chandle
return obj
return obj
cdef class NodeBase:
cdef class NodeBase:
cdef void* chandle
cdef void* chandle
...
@@ -49,9 +52,6 @@ cdef class NodeBase:
...
@@ -49,9 +52,6 @@ cdef class NodeBase:
def __set__(self, value):
def __set__(self, value):
self._set_handle(value)
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
def __dealloc__(self):
def __dealloc__(self):
CALL(TVMNodeFree(self.chandle))
CALL(TVMNodeFree(self.chandle))
...
...
python/tvm/_ffi/node.py
View file @
56ab0adb
...
@@ -21,6 +21,12 @@ except IMPORT_EXCEPT:
...
@@ -21,6 +21,12 @@ except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
# pylint: disable=wrong-import-position
from
._ctypes.node
import
_register_node
,
NodeBase
as
_NodeBase
from
._ctypes.node
import
_register_node
,
NodeBase
as
_NodeBase
def
_new_object
(
cls
):
"""Helper function for pickle"""
return
cls
.
__new__
(
cls
)
class
NodeBase
(
_NodeBase
):
class
NodeBase
(
_NodeBase
):
"""NodeBase is the base class of all TVM language AST object."""
"""NodeBase is the base class of all TVM language AST object."""
def
__repr__
(
self
):
def
__repr__
(
self
):
...
@@ -46,7 +52,8 @@ class NodeBase(_NodeBase):
...
@@ -46,7 +52,8 @@ class NodeBase(_NodeBase):
return
not
self
.
__eq__
(
other
)
return
not
self
.
__eq__
(
other
)
def
__reduce__
(
self
):
def
__reduce__
(
self
):
return
(
type
(
self
),
(
None
,),
self
.
__getstate__
())
cls
=
type
(
self
)
return
(
_new_object
,
(
cls
,
),
self
.
__getstate__
())
def
__getstate__
(
self
):
def
__getstate__
(
self
):
handle
=
self
.
handle
handle
=
self
.
handle
...
...
python/tvm/target.py
View file @
56ab0adb
...
@@ -79,11 +79,13 @@ class Target(NodeBase):
...
@@ -79,11 +79,13 @@ class Target(NodeBase):
- :any:`tvm.target.mali` create Mali target
- :any:`tvm.target.mali` create Mali target
- :any:`tvm.target.intel_graphics` create Intel Graphics target
- :any:`tvm.target.intel_graphics` create Intel Graphics target
"""
"""
def
__init__
(
self
,
handle
):
def
__new__
(
cls
):
super
(
Target
,
self
)
.
__init__
(
handle
)
# Always override new to enable class
self
.
_keys
=
None
obj
=
NodeBase
.
__new__
(
cls
)
self
.
_options
=
None
obj
.
_keys
=
None
self
.
_libs
=
None
obj
.
_options
=
None
obj
.
_libs
=
None
return
obj
@property
@property
def
keys
(
self
):
def
keys
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment