transform.h 9.77 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32
/*!
 * \file tvm/relay/attrs/transform.h
 * \brief Transform operators.
 */
#ifndef TVM_RELAY_ATTRS_TRANSFORM_H_
#define TVM_RELAY_ATTRS_TRANSFORM_H_

#include <tvm/attrs.h>
#include <string>

namespace tvm {
namespace relay {

33 34 35 36 37 38 39 40 41 42
/*! \brief data type cast */
struct CastAttrs : public tvm::AttrsNode<CastAttrs> {
  DataType dtype;

  TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") {
    TVM_ATTR_FIELD(dtype)
        .describe("Target data type");
  }
};  // struct CastAttrs.

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
/*! \brief Attributes used in expand_dims operators */
struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
  int axis;
  int num_newaxis;

  TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relay.attrs.ExpandDimsAttrs") {
    TVM_ATTR_FIELD(axis)
        .describe("The axis at which the input array is expanded."
                  "Should lie in range `[-data.ndim - 1, data.ndim]`."
                  "If `axis < 0`, it is the first axis inserted;"
                  "If `axis >= 0`, it is the last axis inserted in Python's negative indexing.");
    TVM_ATTR_FIELD(num_newaxis)
        .describe("Number of axises to be inserted. Should be >= 0.")
        .set_lower_bound(0)
        .set_default(1);
  }
};  // struct ExpandDimsAttrs

61 62 63 64 65 66 67 68 69 70 71 72 73
/*! \brief Attributes used in concatenate operators */
struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> {
  int axis;
  TVM_DECLARE_ATTRS(ConcatenateAttrs, "relay.attrs.ConcatenateAttrs") {
    TVM_ATTR_FIELD(axis)
        .describe("The axis at which the input arrays are concatenated."
                  "Should lie in range `[-ndim, ndim)`.")
        .set_default(0);
  }
};  // struct ConcatenateAttrs

/*! \brief Attributes used in transpose operators */
struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
74
  Array<Integer> axes;
75 76 77 78 79 80 81 82
  TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") {
    TVM_ATTR_FIELD(axes)
        .describe("The target axes order, reverse order if not specified.");
  }
};  // struct TransposeAttrs

/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
83
  Array<Integer> newshape;
84
  bool reverse;
85 86 87
  TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
    TVM_ATTR_FIELD(newshape)
        .describe("The new shape. Should be compatible with the original shape.");
88 89 90
    TVM_ATTR_FIELD(reverse)
        .describe("Infer the special values from right to left if true")
        .set_default(false);
91 92 93
  }
};  // struct ReshapeAttrs

Siva committed
94
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
95
  Integer axis;
96
  std::string mode;
Siva committed
97 98

  TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
99
    TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
Siva committed
100
        .describe("The axis over which to select values.");
101 102 103 104
    TVM_ATTR_FIELD(mode).set_default("clip")
        .describe("Specify how out-of-bound indices will behave."
                  "clip - clip to the range (default)"
                  "wrap - wrap around the indices");
Siva committed
105 106 107
  }
};

108 109
/*! \brief Attributes that specify a tensor */
struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
110 111 112
  Array<IndexExpr> shape;
  DataType dtype;

113
  TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") {
114 115 116 117
    TVM_ATTR_FIELD(shape)
      .describe("Target shape.");
    TVM_ATTR_FIELD(dtype)
      .describe("Target data type.")
118
      .set_default(NullValue<DataType>());
119
  }
120
};  // struct InitOpAttrs
121

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
/*! \brief Attributes used in arange operators */
struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
  tvm::Expr start;
  tvm::Expr stop;
  tvm::Expr step;
  DataType dtype;

  TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
    TVM_ATTR_FIELD(start).set_default(make_const(Float(32), 0))
        .describe("Start of interval. The interval includes this value.");
    TVM_ATTR_FIELD(stop)
        .describe("Stop of interval. The interval does not include this value.");
    TVM_ATTR_FIELD(step).set_default(make_const(Float(32), 1))
        .describe("Spacing between values.");
    TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
        .describe("Target data type.");
  }
};  // struct ArangeAttrs

141 142 143 144 145 146 147 148 149
/*! \brief Attributes used in stack operators */
struct StackAttrs : public tvm::AttrsNode<StackAttrs> {
  Integer axis;
  TVM_DECLARE_ATTRS(StackAttrs, "relay.attrs.StackAttrs") {
    TVM_ATTR_FIELD(axis).set_default(0)
        .describe("The axis in the result array along which the input arrays are stacked.");
  }
};  // struct StackAttrs

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
/*! \brief Attributes used in repeat operators */
struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
  Integer repeats;
  Integer axis;
  TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") {
    TVM_ATTR_FIELD(repeats)
        .describe("The number of repetitions for each element.");
    TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
        .describe(" The axis along which to repeat values.");
  }
};  // struct RepeatAttrs

/*! \brief Attributes used in tile operators */
struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
  Array<Integer> reps;
  TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs") {
    TVM_ATTR_FIELD(reps)
        .describe("The number of times for repeating the tensor a."
                  "Each dim sizeof reps must be a positive integer.");
  }
};  // struct TileAttrs

172 173 174 175 176 177 178 179 180
/*! \brief Attributes used in reverse operators */
struct ReverseAttrs : public tvm::AttrsNode<ReverseAttrs> {
  Integer axis;
  TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs") {
    TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
        .describe("The axis along which to reverse elements.");
  }
};  // struct ReverseAttrs

181 182
/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
183 184
  // use axis to make the name numpy compatible.
  Array<Integer> axis;
185 186

  TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
187 188 189
    TVM_ATTR_FIELD(axis)
        .describe("The axis to squeeze in the input tensor."
                  "If `axis = None`, all axis of dimension 1 get squeezed;"
190
                  "Else, the dimension in axes get squeezed."
191 192
                  "It is an error if an axis does not has dimension 1.")
        .set_default(NullValue<Array<Integer> >());
193 194 195
  }
};  // struct SqueezeAttrs

Siva committed
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
  NodeRef indices_or_sections;
  int axis;

  TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
    TVM_ATTR_FIELD(indices_or_sections)
        .describe("Indices or sections to split into. Accepts an int or a tuple"
                  "If indices_or_sections is an integer, the input will be divided equally"
                  "along given axis. If such a split is not possible, an error is raised."
                  "If indices_or_sections is a tuple of sorted integers,"
                  "the entries indicate where along axis the array is split.");
    TVM_ATTR_FIELD(axis).set_default(0)
        .describe("the axis to be splitted.");
  }
};

212 213 214 215 216 217 218 219 220 221
/*! \brief Attributes for StridedSlice operator */
struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
  Array<Integer> begin;
  Array<Integer> end;
  Array<Integer> strides;

  TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
    TVM_ATTR_FIELD(begin)
        .describe("Indices for begin of slice, begin index is also inclusive");
    TVM_ATTR_FIELD(end)
222
        .describe("Indices for end of slice, end index is exclusive");
223 224 225 226
    TVM_ATTR_FIELD(strides).set_default(Array<Integer>({}))
        .describe("Stride values of the slice");
  }
};
227 228 229 230 231 232 233 234 235 236 237 238

struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
  Array<Integer> axes;

  TVM_DECLARE_ATTRS(SliceLikeAttrs, "relay.attrs.SliceLikeAttrs") {
    TVM_ATTR_FIELD(axes)
        .describe("List of axes on which input data will be sliced according to the "
                  "corresponding size of the second input. By default will slice "
                  "on all axes. Negative axes mean counting in reverse.");
  }
};

239
/*! \brief Attributes for Clip operator */
ziheng committed
240 241 242 243 244
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
  double a_min;
  double a_max;

  TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
245 246 247 248
    TVM_ATTR_FIELD(a_min)
      .describe("The minimum clip value.");
    TVM_ATTR_FIELD(a_max)
      .describe("The maximum clip value.");
ziheng committed
249 250 251
  }
};

252
/*! \brief Attributes for LayoutTransform operator */
253 254 255 256 257 258 259 260 261 262 263 264
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
  std::string src_layout;
  std::string dst_layout;

  TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") {
    TVM_ATTR_FIELD(src_layout)
        .describe("The source layout of the tensor. (e.g. NCHW)");
    TVM_ATTR_FIELD(dst_layout)
        .describe("The destination layout of the tensor. (e.g. NCHW16c)");
  }
};

265 266 267 268 269 270 271 272 273 274 275
/*! \brief Attributes for ShapeOf operator */
struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
  DataType dtype;

  TVM_DECLARE_ATTRS(ShapeOfAttrs, "relay.attrs.ShapeOfAttrs") {
    TVM_ATTR_FIELD(dtype)
        .describe("Target data type")
        .set_default(NullValue<DataType>());
  }
};

276 277 278
}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_ATTRS_TRANSFORM_H_