make.py 3.04 KB
Newer Older
1 2 3 4 5 6 7 8
"""namespace of IR node builder make function

This namespace is used for developers. While you do not see any declarations.
The functions are automatically exported from C++ side via PackedFunc.

Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node.
"""
9
from ._ffi.function import _init_api
10
from ._ffi.runtime_ctypes import TVMType
11
from . import stmt as _stmt
12

13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
def range_by_min_extent(min_value, extent):
    """Construct a Range by min and extent.

    This constructs a range in [min_value, min_value + extent)

    Parameters
    ----------
    min_value : Expr
        The minimum value of the range.

    extent : Expr
        The extent of the range.

    Returns
    -------
    rng : Range
        The constructed range.
    """
    return _range_by_min_extent(min_value, extent)

33

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
def static_cast(dtype, expr):
    """Cast expr to dtype.

    If expr is scalar and dtype is a corresponding vector
    type, a Broadcast is generated. Otherwise it is a Cast.

    Parameters
    ----------
    dtype : str
        The target data type.

    expr : Expr
        The expression to be casted.

    Returns
    -------
    casted : Expr
        The casted expression.
    """
    target_type = TVMType(dtype)
    src_type = TVMType(expr.dtype)
55 56 57 58 59
    if target_type.type_code == src_type.type_code and src_type.bits == target_type.bits:
        if src_type.lanes == target_type.lanes:
            return expr
        elif src_type.lanes == 1 and target_type.lanes > 1:
            return Broadcast(expr, target_type.lanes)
60 61 62
    return Cast(dtype, expr)


63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
def node(type_key, **kwargs):
    """Make a new DSL node by its type key and fields

    Parameters
    ----------
    type_key : str
        The type key of the node.

    **kwargs : dict
        The fields of the node.

    Example
    -------
    The following code constructs a IntImm object

    .. code-block:: python

       x = tvm.make.node("IntImm", dtype="int32", value=10)
       assert isinstance(x, tvm.expr.IntImm)
       assert x.value == 10
    """
    args = [type_key]
    for k, v in kwargs.items():
        args += [k, v]
    return _Node(*args)


90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
def stmt_seq(*args):
    """Make sequence of statements

    Parameters
    ----------
    args : list of Expr or Var
        List of statements to be combined as sequence.

    Returns
    -------
    stmt : Stmt
        The combined statement.
    """
    ret = None
    for value in args:
        if not isinstance(value, _stmt.Stmt):
            value = Evaluate(value)
        ret = value if ret is None else Block(ret, value)
    return ret if ret else Evaluate(0)

110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129

def stmt_list(stmt):
    """Make list of stmt from blocks.

    Parameters
    ----------
    stmt : A block statement

    Returns
    -------
    stmt_list : list of Stmt
         The unpacked list of statements
    """
    if isinstance(stmt, _stmt.Block):
        return stmt_list(stmt.first) + stmt_list(stmt.rest)
    elif isinstance(stmt, _stmt.ProducerConsumer):
        return stmt_list(stmt.body)
    return [stmt]


130
_init_api("tvm.make")