make.py 3.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24
"""namespace of IR node builder make function

This namespace is used for developers. While you do not see any declarations.
The functions are automatically exported from C++ side via PackedFunc.

Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node.
"""
25
from __future__ import absolute_import as _abs
26
from ._ffi.function import _init_api
27
from ._ffi.runtime_ctypes import TVMType
28

29

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
def range_by_min_extent(min_value, extent):
    """Construct a Range by min and extent.

    This constructs a range in [min_value, min_value + extent)

    Parameters
    ----------
    min_value : Expr
        The minimum value of the range.

    extent : Expr
        The extent of the range.

    Returns
    -------
    rng : Range
        The constructed range.
    """
    return _range_by_min_extent(min_value, extent)

50

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
def static_cast(dtype, expr):
    """Cast expr to dtype.

    If expr is scalar and dtype is a corresponding vector
    type, a Broadcast is generated. Otherwise it is a Cast.

    Parameters
    ----------
    dtype : str
        The target data type.

    expr : Expr
        The expression to be casted.

    Returns
    -------
    casted : Expr
        The casted expression.
    """
    target_type = TVMType(dtype)
    src_type = TVMType(expr.dtype)
72 73 74
    if target_type.type_code == src_type.type_code and src_type.bits == target_type.bits:
        if src_type.lanes == target_type.lanes:
            return expr
75
        if src_type.lanes == 1 and target_type.lanes > 1:
76
            return Broadcast(expr, target_type.lanes)
77 78 79
    return Cast(dtype, expr)


80 81 82 83 84 85 86 87 88 89 90
def node(type_key, **kwargs):
    """Make a new DSL node by its type key and fields

    Parameters
    ----------
    type_key : str
        The type key of the node.

    **kwargs : dict
        The fields of the node.

91 92 93 94 95 96 97 98 99 100 101
    Returns
    -------
    node : Node
        The corresponding DSL Node

    Note
    ----
    If the created node is instance of AttrsNode, then
    the creator function will also run bound checks and
    default value setup as supported by Attrs.

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    Example
    -------
    The following code constructs a IntImm object

    .. code-block:: python

       x = tvm.make.node("IntImm", dtype="int32", value=10)
       assert isinstance(x, tvm.expr.IntImm)
       assert x.value == 10
    """
    args = [type_key]
    for k, v in kwargs.items():
        args += [k, v]
    return _Node(*args)


118
_init_api("tvm.make")