tensor_intrin.py 2.98 KB
Newer Older
1 2 3 4 5 6 7 8
"""Tensor intrinsics"""
from __future__ import absolute_import as _abs
from . import _api_internal
from . import api as _api
from . import expr as _expr
from . import stmt as _stmt
from . import make as _make
from . import tensor as _tensor
9
from .build_module import BuildConfig
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
from ._ffi.node import NodeBase, register_node

@register_node
class TensorIntrin(NodeBase):
    """Tensor intrinsic functions for certain computation.

    See Also
    --------
    decl_tensor_intrin: Construct a TensorIntrin
    """
    pass


def decl_tensor_intrin(op,
                       fcompute,
                       name="tensor_intrin",
                       binds=None):
    """Declare a tensor intrinsic function.

    Parameters
    ----------
    op: Operation
        The symbolic description of the intrinsic operation

    fcompute: lambda function of inputs, outputs-> stmt
        Specifies the IR statement to do the computation.
        See the following note for function signature of fcompute

        .. note::
             **Parameters**

             - **ins** (list of :any:`Buffer`) - Placeholder for each inputs
             - **outs** (list of :any:`Buffer`) - Placeholder for each outputs

             **Returns**

             - **stmt** (:any:`Stmt`, or tuple of three stmts)
             - If a single stmt is returned, it represents the body
             - If tuple of three stmts are returned they corresponds to body,
               reduce_init, reduce_update

    name: str, optional
        The name of the intrinsic.

    binds: dict of :any:`Tensor` to :any:`Buffer`, optional
        Dictionary that maps the Tensor to Buffer which specified the data layout
        requirement of the function. By default, a new compact buffer is created
        for each tensor in the argument.

    Returns
    -------
    intrin: TensorIntrin
        A TensorIntrin that can be used in tensorize schedule.
    """
    if not isinstance(op, _tensor.Operation):
        raise TypeError("expect Operation")
    inputs = op.input_tensors
    binds = binds if binds else {}
    tensors = [x for x in inputs]
    for i in range(op.num_outputs):
        tensors.append(op.output(i))

    binds_list = []
    for t in inputs:
        if not isinstance(t.op, _tensor.PlaceholderOp):
            raise ValueError("Donot yet support composition op")

    cfg = BuildConfig.current
    for t in tensors:
        buf = (binds[t] if t in binds else
               _api.decl_buffer(t.shape, t.dtype, t.op.name,
                                data_alignment=cfg.data_alignment,
                                offset_factor=cfg.offset_factor))
        binds_list.append(buf)

    body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
    if isinstance(body, (_expr.Expr, _stmt.Stmt)):
        body = [body]
    body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
    if len(body) < 3:
        body += [None] * (3 - len(body))
    return _api_internal._TensorIntrin(
        name, op, inputs, binds_list, *body)