expr.h 8.21 KB
Newer Older
tqchen committed
1 2 3
/*!
 *  Copyright (c) 2016 by Contributors
 * \file expr.h
tqchen committed
4
 * \brief The Expr and related elements in DataFlow construction.
tqchen committed
5 6 7 8
 */
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_

9
#include <ir/Expr.h>
tqchen committed
10
#include <ir/IRPrinter.h>
11
#include <ir/IROperator.h>
12
#include <string>
tqchen committed
13
#include <algorithm>
tqchen committed
14
#include "./base.h"
15
#include "./runtime/c_runtime_api.h"
tqchen committed
16 17 18

namespace tvm {

19
using Halide::Type;
20
using Halide::Float;
21
using Halide::Bool;
22 23 24
using Halide::Int;
using Halide::UInt;
using Halide::Handle;
25 26
using Halide::ExprHash;
using Halide::ExprEqual;
27

28
using Halide::Expr;
tqchen committed
29
using Halide::VarExpr;
tqchen committed
30
using Halide::IR::FunctionRef;
tqchen committed
31
using Halide::IR::FunctionBaseNode;
tqchen committed
32
using Halide::Internal::Stmt;
tqchen committed
33
using Halide::Internal::IRPrinter;
34
using Halide::Internal::Variable;
35

36
using Halide::Internal::make_const;
37 38 39
using Halide::Internal::make_zero;
using Halide::Internal::as_const_int;
using Halide::Internal::as_const_uint;
40 41
using Halide::Internal::const_true;
using Halide::Internal::const_false;
Tianqi Chen committed
42
using Halide::Internal::is_no_op;
43 44
using Halide::likely;
using Halide::likely_if_innermost;
45

46 47 48 49 50 51 52 53
inline Type TVMShapeIndexType() {
  if (std::is_signed<tvm_index_t>::value) {
    return Int(sizeof(tvm_index_t) * 8);
  } else {
    return UInt(sizeof(tvm_index_t) * 8);
  }
}

54 55 56 57 58 59 60 61 62 63 64 65
inline Type TVMType2Type(TVMType t) {
  return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
}

inline TVMType Type2TVMType(Type t) {
  TVMType ret;
  ret.code = static_cast<uint8_t>(t.code());
  ret.bits = static_cast<uint8_t>(t.bits());
  ret.lanes = static_cast<uint16_t>(t.lanes());
  return ret;
}

tqchen committed
66
/*! \brief a named variable in TVM */
67 68 69 70
class Var : public Halide::VarExpr {
 public:
  explicit Var(const std::string& name_hint = "v",
               Type t = Int(32)) : VarExpr(name_hint, t) {}
tqchen committed
71
  explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
72
  explicit Var(VarExpr v) : VarExpr(v) {}
73 74 75 76 77 78 79 80
  /*!
   * \brief Make a new copy of var with same type, append suffix
   * \param suffix The suffix to be appended.
   * \return the new Var copy
   */
  Var copy_with_suffix(const std::string& suffix) const {
    return Var((*this)->name_hint + suffix, (*this)->type);
  }
81 82
  /*! \brief type indicate the container type */
  using ContainerType = Variable;
83
};
tqchen committed
84

tqchen committed
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113

/*! \brief container class of iteration variable. */
class IterVarNode;

/*!
 * \brief same as Halide::IR::Range
 *  except it provide an constructor with (begin, end)
 *
 *  \note Traditional Halide's Range have a constructor with
 *   (begin, extent), which does not match the convention in e.g. python.
 *   We decided to correct it by removing the constructor in HalideIR,
 *   and add it back in TVM's range.
 */
class Range : public Halide::IR::Range {
 public:
  /*! \brief constructor */
  Range() {}
  explicit Range(std::shared_ptr<Node> n) : Halide::IR::Range(n) {}
  /*!
   * \brief constructor by begin and end
   * \param begin The begin of the range.
   * \param end The end of the range.
   */
  Range(Expr begin, Expr end);

  static Range make_with_min_extent(Expr min, Expr extent);
};

/*!
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
 * \brief Type of iteration variable.
 *  Each IterVar have a specific type.
 *
 *  The type of iter var can be overriden via
 *  stage.iter_var_attrs given they are compatible.
 */
enum IterVarType : int {
  /*!
   * \brief Data parallel iteration.
   *  This normally corresponds to axis of Tensor.
   *  Allow all IterVar manipulations.
   *
   * \note This does not mean the loop
   *  have to be executed in parallel fashion.
   */
  kDataPar = 0,
  /*!
   * \brief The IterVar itself is a thread-index
   *  of a fixed thread launching group.
   *  Note that this is already assumed to be paralellized.
   *
   *  Disallow: split/fuse/vectorize/parallel
   */
  kThreadIndex = 1,
  /*!
   * \brief Communicative reduction.
   *  Cannot be directly parallelized.
   *
   *  Disallow: parallel/vectorize
   */
  kCommReduce = 2,
  /*!
   * \brief Serial loops with loop carry dependency,
   *  the iteration must execute in order.
   *  Cannot be re-ordered.
   *
   *  Disallow: reorder/parallel/vectorize
   */
  kOrdered = 3,
  /*!
   * \brief IterVar is opaque,
   *
   *  May not corresponds to any generated loop
   *  Disallow all IterVar manipulations and compute_at
   *
   * \note This is usually used to implement composite op
   *  or external op, where the
   */
  kOpaque = 4,
  // The following are possible additional
  // types that are provided during schedule
  /*!
   * \brief The execution is unrolled.
   */
  kUnrolled = 5,
  /*!
   * \brief The loop is vectorized.
   */
  kVectorized = 6,
  /*!
   * \brief The loop is parallelized.
   */
  kParallelized = 7
};

/*!
tqchen committed
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
 * \brief Iteration Variable,
 *  represents an iteration over an integer interval.
 */
class IterVar : public NodeRef {
 public:
  // construct a new iter var without a domain
  IterVar() {}
  // construct from shared ptr.
  explicit IterVar(std::shared_ptr<Node> n) : NodeRef(n) {}
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const IterVarNode* operator->() const;
  /*!
   * \return the corresponding var in the IterVar.
   */
  inline operator Expr() const;
  /*! \brief specify container node */
  using ContainerType = IterVarNode;
};

202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
/*!
 * \brief Create a new IterVar that represents an axis in thread.
 *
 * \param dom Optional, domain of the thread axis.
 * \param tag The thread tag of the axis.
 */
IterVar thread_axis(Range dom, std::string tag);

/*!
 * \brief Create a new IterVar for reduction operations.
 *
 * \param dom The domain of the reduction axis.
 * \param name The name of the reduction axis.
 */
IterVar reduce_axis(Range dom, std::string name = "rv");

tqchen committed
218 219 220 221 222 223 224 225 226 227
using Domain = Array<Range>;

// functions
using Halide::cast;
using Halide::min;
using Halide::max;
using Halide::abs;
using Halide::select;

/*!
228
 * \brief sum of of source expression over axis
tqchen committed
229
 * \param source The source expression.
230
 * \param axis List of iteration variables that will be used for reduction.
tqchen committed
231
 */
232
Expr sum(Expr source, Array<IterVar> axis);
tqchen committed
233 234

/*!
235
 * \brief max of of source expression over axis
tqchen committed
236
 * \param source The source expression.
237
 * \param axis List of iteration variables that will be used for reduction.
tqchen committed
238
 */
239
Expr max(Expr source, Array<IterVar> axis);
tqchen committed
240 241

/*!
242
 * \brief max of of source expression over axis
tqchen committed
243
 * \param source The source expression.
244
 * \param axis List of iteration variables that will be used for reduction.
tqchen committed
245
 */
246
Expr min(Expr source, Array<IterVar> axis);
tqchen committed
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263


// print functions for expr
std::ostream& operator<<(std::ostream& os, const NodeRef& n);  // NOLINT(*)

// definition of Node.
/*!
 * \brief An iteration variable representing an iteration
 *  over a one dimensional interval.
 */
class IterVarNode : public Node {
 public:
  /*!
   * \brief the domain of iteration, if known, can be None
   *  For the intermediate schedule node, before schedule.
   */
  Range dom;
264 265
  /*! \brief The looping variable */
  Var var;
266 267
  /*! \brief The type of the IterVar */
  IterVarType iter_type;
tqchen committed
268 269 270 271 272 273 274 275
  /*!
   * \brief additional tag on the iteration variable,
   *  set this if this is binded already to a known thread tag.
   */
  std::string thread_tag;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dom", &dom);
276
    v->Visit("var", &var);
277
    v->Visit("iter_type", &iter_type);
tqchen committed
278 279 280
    v->Visit("thread_tag", &thread_tag);
  }

281 282 283
  static IterVar make(Range dom, Var var,
                      IterVarType iter_type,
                      std::string thread_tag = "");
284

tqchen committed
285
  static constexpr const char* _type_key = "IterVar";
286
  TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node);
tqchen committed
287 288 289 290 291 292 293 294 295 296 297
};

// inline implementations
inline const IterVarNode* IterVar::operator->() const {
  return static_cast<const IterVarNode*>(node_.get());
}

inline IterVar::operator Expr() const {
  return (*this)->var;
}

298 299 300 301
inline const char* IterVarType2String(IterVarType t) {
  switch (t) {
    case kDataPar: return "DataPar";
    case kThreadIndex: return "ThreadIndex";
302
    case kCommReduce: return "CommReduce";
303 304 305 306 307 308 309 310 311
    case kOrdered: return "Ordered";
    case kOpaque: return "Opaque";
    case kUnrolled: return "Unrolled";
    case kVectorized: return "Vectorized";
    case kParallelized: return "Parallelized";
  }
  return "Unknown";
}

tqchen committed
312
}  // namespace tvm
313 314 315 316 317 318 319 320 321

namespace std {
template <>
struct hash<::tvm::IterVar> {
  std::size_t operator()(const ::tvm::IterVar& k) const {
    return k.hash();
  }
};
}
tqchen committed
322
#endif  // TVM_EXPR_H_