api.py 24.3 KB
Newer Older
1
"""Functions defined in TVM."""
2
# pylint: disable=invalid-name,unused-import,redefined-builtin
3
from __future__ import absolute_import as _abs
4

5
from numbers import Integral as _Integral
6

7
from ._ffi.base import string_types
8 9 10
from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import Function
11
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
12
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
13
from ._ffi.runtime_ctypes import TVMType
14
from . import _api_internal
tqchen committed
15 16
from . import make as _make
from . import expr as _expr
17
from . import tensor as _tensor
18
from . import schedule as _schedule
19
from . import container as _container
20
from . import tag as _tag
21

22
int8 = "int8"
23 24 25
int32 = "int32"
float32 = "float32"
handle = "handle"
26

ziheng committed
27 28

def min_value(dtype):
29 30 31 32 33 34 35 36 37 38 39 40
    """minimum value of dtype

    Parameters
    ----------
    dtype : str
        The data type.

    Returns
    -------
    value : tvm.Expr
        The minimum value of dtype.
    """
ziheng committed
41 42 43 44
    return _api_internal._min_value(dtype)


def max_value(dtype):
45 46 47 48 49 50 51 52 53 54 55 56
    """maximum value of dtype

    Parameters
    ----------
    dtype : str
        The data type.

    Returns
    -------
    value : tvm.Expr
        The maximum value of dtype.
    """
ziheng committed
57 58 59
    return _api_internal._max_value(dtype)


60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
def const(value, dtype):
    """construct a constant

    Parameters
    ----------
    value : number
        The content of the constant number.

    dtype : str
        The data type.

    Returns
    -------
    const_val: tvm.Expr
        The result expression.
    """
76
    return _api_internal._const(value, dtype)
77 78


79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
def get_env_func(name):
    """Get an EnvFunc by a global name.

    Parameters
    ----------
    name: str
        The name of the global function.

    Returns
    -------
    env_func : EnvFunc
        The result env function.

    Note
    ----
    EnvFunc is a Node wrapper around
    global function that can be serialized via its name.
    This can be used to serialize function field in the language.
    """
    return _api_internal._EnvFuncGet(name)


101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
def convert(value):
    """Convert value to TVM node or function.

    Parameters
    ----------
    value : python value

    Returns
    -------
    tvm_val : Node or Function
        Converted value in TVM
    """
    if isinstance(value, (Function, NodeBase)):
        return value

    if callable(value):
        return _convert_tvm_func(value)
118 119

    return _convert_to_node(value)
120 121


122 123 124 125 126 127 128 129 130 131 132 133 134
def load_json(json_str):
    """Load tvm object from json_str.

    Parameters
    ----------
    json_str : str
        The json string

    Returns
    -------
    node : Node
        The loaded tvm node.
    """
135
    return _api_internal._load_json(json_str)
136 137 138


def save_json(node):
139
    """Save tvm object as json string.
140 141 142 143 144 145 146 147 148 149 150

    Parameters
    ----------
    node : Node
        A TVM Node object to be saved.

    Returns
    -------
    json_str : str
        Saved json string.
    """
151
    return _api_internal._save_json(node)
152 153


154
def var(name="tindex", dtype=int32):
tqchen committed
155 156 157 158 159 160 161 162 163
    """Create a new variable with specified name and dtype

    Parameters
    ----------
    name : str
        The name

    dtype : int
        The data type
164 165 166 167 168

    Returns
    -------
    var : Var
        The result symbolic variable.
tqchen committed
169
    """
170
    return _api_internal._Var(name, dtype)
tqchen committed
171 172


173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
def any(*args):
    """Create a new experssion of the union of all conditions in the arguments

    Parameters
    ----------
    args : list
        List of symbolic boolean expressions

    Returns
    -------
    expr: Expr
        Expression
    """
    if not args:
        raise ValueError("Any must take at least 1 argument")
    if len(args) == 1:
        return args[0]
190
    ret = _make._OpOr(args[0], args[1])
191
    for i in range(2, len(args)):
192
        ret = _make._OpOr(ret, args[i])
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
    return ret


def all(*args):
    """Create a new experssion of the intersection of all conditions in the
      arguments

    Parameters
    ----------
    args : list
        List of symbolic boolean expressions

    Returns
    -------
    expr: Expr
        Expression
    """
    if not args:
        raise ValueError("Any must take at least 1 argument")
    if len(args) == 1:
        return args[0]
214
    ret = _make._OpAnd(args[0], args[1])
215
    for i in range(2, len(args)):
216
        ret = _make._OpAnd(ret, args[i])
217 218 219
    return ret


220
def placeholder(shape, dtype=None, name="placeholder"):
tqchen committed
221
    """Construct an empty tensor object.
222 223 224

    Parameters
    ----------
tqchen committed
225 226 227
    shape: Tuple of Expr
        The shape of the tensor

tqchen committed
228 229 230 231 232 233 234 235
    dtype: str, optional
        The data type of the tensor

    name: str, optional
        The name hint of the tensor

    Returns
    -------
236
    tensor: Tensor
tqchen committed
237 238
        The created tensor
    """
239
    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
tqchen committed
240
    dtype = float32 if dtype is None else dtype
241
    return _api_internal._Placeholder(
242
        shape, dtype, name)
tqchen committed
243 244


245
def compute(shape, fcompute, name="compute", tag="", attrs=None):
tqchen committed
246 247 248 249 250 251 252 253 254
    """Construct a new tensor by computing over the shape domain.

    The compute rule is result[axis] = fcompute(axis)

    Parameters
    ----------
    shape: Tuple of Expr
        The shape of the tensor

255
    fcompute: lambda function of indices-> value
tqchen committed
256
        Specifies the input source expression
257

tqchen committed
258 259 260
    name: str, optional
        The name hint of the tensor

261 262 263 264 265 266
    tag: str, optional
        Additonal tag information about the compute.

    attrs: dict, optional
        The additional auxiliary attributes about the compute.

tqchen committed
267 268
    Returns
    -------
269
    tensor: Tensor
tqchen committed
270
        The created tensor
271
    """
272
    if _tag.TagScope.get_current() is not None:
273 274
        if tag != "":
            raise ValueError("nested tag is not allowed for now")
275
        tag = _tag.TagScope.get_current().tag
276
    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
277 278
    # for python3
    shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
tqchen committed
279
    ndim = len(shape)
280
    code = fcompute.__code__
281

282 283
    out_ndim = ndim
    if code.co_argcount == 0:
284
        arg_names = ["i%d" % i for i in range(ndim)]
285 286
    else:
        arg_names = code.co_varnames[:code.co_argcount]
287
        out_ndim = code.co_argcount
288

289
    if out_ndim != len(arg_names):
290
        raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
291

292
    dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
293
    body = fcompute(*[v.var for v in dim_var])
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313

    if isinstance(body, _tensor.TensorIntrinCall):
        for i, s in enumerate(shape[out_ndim:]):
            var_name = "ax" + str(i)
            dim_var.append(_IterVar((0, s), var_name, 4))
        op_node = _api_internal._TensorComputeOp(name,
                                                 tag,
                                                 dim_var,
                                                 body.reduce_axis,
                                                 out_ndim,
                                                 body.intrin,
                                                 body.tensors,
                                                 body.regions)
    else:
        if not isinstance(body, (list, tuple)):
            body = [body]
        body = convert(body)
        op_node = _api_internal._ComputeOp(
            name, tag, attrs, dim_var, body)

314 315 316
    num = op_node.num_outputs
    outputs = tuple(op_node.output(i) for i in range(num))
    return outputs[0] if num == 1 else outputs
317 318


319
def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attrs=None):
320 321 322 323 324 325 326 327 328 329 330 331 332
    """Construct new tensors by scanning over axis.

    Parameters
    ----------
    init: Tensor or list of Tensor
        The initial condition of first init.shape[0] timestamps

    update: Tensor or list of Tensor
        The update rule of the scan given by symbolic tensor.

    state_placeholder: Tensor or list of Tensor
        The placeholder variables used by update.

333 334 335 336
    inputs: Tensor or list of Tensor, optional
        The list of inputs to the scan. This is not required, but can
        be useful for the compiler to detect scan body faster.

337 338 339
    name: str, optional
        The name hint of the tensor

340 341 342 343 344 345
    tag: str, optional
        Additonal tag information about the compute.

    attrs: dict, optional
        The additional auxiliary attributes about the compute.

346 347
    Returns
    -------
348 349
    tensor: Tensor or list of Tensors
        The created tensor or tuple of tensors it it contains multiple outputs.
350 351 352

    Example
    -------
353 354 355
    .. code-block:: python

      # The following code is equivalent to numpy.cumsum
356 357
      m = tvm.var("m")
      n = tvm.var("n")
358 359 360 361 362
      X = tvm.placeholder((m, n), name="X")
      s_state = tvm.placeholder((m, n))
      s_init = tvm.compute((1, n), lambda _, i: X[0, i])
      s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
      res = tvm.scan(s_init, s_update, s_state, X)
363
    """
364
    if _tag.TagScope.get_current() is not None:
365 366
        if tag != "":
            raise ValueError("nested tag is not allowed for now")
367
        tag = _tag.TagScope.get_current().tag
368 369 370 371 372 373
    if isinstance(init, _tensor.Tensor):
        init = [init]
    if isinstance(update, _tensor.Tensor):
        update = [update]
    if isinstance(state_placeholder, _tensor.Tensor):
        state_placeholder = [state_placeholder]
374 375 376 377
    if isinstance(inputs, _tensor.Tensor):
        inputs = [inputs]
    if inputs is None:
        inputs = []
378 379
    if len(init) != len(update) or len(init) != len(state_placeholder):
        raise ValueError("init, update, state_placeholder must have same length")
380
    axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
381 382
    op = _api_internal._ScanOp(name, tag, attrs,
                               axis, init, update,
383
                               state_placeholder, inputs)
384
    res = [op.output(i) for i in range(len(update))]
385 386
    return res[0] if len(res) == 1 else res

387

388 389 390 391 392 393 394
def extern(shape,
           inputs,
           fcompute,
           name="extern",
           dtype=None,
           in_buffers=None,
           out_buffers=None,
395 396
           tag="",
           attrs=None):
397 398 399 400
    """Compute several tensor via extern function.

    Parameters
    ----------
401
    shape: tuple or list of tuples.
402 403 404 405 406 407 408
        The shape of the outputs.

    inputs: list of Tensor
        The inputs

    fcompute: lambda function of inputs, outputs-> stmt
        Specifies the IR statement to do the computation.
409 410 411 412 413 414 415 416 417 418 419
        See the following note for function signature of fcompute

        .. note::
             **Parameters**

             - **ins** (list of :any:`Buffer`) - Placeholder for each inputs
             - **outs** (list of :any:`Buffer`) - Placeholder for each outputs

             **Returns**

             - **stmt** (:any:`Stmt`) - The statement that carries out array computation.
420 421 422 423 424 425 426 427

    name: str, optional
        The name hint of the tensor

    dtype: str or list of str, optional
        The data types of outputs,
        by default dtype will be same as inputs.

428 429 430 431 432 433
    in_buffers: Buffer or list of Buffer, optional
        Input buffers.

    out_buffers: Buffer or list of Buffers, optional
        Output buffers.

434 435 436 437 438 439 440

    tag: str, optional
        Additonal tag information about the compute.

    attrs: dict, optional
        The additional auxiliary attributes about the compute.

441 442 443 444
    Returns
    -------
    tensor: Tensor or list of Tensors
        The created tensor or tuple of tensors it it contains multiple outputs.
445 446 447 448 449 450 451 452 453 454 455 456 457 458

    Example
    -------
    In the code below, C is generated by calling external PackedFunc
    `tvm.contrib.cblas.matmul`

    .. code-block:: python

        A = tvm.placeholder((n, l), name='A')
        B = tvm.placeholder((l, m), name='B')
        C = tvm.extern((n, m), [A, B],
                       lambda ins, outs: tvm.call_packed(
                          "tvm.contrib.cblas.matmul",
                            ins[0], ins[1], outs[0], 0, 0), name="C")
459
    """
460
    if _tag.TagScope.get_current() is not None:
461 462
        if tag != "":
            raise ValueError("nested tag is not allowed for now")
463
        tag = _tag.TagScope.get_current().tag
464 465
    shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
    shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
466 467 468 469 470 471 472 473 474 475 476 477
    if in_buffers is not None:
        in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
        if len(inputs) != len(in_buffers):
            raise RuntimeError("Number of inputs and in_buffers mismatch: %d vs %d."
                               % (len(inputs), len(in_buffers)))
    if out_buffers is not None:
        out_buffers = [out_buffers] if not isinstance(out_buffers, list) else out_buffers
        if len(shape) != len(out_buffers):
            raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d."
                               % (len(shape), len(out_buffers)))
    input_placeholders = in_buffers or []
    output_placeholders = out_buffers or []
478 479 480 481
    types = set()
    for t in inputs:
        if not isinstance(t, _tensor.Tensor):
            raise ValueError("expect inputs to be tensor")
482 483 484
        if in_buffers is None:
            input_placeholders.append(
                decl_buffer(t.shape, t.dtype, t.op.name))
485 486 487 488 489 490 491
        types.add(t.dtype)

    if dtype is None:
        if len(types) != 1:
            raise ValueError("Cannot infer output type, please provide dtype argument")
        infered_type = types.pop()
        dtype = [infered_type for _ in shape]
492 493
    if isinstance(dtype, str):
        dtype = [dtype]
494

495 496 497
    if out_buffers is None:
        for shp, dt in zip(shape, dtype):
            output_placeholders.append(decl_buffer(shp, dt, name))
498 499 500 501
    body = fcompute(input_placeholders, output_placeholders)
    if isinstance(body, _expr.Expr):
        body = _make.Evaluate(body)

502 503
    op = _api_internal._ExternOp(name, tag, attrs,
                                 inputs, input_placeholders,
504
                                 output_placeholders, body)
505 506 507 508
    res = [op.output(i) for i in range(len(output_placeholders))]
    return res[0] if len(res) == 1 else res


509 510
def decl_buffer(shape,
                dtype=None,
511 512 513
                name="buffer",
                data=None,
                strides=None,
514
                elem_offset=None,
515
                scope="",
516
                data_alignment=-1,
517
                offset_factor=0):
518
    """Declare a new symbolic buffer.
519 520 521

    Normally buffer is created automatically during lower and build.
    This is only needed if user want to specify their own buffer layout.
522

523 524
    See the note below for detailed discussion on usage of buffer.

525 526 527 528 529 530 531 532 533 534 535
    Parameters
    ----------
    shape : tuple of Expr
        The shape of the buffer.

    dtype : str, optional
        The data type of the buffer.

    name : str, optional
        The name of the buffer.

536
    data : Var, optional
537 538 539 540 541
        The data pointer in the buffer.

    strides: array of Expr
        The stride of the buffer.

542 543 544
    elem_offset: Expr, optional
        The beginning offset of the array to data.
        In terms of number of elements of dtype.
545

546 547 548 549
    scope: str, optional
        The storage scope of the buffer, if not global.
        If scope equals empty string, it means it is global memory.

550 551
    data_alignment: int, optional
        The alignment of data pointer in bytes.
552
        If -1 is passed, the alignment will be set to TVM's internal default.
553 554 555 556 557 558

    offset_factor: int, optional
        The factor of elem_offset field, when set,
        elem_offset is required to be multiple of offset_factor.
        If 0 is pssed, the alignment will be set to 1.
        if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
559

560 561 562 563
    Returns
    -------
    buffer : Buffer
        The created buffer
564 565 566 567 568 569 570 571

    Note
    ----
    Buffer data structure reflects the DLTensor structure in dlpack.
    While DLTensor data structure is very general, it is usually helpful
    to create function that only handles specific case of data structure
    and make compiled function benefit from it.

572
    If user pass strides and elem_offset is passed as None
573 574 575 576
    when constructing the function, then the function will be specialized
    for the DLTensor that is compact and aligned.
    If user pass a fully generic symbolic array to the strides,
    then the resulting function becomes fully generic.
577
    """
578
    shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
579 580
    dtype = float32 if dtype is None else dtype
    strides = () if strides is None else strides
581
    if offset_factor != 0 and elem_offset is None:
582 583
        shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
        elem_offset = var('%s_elem_offset' % name, shape_dtype)
584
    if data is None:
585
        data = var(name, "handle")
586
    return _api_internal._Buffer(
587 588
        data, dtype, shape, strides, elem_offset, name, scope,
        data_alignment, offset_factor)
tqchen committed
589

590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
def layout(layout_str):
    """Create a layout node from a string.

    Parameters
    ----------
    layout_str : str
        A layout representation is composed of upper cases, lower cases and numbers,
        where upper case indicates a primal axis and
        the corresponding lower case with factor size indicates the subordinate axis.
        For example, NCHW16c can describe a 5-D tensor of
        [batch_size, channel, height, width, channel_block].
        Here subordinate axis channel_block=16 is the factor size of
        the primal axis C (channel).

    Returns
    -------
    layout : Layout
        The created layout
    """
    return _api_internal._Layout(layout_str)

def bijective_layout(src_layout, dst_layout):
    """Create a bijective layout mapping.

    Parameters
    ----------
    src_layout : str or Layout
        source layout.

    dst_layout : str or Layout
        destination layout.

    Returns
    -------
    bijective_layout : BijectiveLayout
        The created bijective layout
    """
    if isinstance(src_layout, str):
        src_layout = layout(src_layout)
    if isinstance(dst_layout, str):
        dst_layout = layout(dst_layout)
    return _api_internal._BijectiveLayout(src_layout, dst_layout)

633 634
def _IterVar(dom, name, iter_type, thread_tag=''):
    """Internal function to create IterVar
tqchen committed
635 636 637

    Parameters
    ----------
tqchen committed
638
    dom : Range
639
        The domain of iteration.
tqchen committed
640 641

    name : str
642 643 644 645
        The name of iteration variable.

    iter_type : int
        The type of iteration.
tqchen committed
646 647 648

    thread_tag : str
        The thread tag of the iteration variable.
tqchen committed
649 650 651

    Returns
    -------
tqchen committed
652 653
    iter_var : IterVar
       The result itervar
tqchen committed
654
    """
655 656 657
    if dom is not None:
        if isinstance(dom, (list, tuple)):
            if len(dom) != 2:
658
                raise TypeError("need to be list of ranges")
659 660
            dom = Range(dom[0], dom[1])

661
        if not isinstance(dom, _container.Range):
662
            raise TypeError("dom need to be Range")
663
    name = name if name else 'iter'
664 665
    v = var(name)
    return _api_internal._IterVar(dom, v, iter_type, thread_tag)
666 667


668
def thread_axis(dom=None, tag='', name=''):
669 670 671 672
    """Create a new IterVar to represent thread index.

    Parameters
    ----------
673 674 675
    dom : Range or str
        The domain of iteration
        When str is passed, dom is set to None and str is used as tag
676

677
    tag : str, optional
678 679 680 681
        The thread tag

    name : str, optional
        The name of the var.
682 683 684 685 686

    Returns
    -------
    axis : IterVar
        The thread itervar.
687
    """
688
    if isinstance(dom, string_types):
689
        tag, dom = dom, None
690
    if not tag:
691
        raise ValueError("tag must be given as Positional or keyword argument")
692 693 694 695 696 697 698 699 700 701 702 703 704 705
    name = name if name else tag
    return _IterVar(dom, name, 1, tag)


def reduce_axis(dom, name="rv"):
    """Create a new IterVar for reduction.

    Parameters
    ----------
    dom : Range
        The domain of iteration.

    name : str
        The name of the variable.
706 707 708 709 710

    Returns
    -------
    axis : IterVar
        An iteration variable representing the value.
711 712
    """
    return _IterVar(dom, name, 2)
tqchen committed
713

ziheng committed
714

ziheng committed
715 716
def comm_reducer(fcombine, fidentity, name="reduce"):
    """Create a commutative reducer for reduction.
tqchen committed
717 718 719

    Parameters
    ----------
ziheng committed
720 721
    fcombine : function(Expr -> Expr -> Expr)
        A binary function which takes two Expr as input to return a Expr.
722

ziheng committed
723 724
    fidentity : function(str -> Expr)
        A function which takes a type string as input to return a const Expr.
725 726 727

    Returns
    -------
ziheng committed
728
    reducer : function
729
        A function which creates a reduce expression over axis.
730
        There are two ways to use it:
731 732 733 734

        1. accept (expr, axis, where) to produce an Reduce Expr on
           specified axis;
        2. simply use it with multiple Exprs.
735

ziheng committed
736
    Example
737
    -------
ziheng committed
738
    .. code-block:: python
739

ziheng committed
740 741 742 743 744 745 746
        n = tvm.var('n')
        m = tvm.var('m')
        mysum = tvm.comm_reducer(lambda x, y: x+y,
            lambda t: tvm.const(0, dtype=t), name="mysum")
        A = tvm.placeholder((n, m), name='A')
        k = tvm.reduce_axis((0, m), name='k')
        B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name='B')
tqchen committed
747
    """
ziheng committed
748 749 750 751 752 753 754 755 756 757 758 759 760
    def _reduce_directly(*args):
        num = len(args)
        # process `where` is None
        if num == 3 and args[2] is None:
            num = 2
        res = args[0]
        for i in range(num-1):
            res = fcombine(res, args[i+1])
        return res

    def _make_reduce(expr, axis, where=None):
        code = fcombine.__code__
        assert fcombine.__code__.co_argcount == 2
761
        expr = convert(expr)
762
        if isinstance(expr, _container.Array):
763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788
            size = len(expr)
            larr = []
            rarr = []
            dtypes = []
            for i in range(size):
                dtype = expr[i].dtype
                dtypes.append(dtype)
                lname = code.co_varnames[0] + '_' + str(i)
                larr.append(var(lname, dtype))
                rname = code.co_varnames[1] + '_' + str(i)
                rarr.append(var(rname, dtype))
            lhs = convert(larr)
            rhs = convert(rarr)
            result = fcombine(lhs, rhs)
            id_elem = fidentity(*dtypes)
        else:
            assert isinstance(expr, _expr.Expr)
            size = 1
            dtype = expr.dtype
            lvar = var(code.co_varnames[0], dtype)
            rvar = var(code.co_varnames[1], dtype)
            result = [fcombine(lvar, rvar)]
            id_elem = [fidentity(dtype)]
            lhs = convert([lvar])
            rhs = convert([rvar])
            expr = convert([expr])
ziheng committed
789
        result = convert(result)
790 791
        id_elem = convert(id_elem)
        combiner = _make.CommReducer(lhs, rhs, result, id_elem)
792
        axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
793 794
        if where is None:
            where = convert(True)
795
        outputs = tuple(_expr.Reduce(combiner, expr, axis, where, i)
796 797
                        for i in range(size))
        return outputs[0] if size == 1 else outputs
ziheng committed
798

799
    # pylint: disable=keyword-arg-before-vararg
ziheng committed
800
    def reducer(expr, axis, where=None, *args):
801
        if isinstance(axis, (_schedule.IterVar, list, tuple)):
802
            assert not args
ziheng committed
803
            return _make_reduce(expr, axis, where)
804 805 806 807
        if where is None:
            assert not args
            return _reduce_directly(expr, axis)
        return _reduce_directly(expr, axis, where, *args)
ziheng committed
808 809

    doc_str = """Create a {0} expression over axis.
810

ziheng committed
811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826
              Parameters
              ----------
              expr : Expr
                  The source expression.
              axis : IterVar
                  The reduction IterVar axis
              where : optional, Expr
                  Filtering predicate of the reduction.
              Returns
              -------
              value : Expr
                  The result value.

              Example
              -------
              .. code-block:: python
827

ziheng committed
828 829 830 831 832 833 834
                m = tvm.var("m")
                n = tvm.var("n")
                A = tvm.placeholder((m, n), name="A")
                k = tvm.reduce_axis((0, n), name="k")

                # there are two way to use this {0} reducer:
                # mode 1, accept (expr, axis, where) to produce an Reduce Expr
835
                B = tvm.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B")
ziheng committed
836 837

                # mode 2, simply use it with multiple Exprs:
838
                {0}_res = tvm.{0}(m, n)
ziheng committed
839 840 841
              """
    reducer.__doc__ = doc_str.format(name)
    return reducer
842

tqchen committed
843

844
_init_api("tvm.api")
ziheng committed
845 846
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
847 848
min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max')