# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Data layout.""" import tvm._ffi from tvm.runtime import Object from . import _ffi_api @tvm._ffi.register_object class Layout(Object): """Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). See Also -------- layout : Declare a layout """ def __len__(self): return _ffi_api.LayoutNdim(self) def __contains__(self, axis): return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name def __getitem__(self, index): if index >= len(self): raise IndexError("Layout index out of range") return _ffi_api.LayoutGetItem(self, index) def index_of(self, axis): """Get the index of an axis Parameters ---------- axis : str The axis name, need to be [a-z,A-Z] Returns ------- index : int The index of the axis, -1 if not found. """ return _ffi_api.LayoutIndexOf(self, axis) def factor_of(self, axis): """Get the factor size of the subordinate axis. Parameters ---------- axis : str The axis name, need to be [a-z,A-Z] Returns ------- factor : int the size of the subordinate-axis of axis (if axis is a primal-axis), or the size of axis itself (if axis is a subordinate-axis). Return -1 if axis is not in the layout. """ return _ffi_api.LayoutFactorOf(self, axis) @tvm._ffi.register_object class BijectiveLayout(Object): """Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other. Do not construct directly, use :any:`bijective_layout` instead. See the documentation of :any:`bijective_layout` for more details. Parameters ---------- src_layout : str or Layout source layout. dst_layout : str or Layout destination layout. See Also -------- bijective_layout : Declare a layout """ def forward_index(self, index): """Given the indices of the src-layout, infer the dst index. Parameters ---------- index: Array of Expr The indices in src-layout. Returns ------- dst_index: Array of Expr The inferred indices in dst-layout. """ return _ffi_api.BijectiveLayoutForwardIndex(self, index) def backward_index(self, index): """Given the indices of the dst-layout, infer the src index. Parameters ---------- index: Array of Expr The indices in dst-layout. Returns ------- src_index: Array of Expr The inferred indices in src-layout. """ return _ffi_api.BijectiveLayoutBackwardIndex(self, index) def forward_shape(self, shape): """Given the shape of the src-layout, infer the dst shape. Parameters ---------- shape: Array of Expr The shape in src-layout. Returns ------- dst_shape: Array of Expr The inferred shape in dst-layout. """ return _ffi_api.BijectiveLayoutForwardShape(self, shape) def backward_shape(self, shape): """Given the shape of the dst-layout, infer the src shape. Parameters ---------- shape: Array of Expr The shape in dst-layout. Returns ------- src_shape: Array of Expr The inferred shape in src-layout. """ return _ffi_api.BijectiveLayoutBackwardShape(self, shape) def layout(layout_str): """Create a layout node from a string. Parameters ---------- layout_str : str A layout representation is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. For example, NCHW16c can describe a 5-D tensor of [batch_size, channel, height, width, channel_block]. Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). Returns ------- layout : Layout The created layout """ return _ffi_api.Layout(layout_str) def bijective_layout(src_layout, dst_layout): """Create a bijective layout mapping. Parameters ---------- src_layout : str or Layout source layout. dst_layout : str or Layout destination layout. Returns ------- bijective_layout : BijectiveLayout The created bijective layout """ if isinstance(src_layout, str): src_layout = layout(src_layout) if isinstance(dst_layout, str): dst_layout = layout(dst_layout) return _ffi_api.BijectiveLayout(src_layout, dst_layout)