# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Test layout and bijective-layout node""" import tvm from topi.util import get_const_tuple def test_layout(): layout = tvm.layout("NCHW16c") assert layout is not None assert isinstance(layout, tvm.tensor.Layout) assert layout.factor_of("c") == 16 assert layout.factor_of("C") == 16 assert layout.factor_of("N") == -1 assert layout.index_of("N") == 0 assert layout.index_of("C") == 1 assert layout.index_of("H") == 2 assert layout.index_of("W") == 3 assert layout.index_of("c") == 4 assert layout.index_of("O") == -1 assert "N" in layout assert "C" in layout assert "H" in layout assert "W" in layout assert "c" in layout assert "O" not in layout assert layout[0] == "N" assert layout[1] == "C" assert layout[2] == "H" assert layout[3] == "W" assert layout[4] == "c" assert layout[-1] == "c" def test_bilayout_convertible(): # not convertible assert tvm.bijective_layout("NCHW", "ABCD") is None assert tvm.bijective_layout("__undef__", "NCHW") is None assert tvm.bijective_layout("NCHW", "__undef__") is None assert tvm.bijective_layout("__undef__", "__undef__") is None assert tvm.bijective_layout("", "NCHW") is None assert tvm.bijective_layout("NCHW", "") is None assert tvm.bijective_layout("", "") is None # convertible assert tvm.bijective_layout("NCHW", "NCHW16c") is not None def test_bilayout_shape(): bilayout = tvm.bijective_layout("NCHW", "NCHW16c") assert isinstance(bilayout, tvm.tensor.BijectiveLayout) dst_shape = bilayout.forward_shape((1, 32, 7, 7)) assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16) src_shape = bilayout.backward_shape(dst_shape) assert get_const_tuple(src_shape) == (1, 32, 7, 7) def test_bilayout_index(): bilayout = tvm.bijective_layout("NCHW", "NCHW16c") dst_index = bilayout.forward_index([0, 18, 6, 6]) assert get_const_tuple(dst_index) == (0, 1, 6, 6, 2) src_index = bilayout.backward_index([0, 1, 6, 6, 2]) assert get_const_tuple(src_index) == (0, 18, 6, 6) if __name__ == "__main__": test_layout() test_bilayout_convertible() test_bilayout_shape() test_bilayout_index()