module.py 5.69 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
18
"""A global module storing everything needed to interpret or compile a Relay program."""
19
from .base import register_relay_node, RelayNode
20
from .._ffi import base as _base
21
from . import _make
22
from . import _module
23
from . import expr as _expr
24
from . import ty as _ty
25 26

@register_relay_node
27 28
class Module(RelayNode):
    """The global Relay module containing collection of functions.
29

30
    Each global function is identified by an unique tvm.relay.GlobalVar.
31
    tvm.relay.GlobalVar and Module is necessary in order to enable
32
    recursions in function to avoid cyclic reference in the function.x
33

34 35 36 37 38
    Parameters
    ----------
    functions : dict, optional.
        Map of global var to Function
    """
39
    def __init__(self, functions=None, type_definitions=None):
40 41 42 43 44 45 46 47 48 49 50
        if functions is None:
            functions = {}
        elif isinstance(functions, dict):
            mapped_funcs = {}
            for k, v in functions.items():
                if isinstance(k, _base.string_types):
                    k = _expr.GlobalVar(k)
                if not isinstance(k, _expr.GlobalVar):
                    raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
                mapped_funcs[k] = v
            functions = mapped_funcs
51 52 53 54 55 56 57 58 59 60 61 62 63
        if type_definitions is None:
            type_definitions = {}
        elif isinstance(type_definitions, dict):
            mapped_type_defs = {}
            for k, v in type_definitions.items():
                if isinstance(k, _base.string_types):
                    k = _ty.GlobalTypeVar(k)
                if not isinstance(k, _ty.GlobalTypeVar):
                    raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
                mapped_type_defs[k] = v
            type_definitions = mapped_type_defs
        self.__init_handle_by_constructor__(_make.Module, functions, type_definitions)

64

65 66
    def __setitem__(self, var, val):
        """Add a mapping to the module.
67 68 69 70

        Parameters
        ---------
        var: GlobalVar
71
            The global variable.
72

73 74
        val: Union[Function, Type]
            The value.
75
        """
76
        return self._add(var, val)
77

78
    def _add(self, var, val, update=False):
79
        if isinstance(val, _expr.Expr):
80 81
            if isinstance(var, _base.string_types):
                var = _expr.GlobalVar(var)
82 83 84 85 86 87 88 89 90

            # TODO(@jroesch): Port this logic to C++.
            if not isinstance(val, _expr.Function):
                if isinstance(val, _expr.GlobalVar):
                    val = ir_pass.eta_expand(val, self)
                else:
                    val = _expr.Function([], val)


91 92 93 94 95 96
            _make.Module_Add(self, var, val, update)
        else:
            assert isinstance(val, _ty.Type)
            if isinstance(var, _base.string_types):
                var = _ty.GlobalTypeVar(var)
            _module.Module_AddDef(self, var, val)
97

98
    def __getitem__(self, var):
99
        """Lookup a global definition by name or by variable.
100 101 102

        Parameters
        ----------
103 104 105 106 107
        var: str or GlobalVar
            The name or global variable.

        Returns
        -------
108 109
        val: Union[Function, Type]
            The definition referenced by :code:`var` (either a function or type).
110
        """
111
        if isinstance(var, _base.string_types):
112
            return _module.Module_Lookup_str(self, var)
113
        elif isinstance(var, _expr.GlobalVar):
114
            return _module.Module_Lookup(self, var)
115 116
        else:
            return _module.Module_LookupDef(self, var)
117

118
    def update(self, other):
119
        """Insert functions in another Module to current one.
120 121 122

        Parameters
        ----------
123 124
        other: Module
            The module to merge into the current Module.
125
        """
126
        if isinstance(other, dict):
127 128
            other = Module(other)
        return _module.Module_Update(self, other)
129

130 131
    def get_global_var(self, name):
        """Get a global variable in the function by name.
132 133 134

        Parameters
        ----------
135 136
        name: str
            The name of the global variable.
137 138 139

        Returns
        -------
140 141 142 143 144 145
        global_var: GlobalVar
            The global variable mapped to :code:`name`.

        Raises
        ------
        tvm.TVMError if we cannot find corresponding global var.
146
        """
147
        return _module.Module_GetGlobalVar(self, name)
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166

    def get_global_type_var(self, name):
        """Get a global type variable in the function by name.

        Parameters
        ----------
        name: str
            The name of the global type variable.

        Returns
        -------
        global_type_var: GlobalTypeVar
            The global variable mapped to :code:`name`.

        Raises
        ------
        tvm.TVMError if we cannot find corresponding global type var.
        """
        return _module.Module_GetGlobalTypeVar(self, name)
167 168 169 170

    @staticmethod
    def from_expr(expr):
        return _module.Module_FromExpr(expr)