base.py 3.43 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19
# pylint: disable=no-else-return, unidiomatic-typecheck
"""The base node types for the Relay language."""
from __future__ import absolute_import as _abs
20 21
from .._ffi.object import register_object as _register_tvm_node
from .._ffi.object import Object
22
from . import _make
23
from . import _expr
24
from . import _base
25

26
Object = Object
27 28

def register_relay_node(type_key=None):
29
    """Register a Relay node type.
30 31 32 33

    Parameters
    ----------
    type_key : str or cls
34
        The type key of the node.
35 36 37 38 39 40 41
    """
    if not isinstance(type_key, str):
        return _register_tvm_node(
            "relay." + type_key.__name__)(type_key)
    return _register_tvm_node(type_key)


42
def register_relay_attr_node(type_key=None):
43
    """Register a Relay attribute node.
44 45 46 47

    Parameters
    ----------
    type_key : str or cls
48
        The type key of the node.
49 50 51 52 53 54 55
    """
    if not isinstance(type_key, str):
        return _register_tvm_node(
            "relay.attrs." + type_key.__name__)(type_key)
    return _register_tvm_node(type_key)


56
class RelayNode(Object):
57
    """Base class of all Relay nodes."""
58
    def astext(self, show_meta_data=True, annotate=None):
59 60
        """Get the text format of the expression.

61 62
        Parameters
        ----------
63 64 65 66
        show_meta_data : bool
            Whether to include meta data section in the text
            if there is meta data.

67 68 69
        annotate: Optional[relay.Expr->str]
            Optional annotate function to provide additional
            information in the comment block.
70 71 72

        Note
        ----
73
        The meta data section is necessary to fully parse the text format.
Zhi committed
74
        However, it can contain dumps that are big (e.g constant weights),
75
        so it can be helpful to skip printing the meta data section.
76 77 78 79 80

        Returns
        -------
        text : str
            The text format of the expression.
81
        """
82
        return _expr.AsText(self, show_meta_data, annotate)
83

84 85 86
    def set_span(self, span):
        _base.set_span(self, span)

Zhi committed
87 88 89
    def __str__(self):
        return self.astext(show_meta_data=False)

90

91
@register_relay_node
92
class Span(RelayNode):
93 94
    """Specifies a location in a source program."""

95 96
    def __init__(self, source, lineno, col_offset):
        self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
97

98 99 100 101 102 103
@register_relay_node
class SourceName(RelayNode):
    """A identifier for a source location"""

    def __init__(self, name):
        self.__init_handle_by_constructor__(_make.SourceName, name)
104 105

@register_relay_node
106
class Id(Object):
107 108 109
    """Unique identifier(name) used in Var.
       Guaranteed to be stable across all passes.
    """
110 111
    def __init__(self):
        raise RuntimeError("Cannot directly construct Id")