module.py 7.16 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
18
"""A global module storing everything needed to interpret or compile a Relay program."""
19
import os
20
from .base import register_relay_node, RelayNode
21
from .. import register_func
22
from .._ffi import base as _base
23
from . import _make
24
from . import _module
25
from . import expr as _expr
26
from . import ty as _ty
27

28 29 30 31 32 33 34
__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")

@register_func("tvm.relay.std_path")
def _std_path():
    global __STD_PATH__
    return __STD_PATH__

35
@register_relay_node
36 37
class Module(RelayNode):
    """The global Relay module containing collection of functions.
38

39
    Each global function is identified by an unique tvm.relay.GlobalVar.
40
    tvm.relay.GlobalVar and Module is necessary in order to enable
41
    recursions in function to avoid cyclic reference in the function.x
42

43 44
    Parameters
    ----------
45
    functions: Optional[dict].
46 47
        Map of global var to Function
    """
48
    def __init__(self, functions=None, type_definitions=None):
49 50 51 52 53 54 55 56 57 58 59
        if functions is None:
            functions = {}
        elif isinstance(functions, dict):
            mapped_funcs = {}
            for k, v in functions.items():
                if isinstance(k, _base.string_types):
                    k = _expr.GlobalVar(k)
                if not isinstance(k, _expr.GlobalVar):
                    raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
                mapped_funcs[k] = v
            functions = mapped_funcs
60 61 62 63 64 65 66 67 68 69 70 71 72
        if type_definitions is None:
            type_definitions = {}
        elif isinstance(type_definitions, dict):
            mapped_type_defs = {}
            for k, v in type_definitions.items():
                if isinstance(k, _base.string_types):
                    k = _ty.GlobalTypeVar(k)
                if not isinstance(k, _ty.GlobalTypeVar):
                    raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
                mapped_type_defs[k] = v
            type_definitions = mapped_type_defs
        self.__init_handle_by_constructor__(_make.Module, functions, type_definitions)

73

74 75
    def __setitem__(self, var, val):
        """Add a mapping to the module.
76 77 78 79

        Parameters
        ---------
        var: GlobalVar
80
            The global variable.
81

82 83
        val: Union[Function, Type]
            The value.
84
        """
85
        return self._add(var, val)
86

87
    def _add(self, var, val, update=False):
88
        if isinstance(val, _expr.Expr):
89
            if isinstance(var, _base.string_types):
90 91 92 93 94
                if _module.Module_ContainGlobalVar(self, var):
                    var = _module.Module_GetGlobalVar(self, var)
                else:
                    var = _expr.GlobalVar(var)
            _module.Module_Add(self, var, val, update)
95 96 97 98 99
        else:
            assert isinstance(val, _ty.Type)
            if isinstance(var, _base.string_types):
                var = _ty.GlobalTypeVar(var)
            _module.Module_AddDef(self, var, val)
100

101
    def __getitem__(self, var):
102
        """Lookup a global definition by name or by variable.
103 104 105

        Parameters
        ----------
106
        var: Union[String, GlobalVar, GlobalTypeVar]
107 108 109 110
            The name or global variable.

        Returns
        -------
111 112
        val: Union[Function, Type]
            The definition referenced by :code:`var` (either a function or type).
113
        """
114
        if isinstance(var, _base.string_types):
115
            return _module.Module_Lookup_str(self, var)
116
        elif isinstance(var, _expr.GlobalVar):
117
            return _module.Module_Lookup(self, var)
118 119
        else:
            return _module.Module_LookupDef(self, var)
120

121
    def update(self, other):
122
        """Insert functions in another Module to current one.
123 124 125

        Parameters
        ----------
126 127
        other: Module
            The module to merge into the current Module.
128
        """
129
        if isinstance(other, dict):
130 131
            other = Module(other)
        return _module.Module_Update(self, other)
132

133 134
    def get_global_var(self, name):
        """Get a global variable in the function by name.
135 136 137

        Parameters
        ----------
138 139
        name: str
            The name of the global variable.
140 141 142

        Returns
        -------
143 144 145 146 147 148
        global_var: GlobalVar
            The global variable mapped to :code:`name`.

        Raises
        ------
        tvm.TVMError if we cannot find corresponding global var.
149
        """
150
        return _module.Module_GetGlobalVar(self, name)
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169

    def get_global_type_var(self, name):
        """Get a global type variable in the function by name.

        Parameters
        ----------
        name: str
            The name of the global type variable.

        Returns
        -------
        global_type_var: GlobalTypeVar
            The global variable mapped to :code:`name`.

        Raises
        ------
        tvm.TVMError if we cannot find corresponding global type var.
        """
        return _module.Module_GetGlobalTypeVar(self, name)
170

171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
    def get_constructor(self, tag):
        """Look up an ADT constructor by tag.

        Parameters
        ----------
        tag: int
            The tag for a constructor.

        Returns
        -------
        constructor: Constructor
           The constructor associated with the given tag,

        Raises
        ------
        tvm.TVMError if the corresponding constructor cannot be found.
        """
        return _module.Module_LookupTag(self, tag)

190
    @staticmethod
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
    def from_expr(expr, functions=None, type_defs=None):
        """Construct a module from a standalone expression.

        Parameters
        ----------
        expr: Expr
            The starting expression
        global_funcs: Optional[dict]
            Map of global vars to function definitions
        type_defs: Optional[dict]
            Map of global type vars to type definitions


        Returns
        -------
        mod: Module
            A module containing the passed definitions,
            where expr is set as the entry point
            (wrapped in a function if necessary)
        """
        funcs = functions if functions is not None else {}
        defs = type_defs if type_defs is not None else {}
        return _module.Module_FromExpr(expr, funcs, defs)
214 215 216 217 218 219

    def _import(self, file_to_import):
        return _module.Module_Import(self, file_to_import)

    def import_from_std(self, file_to_import):
        return _module.Module_ImportFromStd(self, file_to_import)