transform.h 10.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 * \file tvm/relay/attrs/transform.h
 * \brief Transform operators.
 */
#ifndef TVM_RELAY_ATTRS_TRANSFORM_H_
#define TVM_RELAY_ATTRS_TRANSFORM_H_

#include <tvm/attrs.h>
28
#include <tvm/relay/base.h>
29 30 31 32 33
#include <string>

namespace tvm {
namespace relay {

34 35 36 37 38 39 40 41 42 43
/*! \brief data type cast */
struct CastAttrs : public tvm::AttrsNode<CastAttrs> {
  DataType dtype;

  TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") {
    TVM_ATTR_FIELD(dtype)
        .describe("Target data type");
  }
};  // struct CastAttrs.

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
/*! \brief Attributes used in expand_dims operators */
struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
  int axis;
  int num_newaxis;

  TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relay.attrs.ExpandDimsAttrs") {
    TVM_ATTR_FIELD(axis)
        .describe("The axis at which the input array is expanded."
                  "Should lie in range `[-data.ndim - 1, data.ndim]`."
                  "If `axis < 0`, it is the first axis inserted;"
                  "If `axis >= 0`, it is the last axis inserted in Python's negative indexing.");
    TVM_ATTR_FIELD(num_newaxis)
        .describe("Number of axises to be inserted. Should be >= 0.")
        .set_lower_bound(0)
        .set_default(1);
  }
};  // struct ExpandDimsAttrs

62 63 64 65 66 67 68 69 70 71 72 73 74
/*! \brief Attributes used in concatenate operators */
struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> {
  int axis;
  TVM_DECLARE_ATTRS(ConcatenateAttrs, "relay.attrs.ConcatenateAttrs") {
    TVM_ATTR_FIELD(axis)
        .describe("The axis at which the input arrays are concatenated."
                  "Should lie in range `[-ndim, ndim)`.")
        .set_default(0);
  }
};  // struct ConcatenateAttrs

/*! \brief Attributes used in transpose operators */
struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
75
  Array<Integer> axes;
76 77 78 79 80 81 82 83
  TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") {
    TVM_ATTR_FIELD(axes)
        .describe("The target axes order, reverse order if not specified.");
  }
};  // struct TransposeAttrs

/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
84
  Array<Integer> newshape;
85
  bool reverse;
86 87 88
  TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
    TVM_ATTR_FIELD(newshape)
        .describe("The new shape. Should be compatible with the original shape.");
89 90 91
    TVM_ATTR_FIELD(reverse)
        .describe("Infer the special values from right to left if true")
        .set_default(false);
92 93 94
  }
};  // struct ReshapeAttrs

Siva committed
95
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
96
  Integer axis;
97
  std::string mode;
Siva committed
98 99

  TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
100
    TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
Siva committed
101
        .describe("The axis over which to select values.");
102 103 104
    TVM_ATTR_FIELD(mode).set_default("clip")
        .describe("Specify how out-of-bound indices will behave."
                  "clip - clip to the range (default)"
105 106
                  "wrap - wrap around the indices"
                  "fast - no clip or wrap around (user must make sure indices are in-bound)");
Siva committed
107 108 109
  }
};

110 111
/*! \brief Attributes that specify a tensor */
struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
112 113 114
  Array<IndexExpr> shape;
  DataType dtype;

115
  TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") {
116 117 118 119
    TVM_ATTR_FIELD(shape)
      .describe("Target shape.");
    TVM_ATTR_FIELD(dtype)
      .describe("Target data type.")
120
      .set_default(NullValue<DataType>());
121
  }
122
};  // struct InitOpAttrs
123

124 125
/*! \brief Attributes used in arange operators */
struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
126 127 128
  Expr start;
  Expr stop;
  Expr step;
129 130 131
  DataType dtype;

  TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
132
    TVM_ATTR_FIELD(start)
133 134 135
        .describe("Start of interval. The interval includes this value.");
    TVM_ATTR_FIELD(stop)
        .describe("Stop of interval. The interval does not include this value.");
136
    TVM_ATTR_FIELD(step)
137
        .describe("Spacing between values.");
138
    TVM_ATTR_FIELD(dtype)
139 140 141 142
        .describe("Target data type.");
  }
};  // struct ArangeAttrs

143 144 145 146 147 148 149 150 151
/*! \brief Attributes used in stack operators */
struct StackAttrs : public tvm::AttrsNode<StackAttrs> {
  Integer axis;
  TVM_DECLARE_ATTRS(StackAttrs, "relay.attrs.StackAttrs") {
    TVM_ATTR_FIELD(axis).set_default(0)
        .describe("The axis in the result array along which the input arrays are stacked.");
  }
};  // struct StackAttrs

152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
/*! \brief Attributes used in repeat operators */
struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
  Integer repeats;
  Integer axis;
  TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") {
    TVM_ATTR_FIELD(repeats)
        .describe("The number of repetitions for each element.");
    TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
        .describe(" The axis along which to repeat values.");
  }
};  // struct RepeatAttrs

/*! \brief Attributes used in tile operators */
struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
  Array<Integer> reps;
  TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs") {
    TVM_ATTR_FIELD(reps)
        .describe("The number of times for repeating the tensor a."
                  "Each dim sizeof reps must be a positive integer.");
  }
};  // struct TileAttrs

174 175 176 177 178 179 180 181 182
/*! \brief Attributes used in reverse operators */
struct ReverseAttrs : public tvm::AttrsNode<ReverseAttrs> {
  Integer axis;
  TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs") {
    TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
        .describe("The axis along which to reverse elements.");
  }
};  // struct ReverseAttrs

183 184
/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
185 186
  // use axis to make the name numpy compatible.
  Array<Integer> axis;
187 188

  TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
189 190 191
    TVM_ATTR_FIELD(axis)
        .describe("The axis to squeeze in the input tensor."
                  "If `axis = None`, all axis of dimension 1 get squeezed;"
192
                  "Else, the dimension in axes get squeezed."
193 194
                  "It is an error if an axis does not has dimension 1.")
        .set_default(NullValue<Array<Integer> >());
195 196 197
  }
};  // struct SqueezeAttrs

Siva committed
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
  NodeRef indices_or_sections;
  int axis;

  TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
    TVM_ATTR_FIELD(indices_or_sections)
        .describe("Indices or sections to split into. Accepts an int or a tuple"
                  "If indices_or_sections is an integer, the input will be divided equally"
                  "along given axis. If such a split is not possible, an error is raised."
                  "If indices_or_sections is a tuple of sorted integers,"
                  "the entries indicate where along axis the array is split.");
    TVM_ATTR_FIELD(axis).set_default(0)
        .describe("the axis to be splitted.");
  }
};

214 215 216 217 218 219 220 221 222 223
/*! \brief Attributes for StridedSlice operator */
struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
  Array<Integer> begin;
  Array<Integer> end;
  Array<Integer> strides;

  TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
    TVM_ATTR_FIELD(begin)
        .describe("Indices for begin of slice, begin index is also inclusive");
    TVM_ATTR_FIELD(end)
224
        .describe("Indices for end of slice, end index is exclusive");
225 226 227 228
    TVM_ATTR_FIELD(strides).set_default(Array<Integer>({}))
        .describe("Stride values of the slice");
  }
};
229 230 231 232 233 234 235 236 237 238 239 240

struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
  Array<Integer> axes;

  TVM_DECLARE_ATTRS(SliceLikeAttrs, "relay.attrs.SliceLikeAttrs") {
    TVM_ATTR_FIELD(axes)
        .describe("List of axes on which input data will be sliced according to the "
                  "corresponding size of the second input. By default will slice "
                  "on all axes. Negative axes mean counting in reverse.");
  }
};

241
/*! \brief Attributes for Clip operator */
ziheng committed
242 243 244 245 246
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
  double a_min;
  double a_max;

  TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
247 248 249 250
    TVM_ATTR_FIELD(a_min)
      .describe("The minimum clip value.");
    TVM_ATTR_FIELD(a_max)
      .describe("The maximum clip value.");
ziheng committed
251 252 253
  }
};

254
/*! \brief Attributes for LayoutTransform operator */
255 256 257 258 259 260 261 262 263 264 265 266
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
  std::string src_layout;
  std::string dst_layout;

  TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") {
    TVM_ATTR_FIELD(src_layout)
        .describe("The source layout of the tensor. (e.g. NCHW)");
    TVM_ATTR_FIELD(dst_layout)
        .describe("The destination layout of the tensor. (e.g. NCHW16c)");
  }
};

267 268 269 270 271 272 273 274 275 276 277
/*! \brief Attributes for ShapeOf operator */
struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
  DataType dtype;

  TVM_DECLARE_ATTRS(ShapeOfAttrs, "relay.attrs.ShapeOfAttrs") {
    TVM_ATTR_FIELD(dtype)
        .describe("Target data type")
        .set_default(NullValue<DataType>());
  }
};

278 279 280 281 282 283 284 285 286 287 288 289
struct SequenceMaskAttrs : public tvm::AttrsNode<SequenceMaskAttrs> {
  double mask_value;
  int axis;

  TVM_DECLARE_ATTRS(SequenceMaskAttrs, "relay.attrs.SequenceMaskAttrs") {
    TVM_ATTR_FIELD(mask_value).set_default(0)
      .describe("The masking value.");
    TVM_ATTR_FIELD(axis).set_default(0)
      .describe("The axis of the length dimension. Can only be 0 or 1.");
  }
};  // struct SequenceMaskAttrs.

290 291 292 293 294 295 296 297 298 299 300
/*! \brief Attributes for ndarray_size operator */
struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
  DataType dtype;

  TVM_DECLARE_ATTRS(NdarraySizeAttrs, "relay.attrs.NdarraySizeAttrs") {
    TVM_ATTR_FIELD(dtype)
        .describe("Target data type")
        .set_default(NullValue<DataType>());
  }
};

301 302 303
}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_ATTRS_TRANSFORM_H_