# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Tensor and Operation class for computation declaration.""" # pylint: disable=invalid-name from __future__ import absolute_import as _abs from ._ffi.node import NodeBase, NodeGeneric, register_node, convert_to_node from . import _api_internal from . import make as _make from . import expr as _expr class TensorSlice(NodeGeneric, _expr.ExprOp): """Auxiliary data structure for enable slicing syntax from tensor.""" def __init__(self, tensor, indices): if not isinstance(indices, tuple): indices = (indices,) self.tensor = tensor self.indices = indices def __getitem__(self, indices): if not isinstance(indices, tuple): indices = (indices,) return TensorSlice(self.tensor, self.indices + indices) def asnode(self): """Convert slice to node.""" return self.tensor(*self.indices) @property def dtype(self): """Data content of the tensor.""" return self.tensor.dtype @register_node class TensorIntrinCall(NodeBase): """Intermediate structure for calling a tensor intrinsic.""" itervar_cls = None @register_node class Tensor(NodeBase, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" def __call__(self, *indices): ndim = self.ndim if len(indices) != ndim: raise ValueError("Need to provide %d index in tensor slice" % ndim) indices = convert_to_node(indices) args = [] for x in indices: if isinstance(x, _expr.Expr): args.append(x) elif isinstance(x, iter_var_cls): args.append(x.var) else: raise ValueError("The indices must be expression") return _make.Call(self.dtype, self.op.name, args, _expr.Call.Halide, self.op, self.value_index) def __getitem__(self, indices): return TensorSlice(self, indices) def __hash__(self): return _api_internal._TensorHash(self) def __eq__(self, other): if not isinstance(other, Tensor): if isinstance(other, _expr.ExprOp): return _expr.EqualOp(self, other) return False if self.ndim == 0 and other.ndim == 0: raise ValueError("Equal == comparison among rank-0 tensor is ambiguous, " "use Tensor.equal for content expression equvalence, " "use Tensor.same_as for exact reference comparison") return _api_internal._TensorEqual(self, other) @property def ndim(self): """Dimension of the tensor.""" return len(self.shape) @property def axis(self): """Axis of the tensor.""" return self.__getattr__("axis") @property def op(self): """The corressponding :any:`Operation`.""" return self.__getattr__("op") @property def value_index(self): """The output value index the tensor corresponds to.""" return self.__getattr__("value_index") @property def shape(self): """The output shape of the tensor.""" return self.__getattr__("shape") @property def name(self): op = self.op if op.num_outputs == 1: return op.name return "%s.v%d" % (op.name, self.value_index) class Operation(NodeBase): """Represent an operation that generates a tensor""" def output(self, index): """Get the index-th output of the operation Parameters ---------- index : int The index size. Returns ------- out : Tensor The i-th output. """ return _api_internal._OpGetOutput(self, index) @property def num_outputs(self): """Number of outputs from this op.""" return _api_internal._OpNumOutputs(self) @property def input_tensors(self): """List of input tensors to this op.""" return _api_internal._OpInputTensors(self) @register_node class PlaceholderOp(Operation): """Placeholder operation.""" @register_node class BaseComputeOp(Operation): """Compute operation.""" @property def axis(self): """Represent the IterVar axis, defined when it is a ComputeOp""" return self.__getattr__("axis") @property def reduce_axis(self): """Represent axis of reductions, only defined when it is a ComputeOp""" return self.__getattr__("reduce_axis") @register_node class ComputeOp(BaseComputeOp): """Scalar operation.""" pass @register_node class TensorComputeOp(BaseComputeOp): """Tensor operation.""" @register_node class ScanOp(Operation): """Scan operation.""" @property def scan_axis(self): """Represent the scan axis, only defined when it is a ScanOp""" return self.__getattr__("scan_axis") @register_node class ExternOp(Operation): """External operation.""" @register_node class HybridOp(Operation): """Hybrid operation.""" @property def axis(self): """Represent the IterVar axis, also defined when it is a HybridOp""" return self.__getattr__("axis") @register_node class Layout(NodeBase): """Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). Do not construct directly, use :any:`layout` instead. See the documentation of :any:`layout` for more details. See Also -------- layout : Declare a layout """ def __len__(self): return _api_internal._LayoutNdim(self) def __contains__(self, axis): return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name def __getitem__(self, index): if index >= len(self): raise IndexError("Layout index out of range") return _api_internal._LayoutGetItem(self, index) def index_of(self, axis): """Get the index of an axis Parameters ---------- axis : str The axis name, need to be [a-z,A-Z] Returns ------- index : int The index of the axis, -1 if not found. """ return _api_internal._LayoutIndexOf(self, axis) def factor_of(self, axis): """Get the factor size of the subordinate axis. Parameters ---------- axis : str The axis name, need to be [a-z,A-Z] Returns ------- factor : int the size of the subordinate-axis of axis (if axis is a primal-axis), or the size of axis itself (if axis is a subordinate-axis). Return -1 if axis is not in the layout. """ return _api_internal._LayoutFactorOf(self, axis) @register_node class BijectiveLayout(NodeBase): """Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other. Do not construct directly, use :any:`bijective_layout` instead. See the documentation of :any:`bijective_layout` for more details. See Also -------- bijective_layout : Declare a bijective layout converter """ def forward_index(self, index): """Given the indices of the src-layout, infer the dst index. Parameters ---------- index: Array of Expr The indices in src-layout. Returns ------- dst_index: Array of Expr The inferred indices in dst-layout. """ return _api_internal._BijectiveLayoutForwardIndex(self, index) def backward_index(self, index): """Given the indices of the dst-layout, infer the src index. Parameters ---------- index: Array of Expr The indices in dst-layout. Returns ------- src_index: Array of Expr The inferred indices in src-layout. """ return _api_internal._BijectiveLayoutBackwardIndex(self, index) def forward_shape(self, shape): """Given the shape of the src-layout, infer the dst shape. Parameters ---------- shape: Array of Expr The shape in src-layout. Returns ------- dst_shape: Array of Expr The inferred shape in dst-layout. """ return _api_internal._BijectiveLayoutForwardShape(self, shape) def backward_shape(self, shape): """Given the shape of the dst-layout, infer the src shape. Parameters ---------- shape: Array of Expr The shape in dst-layout. Returns ------- src_shape: Array of Expr The inferred shape in src-layout. """ return _api_internal._BijectiveLayoutBackwardShape(self, shape)