api.py 21.1 KB
Newer Older
1
"""Functions defined in TVM."""
2
# pylint: disable=invalid-name,unused-import,redefined-builtin
3
from __future__ import absolute_import as _abs
4

5
from numbers import Integral as _Integral
6

7
from ._ffi.base import string_types
8 9 10
from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import Function
11
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
12
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
13
from ._ffi.runtime_ctypes import TVMType
14
from . import _api_internal
tqchen committed
15 16
from . import make as _make
from . import expr as _expr
17
from . import tensor as _tensor
18
from . import schedule as _schedule
19
from . import container as _container
20
from . import tag as _tag
21

22
int8 = "int8"
23 24 25
int32 = "int32"
float32 = "float32"
handle = "handle"
26

ziheng committed
27 28

def min_value(dtype):
29
    """minimum value of dtype"""
ziheng committed
30 31 32 33
    return _api_internal._min_value(dtype)


def max_value(dtype):
34
    """maximum value of dtype"""
ziheng committed
35 36 37
    return _api_internal._max_value(dtype)


38
def const(value, dtype=None):
tqchen committed
39
    """construct a constant"""
40 41 42 43 44
    if dtype is None:
        if isinstance(value, _Integral):
            dtype = 'int32'
        else:
            dtype = 'float32'
45
    return _api_internal._const(value, dtype)
46 47


48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
def convert(value):
    """Convert value to TVM node or function.

    Parameters
    ----------
    value : python value

    Returns
    -------
    tvm_val : Node or Function
        Converted value in TVM
    """
    if isinstance(value, (Function, NodeBase)):
        return value

    if callable(value):
        return _convert_tvm_func(value)
65 66

    return _convert_to_node(value)
67 68


69 70 71 72 73 74 75 76 77 78 79 80 81
def load_json(json_str):
    """Load tvm object from json_str.

    Parameters
    ----------
    json_str : str
        The json string

    Returns
    -------
    node : Node
        The loaded tvm node.
    """
82
    return _api_internal._load_json(json_str)
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97


def save_json(node):
    """Load tvm object as json string.

    Parameters
    ----------
    node : Node
        A TVM Node object to be saved.

    Returns
    -------
    json_str : str
        Saved json string.
    """
98
    return _api_internal._save_json(node)
99 100


101
def var(name="tindex", dtype=int32):
tqchen committed
102 103 104 105 106 107 108 109 110
    """Create a new variable with specified name and dtype

    Parameters
    ----------
    name : str
        The name

    dtype : int
        The data type
111 112 113 114 115

    Returns
    -------
    var : Var
        The result symbolic variable.
tqchen committed
116
    """
117
    return _api_internal._Var(name, dtype)
tqchen committed
118 119


120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
def any(*args):
    """Create a new experssion of the union of all conditions in the arguments

    Parameters
    ----------
    args : list
        List of symbolic boolean expressions

    Returns
    -------
    expr: Expr
        Expression
    """
    if not args:
        raise ValueError("Any must take at least 1 argument")
    if len(args) == 1:
        return args[0]
    ret = _make.Or(args[0], args[1])
    for i in range(2, len(args)):
        ret = _make.Or(ret, args[i])
    return ret


def all(*args):
    """Create a new experssion of the intersection of all conditions in the
      arguments

    Parameters
    ----------
    args : list
        List of symbolic boolean expressions

    Returns
    -------
    expr: Expr
        Expression
    """
    if not args:
        raise ValueError("Any must take at least 1 argument")
    if len(args) == 1:
        return args[0]
    ret = _make.And(args[0], args[1])
    for i in range(2, len(args)):
        ret = _make.And(ret, args[i])
    return ret


167
def placeholder(shape, dtype=None, name="placeholder"):
tqchen committed
168
    """Construct an empty tensor object.
169 170 171

    Parameters
    ----------
tqchen committed
172 173 174
    shape: Tuple of Expr
        The shape of the tensor

tqchen committed
175 176 177 178 179 180 181 182
    dtype: str, optional
        The data type of the tensor

    name: str, optional
        The name hint of the tensor

    Returns
    -------
183
    tensor: Tensor
tqchen committed
184 185
        The created tensor
    """
186
    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
tqchen committed
187
    dtype = float32 if dtype is None else dtype
188
    return _api_internal._Placeholder(
189
        shape, dtype, name)
tqchen committed
190 191


192
def compute(shape, fcompute, name="compute", tag=""):
tqchen committed
193 194 195 196 197 198 199 200 201
    """Construct a new tensor by computing over the shape domain.

    The compute rule is result[axis] = fcompute(axis)

    Parameters
    ----------
    shape: Tuple of Expr
        The shape of the tensor

202
    fcompute: lambda function of indices-> value
tqchen committed
203
        Specifies the input source expression
204

tqchen committed
205 206 207 208 209
    name: str, optional
        The name hint of the tensor

    Returns
    -------
210
    tensor: Tensor
tqchen committed
211
        The created tensor
212
    """
213 214 215 216
    if _tag.TagScope.current is not None:
        if tag != "":
            raise ValueError("nested tag is not allowed for now")
        tag = _tag.TagScope.current.tag
217
    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
tqchen committed
218
    ndim = len(shape)
219
    code = fcompute.__code__
220

221
    if fcompute.__code__.co_argcount == 0:
222
        arg_names = ["i%d" % i for i in range(ndim)]
223 224 225
    else:
        arg_names = code.co_varnames[:code.co_argcount]

tqchen committed
226
    if ndim != len(arg_names):
227
        raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
228

229
    dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
230
    body = fcompute(*[v.var for v in dim_var])
231 232
    if not isinstance(body, (list, tuple)):
        body = [body]
tqchen committed
233
    body = convert(body)
234
    op_node = _api_internal._ComputeOp(
235
        name, tag, dim_var, body)
236 237 238
    num = op_node.num_outputs
    outputs = tuple(op_node.output(i) for i in range(num))
    return outputs[0] if num == 1 else outputs
239 240


241
def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
242 243 244 245 246 247 248 249 250 251 252 253 254
    """Construct new tensors by scanning over axis.

    Parameters
    ----------
    init: Tensor or list of Tensor
        The initial condition of first init.shape[0] timestamps

    update: Tensor or list of Tensor
        The update rule of the scan given by symbolic tensor.

    state_placeholder: Tensor or list of Tensor
        The placeholder variables used by update.

255 256 257 258
    inputs: Tensor or list of Tensor, optional
        The list of inputs to the scan. This is not required, but can
        be useful for the compiler to detect scan body faster.

259 260 261 262 263
    name: str, optional
        The name hint of the tensor

    Returns
    -------
264 265
    tensor: Tensor or list of Tensors
        The created tensor or tuple of tensors it it contains multiple outputs.
266 267 268

    Example
    -------
269 270 271
    .. code-block:: python

      # The following code is equivalent to numpy.cumsum
272 273
      m = tvm.var("m")
      n = tvm.var("n")
274 275 276 277 278
      X = tvm.placeholder((m, n), name="X")
      s_state = tvm.placeholder((m, n))
      s_init = tvm.compute((1, n), lambda _, i: X[0, i])
      s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
      res = tvm.scan(s_init, s_update, s_state, X)
279
    """
280 281 282 283
    if _tag.TagScope.current is not None:
        if tag != "":
            raise ValueError("nested tag is not allowed for now")
        tag = _tag.TagScope.current.tag
284 285 286 287 288 289
    if isinstance(init, _tensor.Tensor):
        init = [init]
    if isinstance(update, _tensor.Tensor):
        update = [update]
    if isinstance(state_placeholder, _tensor.Tensor):
        state_placeholder = [state_placeholder]
290 291 292 293
    if isinstance(inputs, _tensor.Tensor):
        inputs = [inputs]
    if inputs is None:
        inputs = []
294 295
    if len(init) != len(update) or len(init) != len(state_placeholder):
        raise ValueError("init, update, state_placeholder must have same length")
296
    axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
297 298
    op = _api_internal._ScanOp(name, tag, axis, init, update,
                               state_placeholder, inputs)
299
    res = [op.output(i) for i in range(len(update))]
300 301
    return res[0] if len(res) == 1 else res

302

303 304 305 306 307 308 309 310
def extern(shape,
           inputs,
           fcompute,
           name="extern",
           dtype=None,
           in_buffers=None,
           out_buffers=None,
           tag=""):
311 312 313 314
    """Compute several tensor via extern function.

    Parameters
    ----------
315
    shape: tuple or list of tuples.
316 317 318 319 320 321 322
        The shape of the outputs.

    inputs: list of Tensor
        The inputs

    fcompute: lambda function of inputs, outputs-> stmt
        Specifies the IR statement to do the computation.
323 324 325 326 327 328 329 330 331 332 333
        See the following note for function signature of fcompute

        .. note::
             **Parameters**

             - **ins** (list of :any:`Buffer`) - Placeholder for each inputs
             - **outs** (list of :any:`Buffer`) - Placeholder for each outputs

             **Returns**

             - **stmt** (:any:`Stmt`) - The statement that carries out array computation.
334 335 336 337 338 339 340 341

    name: str, optional
        The name hint of the tensor

    dtype: str or list of str, optional
        The data types of outputs,
        by default dtype will be same as inputs.

342 343 344 345 346 347
    in_buffers: Buffer or list of Buffer, optional
        Input buffers.

    out_buffers: Buffer or list of Buffers, optional
        Output buffers.

348 349 350 351
    Returns
    -------
    tensor: Tensor or list of Tensors
        The created tensor or tuple of tensors it it contains multiple outputs.
352 353 354 355 356 357 358 359 360 361 362 363 364 365

    Example
    -------
    In the code below, C is generated by calling external PackedFunc
    `tvm.contrib.cblas.matmul`

    .. code-block:: python

        A = tvm.placeholder((n, l), name='A')
        B = tvm.placeholder((l, m), name='B')
        C = tvm.extern((n, m), [A, B],
                       lambda ins, outs: tvm.call_packed(
                          "tvm.contrib.cblas.matmul",
                            ins[0], ins[1], outs[0], 0, 0), name="C")
366
    """
367 368 369 370
    if _tag.TagScope.current is not None:
        if tag != "":
            raise ValueError("nested tag is not allowed for now")
        tag = _tag.TagScope.current.tag
371 372
    shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
    shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
373 374 375 376 377 378 379 380 381 382 383 384
    if in_buffers is not None:
        in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
        if len(inputs) != len(in_buffers):
            raise RuntimeError("Number of inputs and in_buffers mismatch: %d vs %d."
                               % (len(inputs), len(in_buffers)))
    if out_buffers is not None:
        out_buffers = [out_buffers] if not isinstance(out_buffers, list) else out_buffers
        if len(shape) != len(out_buffers):
            raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d."
                               % (len(shape), len(out_buffers)))
    input_placeholders = in_buffers or []
    output_placeholders = out_buffers or []
385 386 387 388
    types = set()
    for t in inputs:
        if not isinstance(t, _tensor.Tensor):
            raise ValueError("expect inputs to be tensor")
389 390 391
        if in_buffers is None:
            input_placeholders.append(
                decl_buffer(t.shape, t.dtype, t.op.name))
392 393 394 395 396 397 398
        types.add(t.dtype)

    if dtype is None:
        if len(types) != 1:
            raise ValueError("Cannot infer output type, please provide dtype argument")
        infered_type = types.pop()
        dtype = [infered_type for _ in shape]
399 400
    if isinstance(dtype, str):
        dtype = [dtype]
401

402 403 404
    if out_buffers is None:
        for shp, dt in zip(shape, dtype):
            output_placeholders.append(decl_buffer(shp, dt, name))
405 406 407 408
    body = fcompute(input_placeholders, output_placeholders)
    if isinstance(body, _expr.Expr):
        body = _make.Evaluate(body)

409 410
    op = _api_internal._ExternOp(name, tag, inputs, input_placeholders,
                                 output_placeholders, body)
411 412 413 414
    res = [op.output(i) for i in range(len(output_placeholders))]
    return res[0] if len(res) == 1 else res


415 416
def decl_buffer(shape,
                dtype=None,
417 418 419
                name="buffer",
                data=None,
                strides=None,
420
                elem_offset=None,
421
                scope="",
422
                data_alignment=-1,
423
                offset_factor=0):
424 425 426 427
    """Decleare a new symbolic buffer.

    Normally buffer is created automatically during lower and build.
    This is only needed if user want to specify their own buffer layout.
428

429 430
    See the note below for detailed discussion on usage of buffer.

431 432 433 434 435 436 437 438 439 440 441
    Parameters
    ----------
    shape : tuple of Expr
        The shape of the buffer.

    dtype : str, optional
        The data type of the buffer.

    name : str, optional
        The name of the buffer.

442
    data : Var, optional
443 444 445 446 447
        The data pointer in the buffer.

    strides: array of Expr
        The stride of the buffer.

448 449 450
    elem_offset: Expr, optional
        The beginning offset of the array to data.
        In terms of number of elements of dtype.
451

452 453 454 455
    scope: str, optional
        The storage scope of the buffer, if not global.
        If scope equals empty string, it means it is global memory.

456 457
    data_alignment: int, optional
        The alignment of data pointer in bytes.
458
        If -1 is passed, the alignment will be set to TVM's internal default.
459 460 461 462 463 464

    offset_factor: int, optional
        The factor of elem_offset field, when set,
        elem_offset is required to be multiple of offset_factor.
        If 0 is pssed, the alignment will be set to 1.
        if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
465

466 467 468 469
    Returns
    -------
    buffer : Buffer
        The created buffer
470 471 472 473 474 475 476 477

    Note
    ----
    Buffer data structure reflects the DLTensor structure in dlpack.
    While DLTensor data structure is very general, it is usually helpful
    to create function that only handles specific case of data structure
    and make compiled function benefit from it.

478
    If user pass strides and elem_offset is passed as None
479 480 481 482
    when constructing the function, then the function will be specialized
    for the DLTensor that is compact and aligned.
    If user pass a fully generic symbolic array to the strides,
    then the resulting function becomes fully generic.
483
    """
484
    shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
485 486
    dtype = float32 if dtype is None else dtype
    strides = () if strides is None else strides
487 488
    if offset_factor != 0 and elem_offset is None:
        elem_offset = var('%s_elem_offset' % name, shape[0].dtype)
489
    if data is None:
490
        data = var(name, "handle")
491
    return _api_internal._Buffer(
492 493
        data, dtype, shape, strides, elem_offset, name, scope,
        data_alignment, offset_factor)
tqchen committed
494

495

496 497
def _IterVar(dom, name, iter_type, thread_tag=''):
    """Internal function to create IterVar
tqchen committed
498 499 500

    Parameters
    ----------
tqchen committed
501
    dom : Range
502
        The domain of iteration.
tqchen committed
503 504

    name : str
505 506 507 508
        The name of iteration variable.

    iter_type : int
        The type of iteration.
tqchen committed
509 510 511

    thread_tag : str
        The thread tag of the iteration variable.
tqchen committed
512 513 514

    Returns
    -------
tqchen committed
515 516
    iter_var : IterVar
       The result itervar
tqchen committed
517
    """
518 519 520
    if dom is not None:
        if isinstance(dom, (list, tuple)):
            if len(dom) != 2:
521
                raise TypeError("need to be list of ranges")
522 523
            dom = Range(dom[0], dom[1])

524
        if not isinstance(dom, _container.Range):
525
            raise TypeError("dom need to be Range")
526
    name = name if name else 'iter'
527 528
    v = var(name)
    return _api_internal._IterVar(dom, v, iter_type, thread_tag)
529 530


531
def thread_axis(dom=None, tag='', name=''):
532 533 534 535
    """Create a new IterVar to represent thread index.

    Parameters
    ----------
536 537 538
    dom : Range or str
        The domain of iteration
        When str is passed, dom is set to None and str is used as tag
539

540
    tag : str, optional
541 542 543 544
        The thread tag

    name : str, optional
        The name of the var.
545 546 547 548 549

    Returns
    -------
    axis : IterVar
        The thread itervar.
550
    """
551
    if isinstance(dom, string_types):
552
        tag, dom = dom, None
553
    if not tag:
554
        raise ValueError("tag must be given as Positional or keyword argument")
555 556 557 558 559 560 561 562 563 564 565 566 567 568
    name = name if name else tag
    return _IterVar(dom, name, 1, tag)


def reduce_axis(dom, name="rv"):
    """Create a new IterVar for reduction.

    Parameters
    ----------
    dom : Range
        The domain of iteration.

    name : str
        The name of the variable.
569 570 571 572 573

    Returns
    -------
    axis : IterVar
        An iteration variable representing the value.
574 575
    """
    return _IterVar(dom, name, 2)
tqchen committed
576

ziheng committed
577

Yizhi Liu committed
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
def select(cond, t, f):
    """Construct a select branch
    Parameters
    ----------
    cond : Expr
        The condition
    t : Expr
        The result expression if cond is true.
    f : Expr
        The result expression if cond is false.

    Returns
    -------
    node : Node
        The tvm.expr.Select node
    """
    return _make.Select(convert(cond), convert(t), convert(f))
tqchen committed
595

ziheng committed
596 597
def comm_reducer(fcombine, fidentity, name="reduce"):
    """Create a commutative reducer for reduction.
tqchen committed
598 599 600

    Parameters
    ----------
ziheng committed
601 602
    fcombine : function(Expr -> Expr -> Expr)
        A binary function which takes two Expr as input to return a Expr.
603

ziheng committed
604 605
    fidentity : function(str -> Expr)
        A function which takes a type string as input to return a const Expr.
606 607 608

    Returns
    -------
ziheng committed
609
    reducer : function
610
        A function which creates a reduce expression over axis.
611
        There are two ways to use it:
612 613 614 615

        1. accept (expr, axis, where) to produce an Reduce Expr on
           specified axis;
        2. simply use it with multiple Exprs.
616

ziheng committed
617
    Example
618
    -------
ziheng committed
619
    .. code-block:: python
620

ziheng committed
621 622 623 624 625 626 627
        n = tvm.var('n')
        m = tvm.var('m')
        mysum = tvm.comm_reducer(lambda x, y: x+y,
            lambda t: tvm.const(0, dtype=t), name="mysum")
        A = tvm.placeholder((n, m), name='A')
        k = tvm.reduce_axis((0, m), name='k')
        B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name='B')
tqchen committed
628
    """
ziheng committed
629 630 631 632 633 634 635 636 637 638 639 640 641
    def _reduce_directly(*args):
        num = len(args)
        # process `where` is None
        if num == 3 and args[2] is None:
            num = 2
        res = args[0]
        for i in range(num-1):
            res = fcombine(res, args[i+1])
        return res

    def _make_reduce(expr, axis, where=None):
        code = fcombine.__code__
        assert fcombine.__code__.co_argcount == 2
642
        expr = convert(expr)
643
        if isinstance(expr, _container.Array):
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669
            size = len(expr)
            larr = []
            rarr = []
            dtypes = []
            for i in range(size):
                dtype = expr[i].dtype
                dtypes.append(dtype)
                lname = code.co_varnames[0] + '_' + str(i)
                larr.append(var(lname, dtype))
                rname = code.co_varnames[1] + '_' + str(i)
                rarr.append(var(rname, dtype))
            lhs = convert(larr)
            rhs = convert(rarr)
            result = fcombine(lhs, rhs)
            id_elem = fidentity(*dtypes)
        else:
            assert isinstance(expr, _expr.Expr)
            size = 1
            dtype = expr.dtype
            lvar = var(code.co_varnames[0], dtype)
            rvar = var(code.co_varnames[1], dtype)
            result = [fcombine(lvar, rvar)]
            id_elem = [fidentity(dtype)]
            lhs = convert([lvar])
            rhs = convert([rvar])
            expr = convert([expr])
ziheng committed
670
        result = convert(result)
671 672
        id_elem = convert(id_elem)
        combiner = _make.CommReducer(lhs, rhs, result, id_elem)
673
        axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
674 675 676 677 678
        if where is None:
            where = convert(True)
        outputs = tuple(_make.Reduce(combiner, expr, axis, where, i)
                        for i in range(size))
        return outputs[0] if size == 1 else outputs
ziheng committed
679

680
    # pylint: disable=keyword-arg-before-vararg
ziheng committed
681
    def reducer(expr, axis, where=None, *args):
682
        if isinstance(axis, (_schedule.IterVar, list, tuple)):
683
            assert not args
ziheng committed
684
            return _make_reduce(expr, axis, where)
685 686 687 688
        if where is None:
            assert not args
            return _reduce_directly(expr, axis)
        return _reduce_directly(expr, axis, where, *args)
ziheng committed
689 690

    doc_str = """Create a {0} expression over axis.
691

ziheng committed
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707
              Parameters
              ----------
              expr : Expr
                  The source expression.
              axis : IterVar
                  The reduction IterVar axis
              where : optional, Expr
                  Filtering predicate of the reduction.
              Returns
              -------
              value : Expr
                  The result value.

              Example
              -------
              .. code-block:: python
708

ziheng committed
709 710 711 712 713 714 715
                m = tvm.var("m")
                n = tvm.var("n")
                A = tvm.placeholder((m, n), name="A")
                k = tvm.reduce_axis((0, n), name="k")

                # there are two way to use this {0} reducer:
                # mode 1, accept (expr, axis, where) to produce an Reduce Expr
716
                B = tvm.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B")
ziheng committed
717 718

                # mode 2, simply use it with multiple Exprs:
719
                {0}_res = tvm.{0}(m, n)
ziheng committed
720 721 722
              """
    reducer.__doc__ = doc_str.format(name)
    return reducer
723

tqchen committed
724

725
_init_api("tvm.api")
ziheng committed
726 727 728 729
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _make.Min(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _make.Max(x, y), min_value, name='max')