node.py 2.76 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18
"""Node namespace"""
# pylint: disable=unused-import
19 20 21
from __future__ import absolute_import

import ctypes
22
import sys
23
from .. import _api_internal
24
from .object import Object, register_object, _set_class_node
25
from .node_generic import NodeGeneric, convert_to_node, const
26

27 28 29 30 31 32

def _new_object(cls):
    """Helper function for pickle"""
    return cls.__new__(cls)


33
class NodeBase(Object):
34
    """NodeBase is the base class of all TVM language AST object."""
35 36 37
    def __repr__(self):
        return _api_internal._format_str(self)

38
    def __dir__(self):
39 40 41 42 43 44 45 46 47 48
        fnames = _api_internal._NodeListAttrNames(self)
        size = fnames(-1)
        return [fnames(i) for i in range(size)]

    def __getattr__(self, name):
        try:
            return _api_internal._NodeGetAttr(self, name)
        except AttributeError:
            raise AttributeError(
                "%s has no attribute %s" % (str(type(self)), name))
49 50 51 52 53

    def __hash__(self):
        return _api_internal._raw_ptr(self)

    def __eq__(self, other):
54
        return self.same_as(other)
55 56 57 58 59

    def __ne__(self, other):
        return not self.__eq__(other)

    def __reduce__(self):
60 61
        cls = type(self)
        return (_new_object, (cls, ), self.__getstate__())
62 63 64 65 66

    def __getstate__(self):
        handle = self.handle
        if handle is not None:
            return {'handle': _api_internal._save_json(self)}
67
        return {'handle': None}
68 69 70 71 72 73 74 75 76 77 78 79

    def __setstate__(self, state):
        # pylint: disable=assigning-non-slot
        handle = state['handle']
        if handle is not None:
            json_str = handle
            other = _api_internal._load_json(json_str)
            self.handle = other.handle
            other.handle = None
        else:
            self.handle = None

80 81 82 83 84 85
    def same_as(self, other):
        """check object identity equality"""
        if not isinstance(other, NodeBase):
            return False
        return self.__hash__() == other.__hash__()

86

87 88 89
# pylint: disable=invalid-name
register_node = register_object
_set_class_node(NodeBase)