api.py 23.1 KB
Newer Older
1
"""Functions defined in TVM."""
2
# pylint: disable=invalid-name,unused-import,redefined-builtin
3
from __future__ import absolute_import as _abs
4

5
from numbers import Integral as _Integral
6

7
from ._ffi.base import string_types
8 9 10
from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import Function
11
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
12
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
13
from ._ffi.runtime_ctypes import TVMType
14
from . import _api_internal
tqchen committed
15 16
from . import make as _make
from . import expr as _expr
17
from . import tensor as _tensor
18
from . import schedule as _schedule
19
from . import container as _container
20
from . import tag as _tag
21

22
int8 = "int8"
23 24 25
int32 = "int32"
float32 = "float32"
handle = "handle"
26

ziheng committed
27 28

def min_value(dtype):
29
    """minimum value of dtype"""
ziheng committed
30 31 32 33
    return _api_internal._min_value(dtype)


def max_value(dtype):
34
    """maximum value of dtype"""
ziheng committed
35 36 37
    return _api_internal._max_value(dtype)


38
def const(value, dtype=None):
tqchen committed
39
    """construct a constant"""
40 41 42 43 44
    if dtype is None:
        if isinstance(value, _Integral):
            dtype = 'int32'
        else:
            dtype = 'float32'
45
    return _api_internal._const(value, dtype)
46 47


48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
def get_env_func(name):
    """Get an EnvFunc by a global name.

    Parameters
    ----------
    name: str
        The name of the global function.

    Returns
    -------
    env_func : EnvFunc
        The result env function.

    Note
    ----
    EnvFunc is a Node wrapper around
    global function that can be serialized via its name.
    This can be used to serialize function field in the language.
    """
    return _api_internal._EnvFuncGet(name)


70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
def convert(value):
    """Convert value to TVM node or function.

    Parameters
    ----------
    value : python value

    Returns
    -------
    tvm_val : Node or Function
        Converted value in TVM
    """
    if isinstance(value, (Function, NodeBase)):
        return value

    if callable(value):
        return _convert_tvm_func(value)
87 88

    return _convert_to_node(value)
89 90


91 92 93 94 95 96 97 98 99 100 101 102 103
def load_json(json_str):
    """Load tvm object from json_str.

    Parameters
    ----------
    json_str : str
        The json string

    Returns
    -------
    node : Node
        The loaded tvm node.
    """
104
    return _api_internal._load_json(json_str)
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119


def save_json(node):
    """Load tvm object as json string.

    Parameters
    ----------
    node : Node
        A TVM Node object to be saved.

    Returns
    -------
    json_str : str
        Saved json string.
    """
120
    return _api_internal._save_json(node)
121 122


123
def var(name="tindex", dtype=int32):
tqchen committed
124 125 126 127 128 129 130 131 132
    """Create a new variable with specified name and dtype

    Parameters
    ----------
    name : str
        The name

    dtype : int
        The data type
133 134 135 136 137

    Returns
    -------
    var : Var
        The result symbolic variable.
tqchen committed
138
    """
139
    return _api_internal._Var(name, dtype)
tqchen committed
140 141


142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
def any(*args):
    """Create a new experssion of the union of all conditions in the arguments

    Parameters
    ----------
    args : list
        List of symbolic boolean expressions

    Returns
    -------
    expr: Expr
        Expression
    """
    if not args:
        raise ValueError("Any must take at least 1 argument")
    if len(args) == 1:
        return args[0]
159
    ret = _make._OpOr(args[0], args[1])
160
    for i in range(2, len(args)):
161
        ret = _make._OpOr(ret, args[i])
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
    return ret


def all(*args):
    """Create a new experssion of the intersection of all conditions in the
      arguments

    Parameters
    ----------
    args : list
        List of symbolic boolean expressions

    Returns
    -------
    expr: Expr
        Expression
    """
    if not args:
        raise ValueError("Any must take at least 1 argument")
    if len(args) == 1:
        return args[0]
183
    ret = _make._OpAnd(args[0], args[1])
184
    for i in range(2, len(args)):
185
        ret = _make._OpAnd(ret, args[i])
186 187 188
    return ret


189
def placeholder(shape, dtype=None, name="placeholder"):
tqchen committed
190
    """Construct an empty tensor object.
191 192 193

    Parameters
    ----------
tqchen committed
194 195 196
    shape: Tuple of Expr
        The shape of the tensor

tqchen committed
197 198 199 200 201 202 203 204
    dtype: str, optional
        The data type of the tensor

    name: str, optional
        The name hint of the tensor

    Returns
    -------
205
    tensor: Tensor
tqchen committed
206 207
        The created tensor
    """
208
    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
tqchen committed
209
    dtype = float32 if dtype is None else dtype
210
    return _api_internal._Placeholder(
211
        shape, dtype, name)
tqchen committed
212 213


214
def compute(shape, fcompute, name="compute", tag="", attrs=None):
tqchen committed
215 216 217 218 219 220 221 222 223
    """Construct a new tensor by computing over the shape domain.

    The compute rule is result[axis] = fcompute(axis)

    Parameters
    ----------
    shape: Tuple of Expr
        The shape of the tensor

224
    fcompute: lambda function of indices-> value
tqchen committed
225
        Specifies the input source expression
226

tqchen committed
227 228 229
    name: str, optional
        The name hint of the tensor

230 231 232 233 234 235
    tag: str, optional
        Additonal tag information about the compute.

    attrs: dict, optional
        The additional auxiliary attributes about the compute.

tqchen committed
236 237
    Returns
    -------
238
    tensor: Tensor
tqchen committed
239
        The created tensor
240
    """
241
    if _tag.TagScope.get_current() is not None:
242 243
        if tag != "":
            raise ValueError("nested tag is not allowed for now")
244
        tag = _tag.TagScope.get_current().tag
245
    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
246 247
    # for python3
    shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
tqchen committed
248
    ndim = len(shape)
249
    code = fcompute.__code__
250

251 252
    out_ndim = ndim
    if code.co_argcount == 0:
253
        arg_names = ["i%d" % i for i in range(ndim)]
254 255
    else:
        arg_names = code.co_varnames[:code.co_argcount]
256
        out_ndim = code.co_argcount
257

258
    if out_ndim != len(arg_names):
259
        raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
260

261
    dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
262
    body = fcompute(*[v.var for v in dim_var])
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282

    if isinstance(body, _tensor.TensorIntrinCall):
        for i, s in enumerate(shape[out_ndim:]):
            var_name = "ax" + str(i)
            dim_var.append(_IterVar((0, s), var_name, 4))
        op_node = _api_internal._TensorComputeOp(name,
                                                 tag,
                                                 dim_var,
                                                 body.reduce_axis,
                                                 out_ndim,
                                                 body.intrin,
                                                 body.tensors,
                                                 body.regions)
    else:
        if not isinstance(body, (list, tuple)):
            body = [body]
        body = convert(body)
        op_node = _api_internal._ComputeOp(
            name, tag, attrs, dim_var, body)

283 284 285
    num = op_node.num_outputs
    outputs = tuple(op_node.output(i) for i in range(num))
    return outputs[0] if num == 1 else outputs
286 287


288
def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attrs=None):
289 290 291 292 293 294 295 296 297 298 299 300 301
    """Construct new tensors by scanning over axis.

    Parameters
    ----------
    init: Tensor or list of Tensor
        The initial condition of first init.shape[0] timestamps

    update: Tensor or list of Tensor
        The update rule of the scan given by symbolic tensor.

    state_placeholder: Tensor or list of Tensor
        The placeholder variables used by update.

302 303 304 305
    inputs: Tensor or list of Tensor, optional
        The list of inputs to the scan. This is not required, but can
        be useful for the compiler to detect scan body faster.

306 307 308
    name: str, optional
        The name hint of the tensor

309 310 311 312 313 314
    tag: str, optional
        Additonal tag information about the compute.

    attrs: dict, optional
        The additional auxiliary attributes about the compute.

315 316
    Returns
    -------
317 318
    tensor: Tensor or list of Tensors
        The created tensor or tuple of tensors it it contains multiple outputs.
319 320 321

    Example
    -------
322 323 324
    .. code-block:: python

      # The following code is equivalent to numpy.cumsum
325 326
      m = tvm.var("m")
      n = tvm.var("n")
327 328 329 330 331
      X = tvm.placeholder((m, n), name="X")
      s_state = tvm.placeholder((m, n))
      s_init = tvm.compute((1, n), lambda _, i: X[0, i])
      s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
      res = tvm.scan(s_init, s_update, s_state, X)
332
    """
333
    if _tag.TagScope.get_current() is not None:
334 335
        if tag != "":
            raise ValueError("nested tag is not allowed for now")
336
        tag = _tag.TagScope.get_current().tag
337 338 339 340 341 342
    if isinstance(init, _tensor.Tensor):
        init = [init]
    if isinstance(update, _tensor.Tensor):
        update = [update]
    if isinstance(state_placeholder, _tensor.Tensor):
        state_placeholder = [state_placeholder]
343 344 345 346
    if isinstance(inputs, _tensor.Tensor):
        inputs = [inputs]
    if inputs is None:
        inputs = []
347 348
    if len(init) != len(update) or len(init) != len(state_placeholder):
        raise ValueError("init, update, state_placeholder must have same length")
349
    axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
350 351
    op = _api_internal._ScanOp(name, tag, attrs,
                               axis, init, update,
352
                               state_placeholder, inputs)
353
    res = [op.output(i) for i in range(len(update))]
354 355
    return res[0] if len(res) == 1 else res

356

357 358 359 360 361 362 363
def extern(shape,
           inputs,
           fcompute,
           name="extern",
           dtype=None,
           in_buffers=None,
           out_buffers=None,
364 365
           tag="",
           attrs=None):
366 367 368 369
    """Compute several tensor via extern function.

    Parameters
    ----------
370
    shape: tuple or list of tuples.
371 372 373 374 375 376 377
        The shape of the outputs.

    inputs: list of Tensor
        The inputs

    fcompute: lambda function of inputs, outputs-> stmt
        Specifies the IR statement to do the computation.
378 379 380 381 382 383 384 385 386 387 388
        See the following note for function signature of fcompute

        .. note::
             **Parameters**

             - **ins** (list of :any:`Buffer`) - Placeholder for each inputs
             - **outs** (list of :any:`Buffer`) - Placeholder for each outputs

             **Returns**

             - **stmt** (:any:`Stmt`) - The statement that carries out array computation.
389 390 391 392 393 394 395 396

    name: str, optional
        The name hint of the tensor

    dtype: str or list of str, optional
        The data types of outputs,
        by default dtype will be same as inputs.

397 398 399 400 401 402
    in_buffers: Buffer or list of Buffer, optional
        Input buffers.

    out_buffers: Buffer or list of Buffers, optional
        Output buffers.

403 404 405 406 407 408 409

    tag: str, optional
        Additonal tag information about the compute.

    attrs: dict, optional
        The additional auxiliary attributes about the compute.

410 411 412 413
    Returns
    -------
    tensor: Tensor or list of Tensors
        The created tensor or tuple of tensors it it contains multiple outputs.
414 415 416 417 418 419 420 421 422 423 424 425 426 427

    Example
    -------
    In the code below, C is generated by calling external PackedFunc
    `tvm.contrib.cblas.matmul`

    .. code-block:: python

        A = tvm.placeholder((n, l), name='A')
        B = tvm.placeholder((l, m), name='B')
        C = tvm.extern((n, m), [A, B],
                       lambda ins, outs: tvm.call_packed(
                          "tvm.contrib.cblas.matmul",
                            ins[0], ins[1], outs[0], 0, 0), name="C")
428
    """
429
    if _tag.TagScope.get_current() is not None:
430 431
        if tag != "":
            raise ValueError("nested tag is not allowed for now")
432
        tag = _tag.TagScope.get_current().tag
433 434
    shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
    shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
435 436 437 438 439 440 441 442 443 444 445 446
    if in_buffers is not None:
        in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
        if len(inputs) != len(in_buffers):
            raise RuntimeError("Number of inputs and in_buffers mismatch: %d vs %d."
                               % (len(inputs), len(in_buffers)))
    if out_buffers is not None:
        out_buffers = [out_buffers] if not isinstance(out_buffers, list) else out_buffers
        if len(shape) != len(out_buffers):
            raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d."
                               % (len(shape), len(out_buffers)))
    input_placeholders = in_buffers or []
    output_placeholders = out_buffers or []
447 448 449 450
    types = set()
    for t in inputs:
        if not isinstance(t, _tensor.Tensor):
            raise ValueError("expect inputs to be tensor")
451 452 453
        if in_buffers is None:
            input_placeholders.append(
                decl_buffer(t.shape, t.dtype, t.op.name))
454 455 456 457 458 459 460
        types.add(t.dtype)

    if dtype is None:
        if len(types) != 1:
            raise ValueError("Cannot infer output type, please provide dtype argument")
        infered_type = types.pop()
        dtype = [infered_type for _ in shape]
461 462
    if isinstance(dtype, str):
        dtype = [dtype]
463

464 465 466
    if out_buffers is None:
        for shp, dt in zip(shape, dtype):
            output_placeholders.append(decl_buffer(shp, dt, name))
467 468 469 470
    body = fcompute(input_placeholders, output_placeholders)
    if isinstance(body, _expr.Expr):
        body = _make.Evaluate(body)

471 472
    op = _api_internal._ExternOp(name, tag, attrs,
                                 inputs, input_placeholders,
473
                                 output_placeholders, body)
474 475 476 477
    res = [op.output(i) for i in range(len(output_placeholders))]
    return res[0] if len(res) == 1 else res


478 479
def decl_buffer(shape,
                dtype=None,
480 481 482
                name="buffer",
                data=None,
                strides=None,
483
                elem_offset=None,
484
                scope="",
485
                data_alignment=-1,
486
                offset_factor=0):
487 488 489 490
    """Decleare a new symbolic buffer.

    Normally buffer is created automatically during lower and build.
    This is only needed if user want to specify their own buffer layout.
491

492 493
    See the note below for detailed discussion on usage of buffer.

494 495 496 497 498 499 500 501 502 503 504
    Parameters
    ----------
    shape : tuple of Expr
        The shape of the buffer.

    dtype : str, optional
        The data type of the buffer.

    name : str, optional
        The name of the buffer.

505
    data : Var, optional
506 507 508 509 510
        The data pointer in the buffer.

    strides: array of Expr
        The stride of the buffer.

511 512 513
    elem_offset: Expr, optional
        The beginning offset of the array to data.
        In terms of number of elements of dtype.
514

515 516 517 518
    scope: str, optional
        The storage scope of the buffer, if not global.
        If scope equals empty string, it means it is global memory.

519 520
    data_alignment: int, optional
        The alignment of data pointer in bytes.
521
        If -1 is passed, the alignment will be set to TVM's internal default.
522 523 524 525 526 527

    offset_factor: int, optional
        The factor of elem_offset field, when set,
        elem_offset is required to be multiple of offset_factor.
        If 0 is pssed, the alignment will be set to 1.
        if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
528

529 530 531 532
    Returns
    -------
    buffer : Buffer
        The created buffer
533 534 535 536 537 538 539 540

    Note
    ----
    Buffer data structure reflects the DLTensor structure in dlpack.
    While DLTensor data structure is very general, it is usually helpful
    to create function that only handles specific case of data structure
    and make compiled function benefit from it.

541
    If user pass strides and elem_offset is passed as None
542 543 544 545
    when constructing the function, then the function will be specialized
    for the DLTensor that is compact and aligned.
    If user pass a fully generic symbolic array to the strides,
    then the resulting function becomes fully generic.
546
    """
547
    shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
548 549
    dtype = float32 if dtype is None else dtype
    strides = () if strides is None else strides
550
    if offset_factor != 0 and elem_offset is None:
551 552
        shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
        elem_offset = var('%s_elem_offset' % name, shape_dtype)
553
    if data is None:
554
        data = var(name, "handle")
555
    return _api_internal._Buffer(
556 557
        data, dtype, shape, strides, elem_offset, name, scope,
        data_alignment, offset_factor)
tqchen committed
558

559 560
def _IterVar(dom, name, iter_type, thread_tag=''):
    """Internal function to create IterVar
tqchen committed
561 562 563

    Parameters
    ----------
tqchen committed
564
    dom : Range
565
        The domain of iteration.
tqchen committed
566 567

    name : str
568 569 570 571
        The name of iteration variable.

    iter_type : int
        The type of iteration.
tqchen committed
572 573 574

    thread_tag : str
        The thread tag of the iteration variable.
tqchen committed
575 576 577

    Returns
    -------
tqchen committed
578 579
    iter_var : IterVar
       The result itervar
tqchen committed
580
    """
581 582 583
    if dom is not None:
        if isinstance(dom, (list, tuple)):
            if len(dom) != 2:
584
                raise TypeError("need to be list of ranges")
585 586
            dom = Range(dom[0], dom[1])

587
        if not isinstance(dom, _container.Range):
588
            raise TypeError("dom need to be Range")
589
    name = name if name else 'iter'
590 591
    v = var(name)
    return _api_internal._IterVar(dom, v, iter_type, thread_tag)
592 593


594
def thread_axis(dom=None, tag='', name=''):
595 596 597 598
    """Create a new IterVar to represent thread index.

    Parameters
    ----------
599 600 601
    dom : Range or str
        The domain of iteration
        When str is passed, dom is set to None and str is used as tag
602

603
    tag : str, optional
604 605 606 607
        The thread tag

    name : str, optional
        The name of the var.
608 609 610 611 612

    Returns
    -------
    axis : IterVar
        The thread itervar.
613
    """
614
    if isinstance(dom, string_types):
615
        tag, dom = dom, None
616
    if not tag:
617
        raise ValueError("tag must be given as Positional or keyword argument")
618 619 620 621 622 623 624 625 626 627 628 629 630 631
    name = name if name else tag
    return _IterVar(dom, name, 1, tag)


def reduce_axis(dom, name="rv"):
    """Create a new IterVar for reduction.

    Parameters
    ----------
    dom : Range
        The domain of iteration.

    name : str
        The name of the variable.
632 633 634 635 636

    Returns
    -------
    axis : IterVar
        An iteration variable representing the value.
637 638
    """
    return _IterVar(dom, name, 2)
tqchen committed
639

ziheng committed
640

Yizhi Liu committed
641
def select(cond, t, f):
642 643
    """Construct a select branch.

Yizhi Liu committed
644 645 646 647
    Parameters
    ----------
    cond : Expr
        The condition
648

Yizhi Liu committed
649 650
    t : Expr
        The result expression if cond is true.
651

Yizhi Liu committed
652 653 654 655 656 657 658 659
    f : Expr
        The result expression if cond is false.

    Returns
    -------
    node : Node
        The tvm.expr.Select node
    """
660
    return _expr.Select(convert(cond), convert(t), convert(f))
tqchen committed
661

662

ziheng committed
663 664
def comm_reducer(fcombine, fidentity, name="reduce"):
    """Create a commutative reducer for reduction.
tqchen committed
665 666 667

    Parameters
    ----------
ziheng committed
668 669
    fcombine : function(Expr -> Expr -> Expr)
        A binary function which takes two Expr as input to return a Expr.
670

ziheng committed
671 672
    fidentity : function(str -> Expr)
        A function which takes a type string as input to return a const Expr.
673 674 675

    Returns
    -------
ziheng committed
676
    reducer : function
677
        A function which creates a reduce expression over axis.
678
        There are two ways to use it:
679 680 681 682

        1. accept (expr, axis, where) to produce an Reduce Expr on
           specified axis;
        2. simply use it with multiple Exprs.
683

ziheng committed
684
    Example
685
    -------
ziheng committed
686
    .. code-block:: python
687

ziheng committed
688 689 690 691 692 693 694
        n = tvm.var('n')
        m = tvm.var('m')
        mysum = tvm.comm_reducer(lambda x, y: x+y,
            lambda t: tvm.const(0, dtype=t), name="mysum")
        A = tvm.placeholder((n, m), name='A')
        k = tvm.reduce_axis((0, m), name='k')
        B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name='B')
tqchen committed
695
    """
ziheng committed
696 697 698 699 700 701 702 703 704 705 706 707 708
    def _reduce_directly(*args):
        num = len(args)
        # process `where` is None
        if num == 3 and args[2] is None:
            num = 2
        res = args[0]
        for i in range(num-1):
            res = fcombine(res, args[i+1])
        return res

    def _make_reduce(expr, axis, where=None):
        code = fcombine.__code__
        assert fcombine.__code__.co_argcount == 2
709
        expr = convert(expr)
710
        if isinstance(expr, _container.Array):
711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736
            size = len(expr)
            larr = []
            rarr = []
            dtypes = []
            for i in range(size):
                dtype = expr[i].dtype
                dtypes.append(dtype)
                lname = code.co_varnames[0] + '_' + str(i)
                larr.append(var(lname, dtype))
                rname = code.co_varnames[1] + '_' + str(i)
                rarr.append(var(rname, dtype))
            lhs = convert(larr)
            rhs = convert(rarr)
            result = fcombine(lhs, rhs)
            id_elem = fidentity(*dtypes)
        else:
            assert isinstance(expr, _expr.Expr)
            size = 1
            dtype = expr.dtype
            lvar = var(code.co_varnames[0], dtype)
            rvar = var(code.co_varnames[1], dtype)
            result = [fcombine(lvar, rvar)]
            id_elem = [fidentity(dtype)]
            lhs = convert([lvar])
            rhs = convert([rvar])
            expr = convert([expr])
ziheng committed
737
        result = convert(result)
738 739
        id_elem = convert(id_elem)
        combiner = _make.CommReducer(lhs, rhs, result, id_elem)
740
        axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
741 742
        if where is None:
            where = convert(True)
743
        outputs = tuple(_expr.Reduce(combiner, expr, axis, where, i)
744 745
                        for i in range(size))
        return outputs[0] if size == 1 else outputs
ziheng committed
746

747
    # pylint: disable=keyword-arg-before-vararg
ziheng committed
748
    def reducer(expr, axis, where=None, *args):
749
        if isinstance(axis, (_schedule.IterVar, list, tuple)):
750
            assert not args
ziheng committed
751
            return _make_reduce(expr, axis, where)
752 753 754 755
        if where is None:
            assert not args
            return _reduce_directly(expr, axis)
        return _reduce_directly(expr, axis, where, *args)
ziheng committed
756 757

    doc_str = """Create a {0} expression over axis.
758

ziheng committed
759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774
              Parameters
              ----------
              expr : Expr
                  The source expression.
              axis : IterVar
                  The reduction IterVar axis
              where : optional, Expr
                  Filtering predicate of the reduction.
              Returns
              -------
              value : Expr
                  The result value.

              Example
              -------
              .. code-block:: python
775

ziheng committed
776 777 778 779 780 781 782
                m = tvm.var("m")
                n = tvm.var("n")
                A = tvm.placeholder((m, n), name="A")
                k = tvm.reduce_axis((0, n), name="k")

                # there are two way to use this {0} reducer:
                # mode 1, accept (expr, axis, where) to produce an Reduce Expr
783
                B = tvm.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B")
ziheng committed
784 785

                # mode 2, simply use it with multiple Exprs:
786
                {0}_res = tvm.{0}(m, n)
ziheng committed
787 788 789
              """
    reducer.__doc__ = doc_str.format(name)
    return reducer
790

tqchen committed
791

792
_init_api("tvm.api")
ziheng committed
793 794
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
795 796
min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max')