module.py 7.23 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19
"""IRModule that holds the functions and type definitions."""
from tvm._ffi.base import string_types
import tvm._ffi
20

21 22 23 24
from .base import Node
from . import expr as _expr
from . import type as _ty
from . import _ffi_api
25 26


27 28 29
@tvm._ffi.register_object("relay.Module")
class IRModule(Node):
    """IRModule that holds functions and type definitions.
30

31
    IRModule is the basic unit for all IR transformations across the stack.
32

33 34
    Parameters
    ----------
35
    functions: Optional[dict].
36
        Map of global var to BaseFunc
37
    """
38
    def __init__(self, functions=None, type_definitions=None):
39 40 41 42 43
        if functions is None:
            functions = {}
        elif isinstance(functions, dict):
            mapped_funcs = {}
            for k, v in functions.items():
44
                if isinstance(k, string_types):
45 46 47 48 49
                    k = _expr.GlobalVar(k)
                if not isinstance(k, _expr.GlobalVar):
                    raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
                mapped_funcs[k] = v
            functions = mapped_funcs
50 51 52 53 54
        if type_definitions is None:
            type_definitions = {}
        elif isinstance(type_definitions, dict):
            mapped_type_defs = {}
            for k, v in type_definitions.items():
55
                if isinstance(k, string_types):
56 57 58 59 60
                    k = _ty.GlobalTypeVar(k)
                if not isinstance(k, _ty.GlobalTypeVar):
                    raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
                mapped_type_defs[k] = v
            type_definitions = mapped_type_defs
61
        self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)
62

63

64 65
    def __setitem__(self, var, val):
        """Add a mapping to the module.
66 67 68 69

        Parameters
        ---------
        var: GlobalVar
70
            The global variable.
71

72 73
        val: Union[Function, Type]
            The value.
74
        """
75
        return self._add(var, val)
76

77
    def _add(self, var, val, update=False):
78 79 80 81
        if isinstance(val, _expr.RelayExpr):
            if isinstance(var, string_types):
                if _ffi_api.Module_ContainGlobalVar(self, var):
                    var = _ffi_api.Module_GetGlobalVar(self, var)
82 83
                else:
                    var = _expr.GlobalVar(var)
84
            _ffi_api.Module_Add(self, var, val, update)
85 86
        else:
            assert isinstance(val, _ty.Type)
87
            if isinstance(var, string_types):
88
                var = _ty.GlobalTypeVar(var)
89
            _ffi_api.Module_AddDef(self, var, val, update)
90

91
    def __getitem__(self, var):
92
        """Lookup a global definition by name or by variable.
93 94 95

        Parameters
        ----------
96
        var: Union[String, GlobalVar, GlobalTypeVar]
97 98 99 100
            The name or global variable.

        Returns
        -------
101 102
        val: Union[Function, Type]
            The definition referenced by :code:`var` (either a function or type).
103
        """
104 105 106 107 108
        if isinstance(var, string_types):
            return _ffi_api.Module_Lookup_str(self, var)
        if isinstance(var, _expr.GlobalVar):
            return _ffi_api.Module_Lookup(self, var)
        return _ffi_api.Module_LookupDef(self, var)
109

110
    def update(self, other):
111
        """Insert functions in another Module to current one.
112 113 114

        Parameters
        ----------
115
        other: IRModule
116
            The module to merge into the current Module.
117
        """
118
        if isinstance(other, dict):
119
            other = Module(other)
120
        return _ffi_api.Module_Update(self, other)
121

122 123
    def get_global_var(self, name):
        """Get a global variable in the function by name.
124 125 126

        Parameters
        ----------
127 128
        name: str
            The name of the global variable.
129 130 131

        Returns
        -------
132 133 134 135 136
        global_var: GlobalVar
            The global variable mapped to :code:`name`.

        Raises
        ------
137
        tvm.error.TVMError if we cannot find corresponding global var.
138
        """
139
        return _ffi_api.Module_GetGlobalVar(self, name)
140

141 142 143 144 145
    def get_global_vars(self):
        """Collect all global vars defined in this module.

        Returns
        -------
146
        global_vars: Array[GlobalVar]
147 148
            An array of global vars.
        """
149
        return _ffi_api.Module_GetGlobalVars(self)
150 151 152 153 154 155

    def get_global_type_vars(self):
        """Collect all global type vars defined in this module.

        Returns
        -------
156
        global_type_vars: Array[GlobalTypeVar]
157 158
            An array of global type vars.
        """
159
        return _ffi_api.Module_GetGlobalTypeVars(self)
160

161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
    def get_global_type_var(self, name):
        """Get a global type variable in the function by name.

        Parameters
        ----------
        name: str
            The name of the global type variable.

        Returns
        -------
        global_type_var: GlobalTypeVar
            The global variable mapped to :code:`name`.

        Raises
        ------
176
        tvm.error.TVMError if we cannot find corresponding global type var.
177
        """
178
        return _ffi_api.Module_GetGlobalTypeVar(self, name)
179

180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
    def get_constructor(self, tag):
        """Look up an ADT constructor by tag.

        Parameters
        ----------
        tag: int
            The tag for a constructor.

        Returns
        -------
        constructor: Constructor
           The constructor associated with the given tag,

        Raises
        ------
195
        tvm.error.TVMError if the corresponding constructor cannot be found.
196
        """
197
        return _ffi_api.Module_LookupTag(self, tag)
198

199
    @staticmethod
200 201 202 203 204
    def from_expr(expr, functions=None, type_defs=None):
        """Construct a module from a standalone expression.

        Parameters
        ----------
205
        expr: RelayExpr
206
            The starting expression
207

208 209
        global_funcs: Optional[dict]
            Map of global vars to function definitions
210

211 212 213 214 215 216 217 218 219 220 221 222
        type_defs: Optional[dict]
            Map of global type vars to type definitions

        Returns
        -------
        mod: Module
            A module containing the passed definitions,
            where expr is set as the entry point
            (wrapped in a function if necessary)
        """
        funcs = functions if functions is not None else {}
        defs = type_defs if type_defs is not None else {}
223
        return _ffi_api.Module_FromExpr(expr, funcs, defs)
224 225

    def _import(self, file_to_import):
226
        return _ffi_api.Module_Import(self, file_to_import)
227 228

    def import_from_std(self, file_to_import):
229
        return _ffi_api.Module_ImportFromStd(self, file_to_import)