function.py 4.73 KB
Newer Older
1 2
from __future__ import absolute_import as _abs
from numbers import Number as _Number, Integral as _Integral
3
from ._ctypes._api import _init_function_module, convert
tqchen committed
4 5 6 7
from . import _function_internal
from . import make as _make
from . import expr as _expr
from . import collections as _collections
8 9 10 11 12

int32 = "int32"
float32 = "float32"

def const(value, dtype=None):
tqchen committed
13
    """construct a constant"""
14 15 16 17 18 19 20 21
    if dtype is None:
        if isinstance(value, _Integral):
            dtype = 'int32'
        else:
            dtype = 'float32'
    return _function_internal._const(value, dtype)


22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
def load_json(json_str):
    """Load tvm object from json_str.

    Parameters
    ----------
    json_str : str
        The json string

    Returns
    -------
    node : Node
        The loaded tvm node.
    """
    return _function_internal._load_json(json_str)


def save_json(node):
    """Load tvm object as json string.

    Parameters
    ----------
    node : Node
        A TVM Node object to be saved.

    Returns
    -------
    json_str : str
        Saved json string.
    """
    return _function_internal._save_json(node)


tqchen committed
54 55 56 57 58 59 60 61 62 63 64 65 66 67
def Var(name="tindex", dtype=int32):
    """Create a new variable with specified name and dtype

    Parameters
    ----------
    name : str
        The name

    dtype : int
        The data type
    """
    return _function_internal._Var(name, dtype)


68
def placeholder(shape, dtype = None, name="placeholder"):
tqchen committed
69
    """Construct an empty tensor object.
70 71 72

    Parameters
    ----------
tqchen committed
73 74 75
    shape: Tuple of Expr
        The shape of the tensor

tqchen committed
76 77 78 79 80 81 82 83 84 85 86 87
    dtype: str, optional
        The data type of the tensor

    name: str, optional
        The name hint of the tensor

    Returns
    -------
    tensor: tensor.Tensor
        The created tensor
    """
    dtype = float32 if dtype is None else dtype
88 89
    return _function_internal._Placeholder(
        shape, dtype, name)
tqchen committed
90 91


92
def compute(shape, fcompute, name="compute"):
tqchen committed
93 94 95 96 97 98 99 100 101 102
    """Construct a new tensor by computing over the shape domain.

    The compute rule is result[axis] = fcompute(axis)

    Parameters
    ----------
    shape: Tuple of Expr
        The shape of the tensor


tqchen committed
103 104
    fcompute: lambda function of *indices-> value
        Specifies the input source expression
105

tqchen committed
106 107 108 109 110 111 112
    name: str, optional
        The name hint of the tensor

    Returns
    -------
    tensor: tensor.Tensor
        The created tensor
113
    """
tqchen committed
114 115 116
    if isinstance(shape, _expr.Expr):
        shape = (shape, )

tqchen committed
117
    ndim = len(shape)
tqchen committed
118 119 120
    arg_names = fcompute.__code__.co_varnames
    if ndim != len(arg_names):
        raise ValueError("fcompute do not match dimension")
121 122 123

    dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)]
    body = fcompute(*[v.var for v in dim_var])
tqchen committed
124
    body = convert(body)
tqchen committed
125
    op_node = _function_internal._ComputeOp(
126
        name, dim_var, body)
tqchen committed
127 128
    return _function_internal._Tensor(
        shape, name, body.dtype, op_node, 0)
tqchen committed
129

130

tqchen committed
131 132
def IterVar(dom, name='iter', thread_tag=''):
    """Create a iteration variable
tqchen committed
133 134 135

    Parameters
    ----------
tqchen committed
136 137 138 139 140 141 142 143
    dom : Range
       The domain of iteration.

    name : str
       The name of iteration variable.

    thread_tag : str
        The thread tag of the iteration variable.
tqchen committed
144 145 146

    Returns
    -------
tqchen committed
147 148
    iter_var : IterVar
       The result itervar
tqchen committed
149
    """
tqchen committed
150 151 152 153 154 155 156 157 158
    if isinstance(dom, (list, tuple)):
        if len(dom) != 2:
            raise ValueError("need to list of ranges")
        dom = Range(dom[0], dom[1])

    if not isinstance(dom, _collections.Range):
        raise ValueError("dom need to be Range")

    return _function_internal._IterVar(dom, name, thread_tag)
tqchen committed
159 160 161 162 163 164 165 166 167 168 169 170 171


def sum(expr, rdom):
    """Create a sum expression over rdom

    Parameters
    ----------
    expr : Expr
        The source expression.

    rdom : RDomain
        The reduction domainx
    """
tqchen committed
172
    rdom = rdom if isinstance(rdom, list) else [rdom]
tqchen committed
173 174 175
    x =  _make.Reduce("Add", expr, rdom)
    return x

tqchen committed
176

tqchen committed
177 178 179 180 181 182 183 184 185 186 187
def min(expr, rdom):
    """Create a min expression over rdom

    Parameters
    ----------
    expr : Expr
        The source expression.

    rdom : RDomain
        The reduction domainx
    """
tqchen committed
188
    rdom = rdom if isinstance(rdom, list) else [rdom]
tqchen committed
189 190 191
    x =  _make.Reduce("Min", expr, rdom)
    return x

tqchen committed
192

tqchen committed
193 194 195 196 197 198 199 200 201 202 203
def max(expr, rdom):
    """Create a min expression over rdom

    Parameters
    ----------
    expr : Expr
        The source expression.

    rdom : RDomain
        The reduction domainx
    """
tqchen committed
204
    rdom = rdom if isinstance(rdom, list) else [rdom]
tqchen committed
205 206 207 208
    x =  _make.Reduce("Max", expr, rdom)
    return x


209 210 211 212 213 214 215 216 217 218 219
def Schedule(ops):
    """Create a schedule for list of ops

    Parameters
    ----------
    ops : list of Operations
        The source expression.
    """
    if not isinstance(ops, (list, _collections.Array)):
        ops = [ops]
    return _function_internal._Schedule(ops)
220 221


222
_init_function_module("tvm")