transform.h 6.05 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/*!
 *  Copyright (c) 2018 by Contributors
 * \file tvm/relay/attrs/transform.h
 * \brief Transform operators.
 */
#ifndef TVM_RELAY_ATTRS_TRANSFORM_H_
#define TVM_RELAY_ATTRS_TRANSFORM_H_

#include <tvm/attrs.h>
#include <string>

namespace tvm {
namespace relay {

15 16 17 18 19 20 21 22 23 24
/*! \brief data type cast */
struct CastAttrs : public tvm::AttrsNode<CastAttrs> {
  DataType dtype;

  TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") {
    TVM_ATTR_FIELD(dtype)
        .describe("Target data type");
  }
};  // struct CastAttrs.

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/*! \brief Attributes used in expand_dims operators */
struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
  int axis;
  int num_newaxis;

  TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relay.attrs.ExpandDimsAttrs") {
    TVM_ATTR_FIELD(axis)
        .describe("The axis at which the input array is expanded."
                  "Should lie in range `[-data.ndim - 1, data.ndim]`."
                  "If `axis < 0`, it is the first axis inserted;"
                  "If `axis >= 0`, it is the last axis inserted in Python's negative indexing.");
    TVM_ATTR_FIELD(num_newaxis)
        .describe("Number of axises to be inserted. Should be >= 0.")
        .set_lower_bound(0)
        .set_default(1);
  }
};  // struct ExpandDimsAttrs

43 44 45 46 47 48 49 50 51 52 53 54 55
/*! \brief Attributes used in concatenate operators */
struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> {
  int axis;
  TVM_DECLARE_ATTRS(ConcatenateAttrs, "relay.attrs.ConcatenateAttrs") {
    TVM_ATTR_FIELD(axis)
        .describe("The axis at which the input arrays are concatenated."
                  "Should lie in range `[-ndim, ndim)`.")
        .set_default(0);
  }
};  // struct ConcatenateAttrs

/*! \brief Attributes used in transpose operators */
struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
56
  Array<Integer> axes;
57 58 59 60 61 62 63 64
  TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") {
    TVM_ATTR_FIELD(axes)
        .describe("The target axes order, reverse order if not specified.");
  }
};  // struct TransposeAttrs

/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
65
  Array<Integer> newshape;
66 67 68 69 70 71
  TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
    TVM_ATTR_FIELD(newshape)
        .describe("The new shape. Should be compatible with the original shape.");
  }
};  // struct ReshapeAttrs

Siva committed
72
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
73
  Integer axis;
Siva committed
74 75

  TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
76
    TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
Siva committed
77 78 79 80
        .describe("The axis over which to select values.");
  }
};

81 82
/*! \brief Attributes that specify a tensor */
struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
83 84 85
  Array<IndexExpr> shape;
  DataType dtype;

86
  TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") {
87 88 89 90
    TVM_ATTR_FIELD(shape)
      .describe("Target shape.");
    TVM_ATTR_FIELD(dtype)
      .describe("Target data type.")
91
      .set_default(NullValue<DataType>());
92
  }
93
};  // struct InitOpAttrs
94

95 96
/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
97 98
  // use axis to make the name numpy compatible.
  Array<Integer> axis;
99 100

  TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
101 102 103
    TVM_ATTR_FIELD(axis)
        .describe("The axis to squeeze in the input tensor."
                  "If `axis = None`, all axis of dimension 1 get squeezed;"
104
                  "Else, the dimension in axes get squeezed."
105 106
                  "It is an error if an axis does not has dimension 1.")
        .set_default(NullValue<Array<Integer> >());
107 108 109
  }
};  // struct SqueezeAttrs

Siva committed
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
  NodeRef indices_or_sections;
  int axis;

  TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
    TVM_ATTR_FIELD(indices_or_sections)
        .describe("Indices or sections to split into. Accepts an int or a tuple"
                  "If indices_or_sections is an integer, the input will be divided equally"
                  "along given axis. If such a split is not possible, an error is raised."
                  "If indices_or_sections is a tuple of sorted integers,"
                  "the entries indicate where along axis the array is split.");
    TVM_ATTR_FIELD(axis).set_default(0)
        .describe("the axis to be splitted.");
  }
};

126 127 128 129 130 131 132 133 134 135
/*! \brief Attributes for StridedSlice operator */
struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
  Array<Integer> begin;
  Array<Integer> end;
  Array<Integer> strides;

  TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
    TVM_ATTR_FIELD(begin)
        .describe("Indices for begin of slice, begin index is also inclusive");
    TVM_ATTR_FIELD(end)
136
        .describe("Indices for end of slice, end index is exclusive");
137 138 139 140
    TVM_ATTR_FIELD(strides).set_default(Array<Integer>({}))
        .describe("Stride values of the slice");
  }
};
141 142 143 144 145 146 147 148 149 150 151 152 153


struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
  Array<Integer> axes;

  TVM_DECLARE_ATTRS(SliceLikeAttrs, "relay.attrs.SliceLikeAttrs") {
    TVM_ATTR_FIELD(axes)
        .describe("List of axes on which input data will be sliced according to the "
                  "corresponding size of the second input. By default will slice "
                  "on all axes. Negative axes mean counting in reverse.");
  }
};

ziheng committed
154 155 156 157 158 159 160 161 162 163 164 165 166
// Clip
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
  double a_min;
  double a_max;

  TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
  TVM_ATTR_FIELD(a_min)
    .describe("The minimum clip value.");
  TVM_ATTR_FIELD(a_max)
    .describe("The maximum clip value.");
  }
};

167 168 169 170 171 172 173 174 175 176 177 178 179

struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
  std::string src_layout;
  std::string dst_layout;

  TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") {
    TVM_ATTR_FIELD(src_layout)
        .describe("The source layout of the tensor. (e.g. NCHW)");
    TVM_ATTR_FIELD(dst_layout)
        .describe("The destination layout of the tensor. (e.g. NCHW16c)");
  }
};

180 181 182
}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_ATTRS_TRANSFORM_H_