base.py 3.41 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21
# pylint: disable=no-else-return, unidiomatic-typecheck
"""The base node types for the Relay language."""
from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, register_node as _register_tvm_node
from . import _make
22
from . import _expr
23
from . import _base
24 25 26 27

NodeBase = NodeBase

def register_relay_node(type_key=None):
28
    """Register a Relay node type.
29 30 31 32

    Parameters
    ----------
    type_key : str or cls
33
        The type key of the node.
34 35 36 37 38 39 40
    """
    if not isinstance(type_key, str):
        return _register_tvm_node(
            "relay." + type_key.__name__)(type_key)
    return _register_tvm_node(type_key)


41
def register_relay_attr_node(type_key=None):
42
    """Register a Relay attribute node.
43 44 45 46

    Parameters
    ----------
    type_key : str or cls
47
        The type key of the node.
48 49 50 51 52 53 54
    """
    if not isinstance(type_key, str):
        return _register_tvm_node(
            "relay.attrs." + type_key.__name__)(type_key)
    return _register_tvm_node(type_key)


55
class RelayNode(NodeBase):
56
    """Base class of all Relay nodes."""
57
    def astext(self, show_meta_data=True, annotate=None):
58 59
        """Get the text format of the expression.

60 61
        Parameters
        ----------
62 63 64 65
        show_meta_data : bool
            Whether to include meta data section in the text
            if there is meta data.

66 67 68
        annotate: Optional[relay.Expr->str]
            Optional annotate function to provide additional
            information in the comment block.
69 70 71

        Note
        ----
72
        The meta data section is necessary to fully parse the text format.
Zhi committed
73
        However, it can contain dumps that are big (e.g constant weights),
74
        so it can be helpful to skip printing the meta data section.
75 76 77 78 79

        Returns
        -------
        text : str
            The text format of the expression.
80
        """
81
        return _expr.AsText(self, show_meta_data, annotate)
82

83 84 85
    def set_span(self, span):
        _base.set_span(self, span)

Zhi committed
86 87 88
    def __str__(self):
        return self.astext(show_meta_data=False)

89

90
@register_relay_node
91
class Span(RelayNode):
92 93
    """Specifies a location in a source program."""

94 95
    def __init__(self, source, lineno, col_offset):
        self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
96

97 98 99 100 101 102
@register_relay_node
class SourceName(RelayNode):
    """A identifier for a source location"""

    def __init__(self, name):
        self.__init_handle_by_constructor__(_make.SourceName, name)
103 104 105

@register_relay_node
class Id(NodeBase):
106 107 108
    """Unique identifier(name) used in Var.
       Guaranteed to be stable across all passes.
    """
109 110
    def __init__(self):
        raise RuntimeError("Cannot directly construct Id")