node.py 2.92 KB
Newer Older
1 2
"""Node namespace"""
# pylint: disable=unused-import
3 4 5
from __future__ import absolute_import

import ctypes
6
import sys
7
from .. import _api_internal
8
from .node_generic import NodeGeneric, convert_to_node, const
9
from .base import _LIB, check_call, c_str, py_str, _FFI_MODE
10 11 12 13 14 15 16 17 18 19 20 21 22

IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
    # pylint: disable=wrong-import-position
    if _FFI_MODE == "ctypes":
        raise ImportError()
    if sys.version_info >= (3, 0):
        from ._cy3.core import _register_node, NodeBase as _NodeBase
    else:
        from ._cy2.core import _register_node, NodeBase as _NodeBase
except IMPORT_EXCEPT:
    # pylint: disable=wrong-import-position
    from ._ctypes.node import _register_node, NodeBase as _NodeBase
23

24 25 26 27 28 29

def _new_object(cls):
    """Helper function for pickle"""
    return cls.__new__(cls)


30
class NodeBase(_NodeBase):
31
    """NodeBase is the base class of all TVM language AST object."""
32 33 34
    def __repr__(self):
        return _api_internal._format_str(self)

35 36 37 38 39 40 41 42 43
    def __dir__(self):
        plist = ctypes.POINTER(ctypes.c_char_p)()
        size = ctypes.c_uint()
        check_call(_LIB.TVMNodeListAttrNames(
            self.handle, ctypes.byref(size), ctypes.byref(plist)))
        names = []
        for i in range(size.value):
            names.append(py_str(plist[i]))
        return names
44 45 46 47 48

    def __hash__(self):
        return _api_internal._raw_ptr(self)

    def __eq__(self, other):
49
        return self.same_as(other)
50 51 52 53 54

    def __ne__(self, other):
        return not self.__eq__(other)

    def __reduce__(self):
55 56
        cls = type(self)
        return (_new_object, (cls, ), self.__getstate__())
57 58 59 60 61

    def __getstate__(self):
        handle = self.handle
        if handle is not None:
            return {'handle': _api_internal._save_json(self)}
62
        return {'handle': None}
63 64 65 66 67 68 69 70 71 72 73 74

    def __setstate__(self, state):
        # pylint: disable=assigning-non-slot
        handle = state['handle']
        if handle is not None:
            json_str = handle
            other = _api_internal._load_json(json_str)
            self.handle = other.handle
            other.handle = None
        else:
            self.handle = None

75 76 77 78 79 80
    def same_as(self, other):
        """check object identity equality"""
        if not isinstance(other, NodeBase):
            return False
        return self.__hash__() == other.__hash__()

81 82 83 84 85 86 87 88 89

def register_node(type_key=None):
    """register node type

    Parameters
    ----------
    type_key : str or cls
        The type key of the node
    """
90 91 92 93
    node_name = type_key if isinstance(type_key, str) else type_key.__name__

    def register(cls):
        """internal register function"""
94
        tindex = ctypes.c_int()
95 96
        ret = _LIB.TVMNodeTypeKey2Index(c_str(node_name), ctypes.byref(tindex))
        if ret == 0:
97
            _register_node(tindex.value, cls)
98 99
        return cls

100 101
    if isinstance(type_key, str):
        return register
102
    return register(type_key)