tensor.py 4.22 KB
Newer Older
1
"""Tensor and Operation class for computation declaration."""
2
# pylint: disable=invalid-name
tqchen committed
3
from __future__ import absolute_import as _abs
4
from ._ffi.node import NodeBase, NodeGeneric, register_node, convert_to_node
5
from . import _api_internal
tqchen committed
6 7 8
from . import make as _make
from . import expr as _expr

9
class TensorSlice(NodeGeneric, _expr.ExprOp):
10 11
    """Auxiliary data structure for enable slicing syntax from tensor."""
    def __init__(self, tensor, indices):
tqchen committed
12 13
        if not isinstance(indices, tuple):
            indices = (indices,)
14 15 16 17
        self.tensor = tensor
        self.indices = indices

    def __getitem__(self, indices):
18 19
        if not isinstance(indices, tuple):
            indices = (indices,)
20 21
        return TensorSlice(self.tensor, self.indices + indices)

22 23 24 25
    def asnode(self):
        """Convert slice to node."""
        return self.tensor(*self.indices)

26 27 28 29 30 31
    @property
    def dtype(self):
        """Data content of the tensor."""
        return self.tensor.dtype


32
itervar_cls = None
33

tqchen committed
34 35 36 37 38 39 40
@register_node
class Tensor(NodeBase):
    """Tensor object, to construct, see function.Tensor"""
    def __call__(self, *indices):
        ndim = self.ndim
        if len(indices) != ndim:
            raise ValueError("Need to provide %d index in tensor slice" % ndim)
41
        indices = convert_to_node(indices)
tqchen committed
42 43
        args = []
        for x in indices:
44
            if isinstance(x, _expr.Expr):
tqchen committed
45
                args.append(x)
46 47
            elif isinstance(x, iter_var_cls):
                args.append(x.var)
tqchen committed
48 49 50
            else:
                raise ValueError("The indices must be expression")

51 52 53
        return _make.Call(self.dtype, self.op.name,
                          args, _expr.Call.Halide,
                          self.op, self.value_index)
tqchen committed
54

55 56 57
    def __getitem__(self, indices):
        return TensorSlice(self, indices)

58
    def __hash__(self):
59
        return _api_internal._TensorHash(self)
60 61 62 63

    def __eq__(self, other):
        if not isinstance(other, Tensor):
            return False
64
        return _api_internal._TensorEqual(self, other)
65

tqchen committed
66 67
    @property
    def ndim(self):
68
        """Dimension of the tensor."""
tqchen committed
69
        return len(self.shape)
70

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    @property
    def axis(self):
        """Axis of the tensor."""
        return self.__getattr__("axis")

    @property
    def op(self):
        """The corressponding :any:`Operation`."""
        return self.__getattr__("op")

    @property
    def value_index(self):
        """The output value index the tensor corressponds to."""
        return self.__getattr__("value_index")

    @property
    def shape(self):
        """The output shape of the tensor."""
        return self.__getattr__("shape")

91 92 93 94 95
    @property
    def name(self):
        op = self.op
        if op.num_outputs == 1:
            return op.name
96
        return "%s.v%d" % (op.name, self.value_index)
97

98 99

class Operation(NodeBase):
100
    """Represent an operation that generate a tensor"""
101 102 103 104 105 106 107 108 109 110 111 112 113
    def output(self, index):
        """Get the index-th output of the operation

        Parameters
        ----------
        index : int
            The index size.

        Returns
        -------
        out : Tensor
            The i-th output.
        """
114
        return _api_internal._OpGetOutput(self, index)
115

116 117 118 119
    @property
    def num_outputs(self):
        """Number of outputs of this op."""
        return _api_internal._OpNumOutputs(self)
120

121 122 123 124 125 126
    @property
    def input_tensors(self):
        """List of input tensors to this op."""
        return _api_internal._OpInputTensors(self)


127
@register_node
128 129 130 131
class PlaceholderOp(Operation):
    """Placeholder operation."""
    pass

132

133
@register_node
134
class ComputeOp(Operation):
135
    """Compute operation."""
136 137 138 139 140 141 142 143 144 145
    @property
    def axis(self):
        """Represent axis of IterVar, only defined when it is a ComputeOp"""
        return self.__getattr__("axis")

    @property
    def reduce_axis(self):
        """Represent axis of reductions, only defined when it is a ComputeOp"""
        return self.__getattr__("reduce_axis")

146 147

@register_node
148 149
class ScanOp(Operation):
    """Scan operation."""
150 151 152 153 154
    @property
    def scan_axis(self):
        """Represent axis of scan, only defined when it is a ScanOp"""
        return self.__getattr__("scan_axis")

155 156 157 158 159

@register_node
class ExternOp(Operation):
    """Extern operation."""
    pass