infer_layout_util.h 9.78 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22
 * \file infer_layout_util.h
 * \brief Utility functions to alter the layouts of operators or replace primitive operators with
23 24 25 26
          other expressions. This pass can be used for computing convolution in
          custom layouts or other general weight pre-transformation.
 */

27 28
#ifndef TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_
#define TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_
29

30
#include <tvm/tir/data_layout.h>
31
#include <tvm/relay/expr.h>
32
#include <tvm/relay/op_attr_types.h>
33
#include <string>
34 35
#include <tuple>
#include "pattern_util.h"
36 37 38 39 40

namespace tvm {
namespace relay {

/*!
41 42 43 44 45 46 47
 * \brief Returns a new layout where the subordinate factors are adjusted based on the tensor
 *        shape.
 * \param old_layout The old layout before any transformation.
 * \param old_shape The shape of the original tensor.
 * \return The adjusted Layout.
 */
inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layout,
48
                                       const Array<tvm::PrimExpr>& old_shape) {
49 50 51 52 53 54 55 56 57
  // For each subordinate axis
  //   1) Find the corresponding dual axis.
  //   2) Find the Index of this dual axis in old_layout.
  //   3) Find the shape of the that axis in old_shape.
  //   4) a) Adjust factor to 1, if that shape is 1. b) Else retain the factor.
  std::string new_layout;
  for (auto axis : src_layout->axes) {
    if (!LayoutAxis::Get(axis).IsPrimal()) {
      // 1) Find the corresponding dual axis
58
      const auto& dual_axis = LayoutAxis::Get(axis).ToPrimal();
59 60

      // 2) Find the index of this dual axis in old_layout
61
      int old_axis = old_layout.IndexOf(dual_axis);
62 63 64 65 66 67

      // 3) Find the shape of this index in old_shape
      auto shape_val = old_shape[old_axis];

      // 4) a) Check if this shape element is 1.
      bool is_shape_one = false;
68
      if (auto* shape_int = shape_val.as<IntImmNode>()) {
69 70 71 72 73 74 75 76
        if (shape_int->value == 1) {
          new_layout += "1";
          is_shape_one = true;
        }
      }

      // 4) b) If shape is not 1, retain the factor.
      if (!is_shape_one) {
77
        auto new_shape_val = src_layout.FactorOf(dual_axis);
78 79 80 81 82 83 84 85 86
        new_layout += std::to_string(new_shape_val);
      }
    }
    new_layout += LayoutAxis::Get(axis).name();
  }
  return Layout(new_layout);
}

/*!
87 88 89 90 91 92
 * \brief Infer & correct function of node layout. See \p Layout for layout convention
 * \param attrs The attribute of the node.
 * \param new_in_layouts The layouts of input arguments after alter_op_layout.
 *                       This can be undefined, which means we call this function before alternating
 *                       any operators.
 * \param old_in_layouts The layouts of input arguments before alter_op_layout.
93
 * \param old_in_types The types of old input arguments.
94 95 96 97 98 99 100
 * \return infered_layout An array of two elements that are inferred input layouts and
 *                        inferred output layouts.
 */
using FInferCorrectLayout = runtime::TypedPackedFunc<
    Array<Array<Layout>>(const Attrs& attrs,
                         const Array<Layout>& new_in_layouts,
                         const Array<Layout>& old_in_layouts,
101
                         const Array<tvm::relay::Type> &old_in_types)>;
102 103 104 105 106

/*! \brief take arbitrary input layout and copy to output */
inline Array<Array<Layout> > ElemwiseArbitraryLayout(const Attrs& attrs,
                                                     const Array<Layout>& new_in_layouts,
                                                     const Array<Layout>& old_in_layouts,
107
                                                     const Array<tvm::relay::Type> &old_in_types) {
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
  Layout ret;

  if (new_in_layouts.defined()) {
    CHECK_GE(new_in_layouts.size(), 1);
    ret = new_in_layouts[0];
  } else {
    for (size_t i = 0; i < old_in_layouts.size(); ++i) {
      if (old_in_layouts[i].defined()) {
        ret = old_in_layouts[i];
        break;
      }
    }
  }

  return Array<Array<Layout> >{Array<Layout>(old_in_layouts.size(), ret), {ret}};
}

/*! \brief Infer layout for binary broadcast operators */
inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
                                                   const Array<Layout>& new_in_layouts,
                                                   const Array<Layout>& old_in_layouts,
129
                                                   const Array<tvm::relay::Type> &old_in_types) {
130
  Array<Layout> layouts;
131 132 133 134 135
  Array<Array<IndexExpr>> old_in_shapes;
  for (auto old_in_t : old_in_types) {
    CHECK(old_in_t.as<TensorTypeNode>());
    old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
  }
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152

  if (new_in_layouts.defined()) {
    layouts.assign(new_in_layouts.begin(), new_in_layouts.end());
  } else {
    layouts.assign(old_in_layouts.begin(), old_in_layouts.end());
  }

  if (!layouts[0].defined() && !layouts[1].defined()) {
    // both undefined, infer fails
    return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
  } else if (!layouts[0].defined() || !layouts[1].defined()) {
    // only one is defined, use shape information to help infer
    int defined_idx = layouts[0].defined() ? 0 : 1;
    int undef_idx = 1 - defined_idx;

    if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
      layouts.Set(undef_idx,
153
                  layouts[defined_idx].SubLayout(
154 155 156
                      old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
                      old_in_shapes[undef_idx].size()));
      return Array<Array<Layout> >{layouts, {layouts[defined_idx]}};
157 158 159 160
    } else {
      // only know the tensor with smaller dimensions,
      // so we cannot infer the final broadcasted output.
      // fails in this case.
161
      return Array<Array<Layout> >{{Layout::Undef()}, {Layout::Undef()}};
162
    }
163 164 165 166
  } else if (layouts[0].defined() && layouts[1].defined() &&
            (layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) {
    int scalar = layouts[0].ndim() == 0 ? 0 : 1;
    return Array<Array<Layout> >{layouts, {layouts[1-scalar]}};
167
  } else {
168 169
    // Set the layout of the larger dimension. If one dimension size is lower, we call expand dims
    // while transforming layout.
170
    int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
171 172 173
    int small_idx = 1 - large_idx;
    Layout ret = layouts[large_idx];

174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
    if (old_in_layouts[0].Equals(old_in_layouts[1])) {
      // Support scenarios where original operands were of type [N, H, W, C] and [N, H, W, 1]
      // In this case, we might have NCHW16c coming for 1 operand. However, the other operand does
      // not have enough C dimension. To reuse broadcasting, we would want to use NCHW1c for the
      // second operand. The following section of code walks through the layouts and shapes to
      // perform that operation.
      // a in NCHWC16c
      // b in NHW1
      // b = layout_transform(b) from NHW1 -> NCHW1c
      // add(a, b)
      auto old_small_shape = old_in_shapes[small_idx];
      auto old_small_layout = old_in_layouts[small_idx];
      auto new_small_layout =
          AdjustSubordinateFactors(layouts[large_idx], old_small_layout, old_small_shape);
      layouts.Set(small_idx, new_small_layout);
    } else {
      // Support scenarios where original operands were of type [N, H, W, C] and [C]. In this case,
      // while transforming the layout, we expand dims to make C go to NHWC, and then use the
      // modified layout of the first operator to call the layout transform. E.g.
      // a in NCHWC16c
      // b in C
      // b = expand_dims(b) from C -> NHWC
      // b = layout_transform(b) from NHWC -> NCHW16c
      // add(a, b)
      layouts.Set(small_idx, ret);
199
    }
200
    return Array<Array<Layout>>{layouts, {ret}};
201 202 203
  }
}

204 205 206 207 208 209 210
/*!
 * Call registered FInferCorrectLayout of an op.
 * Parameters are the same as the parameters for FInferCorrectLayout
 * Returns inferred_input_layout, inferred_output_layout, success
 */
static inline std::tuple<Array<Layout>, Array<Layout>, bool> InferCorrectLayouts(
    const Call& call, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
211
    const Array<tvm::relay::Type>& old_in_types) {
212 213 214 215 216 217 218 219 220
  static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");
  if (!call->op.as<OpNode>()) {
    return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
  }

  Op op = Downcast<Op>(call->op);
  if (finfer_layout.count(op)) {
    Array<Array<Layout>> inferred_layouts;
    inferred_layouts =
221
        finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types);
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
    CHECK_EQ(inferred_layouts.size(), 2)
        << "FInferCorrectLayout should return an array with size of 2";
    for (auto x : inferred_layouts) {
      for (auto y : x) {
        if (!y.defined()) {  // inference fails
          return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
        }
      }
    }
    return std::make_tuple<>(inferred_layouts[0], inferred_layouts[1], true);
  } else {
    return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
  }
}

237 238 239
}  //  namespace relay
}  //  namespace tvm

240
#endif  // TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_