node.py 2.8 KB
Newer Older
1 2
"""Node namespace"""
# pylint: disable=unused-import
3 4 5
from __future__ import absolute_import

import ctypes
6
import sys
7
from .. import _api_internal
8
from .node_generic import NodeGeneric, convert_to_node, const
9
from .base import _LIB, check_call, c_str, py_str, _FFI_MODE
10 11 12 13 14 15 16 17 18 19 20 21 22

IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
    # pylint: disable=wrong-import-position
    if _FFI_MODE == "ctypes":
        raise ImportError()
    if sys.version_info >= (3, 0):
        from ._cy3.core import _register_node, NodeBase as _NodeBase
    else:
        from ._cy2.core import _register_node, NodeBase as _NodeBase
except IMPORT_EXCEPT:
    # pylint: disable=wrong-import-position
    from ._ctypes.node import _register_node, NodeBase as _NodeBase
23

24
class NodeBase(_NodeBase):
25
    """NodeBase is the base class of all TVM language AST object."""
26 27 28
    def __repr__(self):
        return _api_internal._format_str(self)

29 30 31 32 33 34 35 36 37
    def __dir__(self):
        plist = ctypes.POINTER(ctypes.c_char_p)()
        size = ctypes.c_uint()
        check_call(_LIB.TVMNodeListAttrNames(
            self.handle, ctypes.byref(size), ctypes.byref(plist)))
        names = []
        for i in range(size.value):
            names.append(py_str(plist[i]))
        return names
38 39 40 41 42

    def __hash__(self):
        return _api_internal._raw_ptr(self)

    def __eq__(self, other):
43
        return self.same_as(other)
44 45 46 47 48 49 50 51 52 53 54

    def __ne__(self, other):
        return not self.__eq__(other)

    def __reduce__(self):
        return (type(self), (None,), self.__getstate__())

    def __getstate__(self):
        handle = self.handle
        if handle is not None:
            return {'handle': _api_internal._save_json(self)}
55
        return {'handle': None}
56 57 58 59 60 61 62 63 64 65 66 67

    def __setstate__(self, state):
        # pylint: disable=assigning-non-slot
        handle = state['handle']
        if handle is not None:
            json_str = handle
            other = _api_internal._load_json(json_str)
            self.handle = other.handle
            other.handle = None
        else:
            self.handle = None

68 69 70 71 72 73
    def same_as(self, other):
        """check object identity equality"""
        if not isinstance(other, NodeBase):
            return False
        return self.__hash__() == other.__hash__()

74 75 76 77 78 79 80 81 82

def register_node(type_key=None):
    """register node type

    Parameters
    ----------
    type_key : str or cls
        The type key of the node
    """
83 84 85 86
    node_name = type_key if isinstance(type_key, str) else type_key.__name__

    def register(cls):
        """internal register function"""
87
        tindex = ctypes.c_int()
88 89
        ret = _LIB.TVMNodeTypeKey2Index(c_str(node_name), ctypes.byref(tindex))
        if ret == 0:
90
            _register_node(tindex.value, cls)
91 92
        return cls

93 94
    if isinstance(type_key, str):
        return register
95
    return register(type_key)