analysis.py 8.07 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
# pylint: disable=no-else-return
18
# pylint: disable=unidiomatic-typecheck
Zhi committed
19
"""
20 21
This file contains the set of passes for Relay, which exposes an interface for
configuring the passes and scripting them in Python.
22
"""
23
from tvm.ir import IRModule
24

25
from . import _ffi_api
26
from .feature import Feature
27

Zhi committed
28

ziheng committed
29 30 31 32 33 34 35 36 37
def post_order_visit(expr, fvisit):
    """Recursively visit the ir in post DFS order node,
    apply fvisit. Each node is guaranteed to be visited
    only once.

    Parameters
    ----------
    expr : tvm.relay.Expr
        The input expression.
38

ziheng committed
39 40 41
    fvisit : function
        The visitor function to be applied.
    """
42
    return _ffi_api.post_order_visit(expr, fvisit)
43 44


45
def well_formed(expr):
46
    """Check that each Var is only bound once (well formed).
47

48 49
    Parameters
    ----------
雾雨魔理沙 committed
50
    expr : tvm.relay.Expr
51
        The input expression
52 53 54 55

    Returns
    -------
    well_form : bool
56
        Whether the input expression is well formed
57
    """
58
    return _ffi_api.well_formed(expr)
59

60

61
def check_kind(t, mod=None):
62
    """Check that the type is well kinded and return the kind.
Zhi committed
63 64
    For example, this mean type cannot has tensor of tensor, or is a tuple type
    of 2 shapes.
65 66 67

    Parameters
    ----------
68
    t : tvm.relay.Type
69
        The type to check
70

71
    mod : Optional[tvm.IRModule]
72
        The global module.
73 74 75

    Returns
    -------
76 77
    kind : Kind
        the kind of t
78 79 80 81 82

    Examples
    --------
    .. code:: python

83 84
        assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) == Shape
        assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type
85
    """
86
    if mod is not None:
87
        return _ffi_api.check_kind(t, mod)
88
    else:
89
        return _ffi_api.check_kind(t)
90

91

92 93 94 95 96 97 98 99 100 101 102 103 104
def check_constant(expr):
    """Check whether an expression is constant

    Parameters
    ----------
    expr : tvm.relay.Expr
        The input expression

    Returns
    -------
    result : bool
        Whether the expression is constant.
    """
105
    return _ffi_api.check_constant(expr)
106 107


108 109
def free_vars(expr):
    """Get free Vars from expression expr in Post DFS order.
110

111 112
    Parameters
    ----------
雾雨魔理沙 committed
113
    expr : tvm.relay.Expr
114
        The input expression
115

116 117
    Returns
    -------
118
    free : List[tvm.relay.Var]
119 120 121 122 123 124 125
        The list of free variables in post DFS order.

    Note
    ----
    The fact that Vars are post-DFS ordred are useful in
    neural networks: usually this means weights of previous
    are ordered first.
126
    """
127
    return _ffi_api.free_vars(expr)
128

129

130 131 132 133 134
def bound_vars(expr):
    """Get bound vars from expression expr in post-DFS order.

    Parameters
    ----------
雾雨魔理沙 committed
135
    expr : tvm.relay.Expr
136 137 138 139 140 141 142
        The input expression

    Returns
    -------
    free : List[tvm.relay.Var]
        The list of bound variables in post-DFS order.
    """
143
    return _ffi_api.bound_vars(expr)
144 145 146 147 148 149 150


def all_vars(expr):
    """Get all vars from expression expr in post-DFS order.

    Parameters
    ----------
雾雨魔理沙 committed
151
    expr : tvm.relay.Expr
152 153 154 155 156 157 158
        The input expression

    Returns
    -------
    free : List[tvm.relay.Var]
        The list of all variables in post-DFS order.
    """
159
    return _ffi_api.all_vars(expr)
160 161


162
def free_type_vars(expr, mod=None):
163 164 165 166
    """Get free type variables from expression/type e

    Parameters
    ----------
雾雨魔理沙 committed
167
    expr : Union[tvm.relay.Expr,tvm.relay.Type]
168
        The input expression/type
雾雨魔理沙 committed
169

170
    mod : Optional[tvm.IRModule]
171
        The global module
172 173 174

    Returns
    -------
175 176
    free : List[tvm.relay.TypeVar]
        The list of free type variables in post-DFS order
177
    """
178
    use_mod = mod if mod is not None else IRModule()
179
    return _ffi_api.free_type_vars(expr, use_mod)
180

181

182
def bound_type_vars(expr, mod=None):
183 184 185 186
    """Get bound type variables from expression/type e

    Parameters
    ----------
雾雨魔理沙 committed
187
    expr : Union[tvm.relay.Expr,tvm.relay.Type]
188
        The input expression/type
雾雨魔理沙 committed
189

190
    mod : Optional[tvm.IRModule]
191
        The global module
192 193 194 195 196 197

    Returns
    -------
    free : List[tvm.relay.TypeVar]
        The list of bound type variables in post-DFS order
    """
198
    use_mod = mod if mod is not None else IRModule()
199
    return _ffi_api.bound_type_vars(expr, use_mod)
200 201


202
def all_type_vars(expr, mod=None):
203 204 205 206
    """Get all type variables from expression/type e

    Parameters
    ----------
雾雨魔理沙 committed
207
    expr : Union[tvm.relay.Expr,tvm.relay.Type]
208
        The input expression/type
Zhi committed
209

210
    mod : Optional[tvm.IRModule]
211
        The global module
212 213 214 215 216 217

    Returns
    -------
    free : List[tvm.relay.TypeVar]
        The list of all type variables in post-DFS order
    """
218
    use_mod = mod if mod is not None else IRModule()
219
    return _ffi_api.all_type_vars(expr, use_mod)
220 221


222 223 224 225 226 227 228 229 230 231 232
def collect_device_info(expr):
    """Collect the device allocation map for the given expression. The device
    ids are propagated from the `device_copy` operators.

    Parameters
    ----------
    expr : tvm.relay.Expr
        The input expression.

    Returns
    -------
233
    ret : Dict[tvm.relay.ir.expr, int]
234 235
        A dictionary mapping tvm.relay.Expr to device type.
    """
236
    return _ffi_api.CollectDeviceInfo(expr)
237 238 239 240 241 242 243 244 245 246 247 248


def collect_device_annotation_ops(expr):
    """Collect the device annotation ops for the given expression.

    Parameters
    ----------
    expr : tvm.relay.Expr
        The input expression.

    Returns
    -------
249
    ret : Dict[tvm.relay.Expr, int]
250 251 252
        A dictionary mapping tvm.relay.Expr to device type where the keys are
        annotation expressions.
    """
253
    return _ffi_api.CollectDeviceAnnotationOps(expr)
254

255 256 257 258 259 260 261 262 263 264 265 266

def get_total_mac_number(expr):
    """
    Count the number of MACs (multiply-accumulate) of a model

    Parameters
    ----------
    expr : tvm.relay.Expr
        The input expression.

    Returns
    -------
雾雨魔理沙 committed
267
    result : int64
268 269
      The number of MACs (multiply-accumulate) of a model
    """
270
    return _ffi_api.GetTotalMacNumber(expr)
271

272

273 274 275 276 277 278 279 280
def unmatched_cases(match, mod=None):
    """
    Finds cases that the match expression does not catch, if any.

    Parameters
    ----------
    match : tvm.relay.Match
        The match expression
Zhi committed
281

282
    mod : Optional[tvm.IRModule]
283 284 285 286 287
        The module (defaults to an empty module)

    Returns
    -------
    missing_patterns : [tvm.relay.Pattern]
Zhi committed
288
        Patterns that the match expression does not catch.
289
    """
290
    return _ffi_api.unmatched_cases(match, mod)
291 292 293 294 295 296 297 298


def detect_feature(a, b=None):
    """
    Detect the feature used in a relay program.

    Parameters
    ----------
299
    a : Union[tvm.relay.Expr, tvm.IRModule]
300 301
      The input expression or module.

302
    b : Optional[Union[tvm.relay.Expr, tvm.IRModule]]
303 304 305 306 307 308 309 310
      The input expression or module.
      The two arguments cannot both be expression or module.

    Returns
    -------
    features : Set[Feature]
      Features used in the program.
    """
311
    if isinstance(a, IRModule):
312
        a, b = b, a
313
    return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)}
Zhi committed
314 315


316 317 318 319 320 321 322 323 324 325 326 327
def extract_fused_functions(mod):
    """Pass to extract IRModule of only fused primitive functions.

    The ExtractFusedFunctions pass invokes SimplifyInference, FuseOps(3),
    and ExtractFusedFunctions in that order

    Parameters
    ----------
    mod : tvm.relay.IRModule

    Returns
    -------
Zhi committed
328
    ret : Dict[int, tvm.relay.function.Function]
329 330
        A module containing only fused primitive functions
    """
331
    ret_mod = _ffi_api.ExtractFusedFunctions()(mod)
332 333 334 335
    ret = {}
    for hash_, func in ret_mod.functions.items():
        ret[hash_] = func
    return ret