expr.h 8.61 KB
Newer Older
tqchen committed
1 2
/*!
 *  Copyright (c) 2016 by Contributors
tqchen committed
3
 * \file tvm/expr.h
tqchen committed
4
 * \brief The Expr and related elements in DataFlow construction.
tqchen committed
5 6 7 8
 */
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_

9
#include <ir/Expr.h>
10
#include <ir/IROperator.h>
11
#include <ir/IRPrinter.h>
12
#include <string>
tqchen committed
13
#include <algorithm>
tqchen committed
14
#include "./base.h"
15
#include "./runtime/c_runtime_api.h"
tqchen committed
16 17 18

namespace tvm {

19 20 21 22 23 24 25 26
using HalideIR::Type;
using HalideIR::Float;
using HalideIR::Bool;
using HalideIR::Int;
using HalideIR::UInt;
using HalideIR::Handle;
using HalideIR::ExprHash;
using HalideIR::ExprEqual;
27

28 29 30 31 32 33 34 35
using HalideIR::Expr;
using HalideIR::VarExpr;
using HalideIR::IR::RangeNode;
using HalideIR::IR::FunctionRef;
using HalideIR::IR::FunctionBaseNode;
using HalideIR::Internal::Stmt;
using HalideIR::Internal::IRPrinter;
using HalideIR::Internal::Variable;
36

37 38 39 40 41 42 43
using HalideIR::Internal::make_const;
using HalideIR::Internal::make_zero;
using HalideIR::Internal::as_const_int;
using HalideIR::Internal::as_const_uint;
using HalideIR::Internal::const_true;
using HalideIR::Internal::const_false;
using HalideIR::Internal::is_no_op;
44

45 46 47 48 49 50 51 52
inline Type TVMShapeIndexType() {
  if (std::is_signed<tvm_index_t>::value) {
    return Int(sizeof(tvm_index_t) * 8);
  } else {
    return UInt(sizeof(tvm_index_t) * 8);
  }
}

53
inline Type TVMType2Type(TVMType t) {
54
  return Type(static_cast<halideir_type_code_t>(t.code), t.bits, t.lanes);
55 56 57 58 59 60 61 62 63 64
}

inline TVMType Type2TVMType(Type t) {
  TVMType ret;
  ret.code = static_cast<uint8_t>(t.code());
  ret.bits = static_cast<uint8_t>(t.bits());
  ret.lanes = static_cast<uint16_t>(t.lanes());
  return ret;
}

65 66 67 68 69 70 71 72
// Get number of bytes considering vector type.
inline int GetVectorBytes(Type dtype) {
  int data_bits = dtype.bits() * dtype.lanes();
  CHECK_EQ(data_bits % 8, 0U)
      << "Need to load/store by multiple of bytes";
  return data_bits / 8;
}

tqchen committed
73
/*! \brief a named variable in TVM */
74
class Var : public HalideIR::VarExpr {
75
 public:
76
  EXPORT explicit Var(const std::string& name_hint = "v",
77
               Type t = Int(32)) : VarExpr(name_hint, t) {}
tqchen committed
78
  explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
79
  explicit Var(VarExpr v) : VarExpr(v) {}
80 81 82 83 84 85 86 87
  /*!
   * \brief Make a new copy of var with same type, append suffix
   * \param suffix The suffix to be appended.
   * \return the new Var copy
   */
  Var copy_with_suffix(const std::string& suffix) const {
    return Var((*this)->name_hint + suffix, (*this)->type);
  }
88 89
  /*! \brief type indicate the container type */
  using ContainerType = Variable;
90
};
tqchen committed
91

tqchen committed
92 93 94 95 96

/*! \brief container class of iteration variable. */
class IterVarNode;

/*!
97
 * \brief same as HalideIR::IR::Range
tqchen committed
98 99 100 101 102 103 104
 *  except it provide an constructor with (begin, end)
 *
 *  \note Traditional Halide's Range have a constructor with
 *   (begin, extent), which does not match the convention in e.g. python.
 *   We decided to correct it by removing the constructor in HalideIR,
 *   and add it back in TVM's range.
 */
105
class Range : public HalideIR::IR::Range {
tqchen committed
106 107 108
 public:
  /*! \brief constructor */
  Range() {}
109
  explicit Range(std::shared_ptr<Node> n) : HalideIR::IR::Range(n) {}
tqchen committed
110 111 112 113 114
  /*!
   * \brief constructor by begin and end
   * \param begin The begin of the range.
   * \param end The end of the range.
   */
115
  TVM_DLL Range(Expr begin, Expr end);
tqchen committed
116

117
  TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
tqchen committed
118 119 120
};

/*!
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
 * \brief Type of iteration variable.
 *  Each IterVar have a specific type.
 *
 *  The type of iter var can be overriden via
 *  stage.iter_var_attrs given they are compatible.
 */
enum IterVarType : int {
  /*!
   * \brief Data parallel iteration.
   *  This normally corresponds to axis of Tensor.
   *  Allow all IterVar manipulations.
   *
   * \note This does not mean the loop
   *  have to be executed in parallel fashion.
   */
  kDataPar = 0,
  /*!
   * \brief The IterVar itself is a thread-index
   *  of a fixed thread launching group.
   *  Note that this is already assumed to be paralellized.
   *
   *  Disallow: split/fuse/vectorize/parallel
   */
  kThreadIndex = 1,
  /*!
   * \brief Communicative reduction.
   *  Cannot be directly parallelized.
   *
   *  Disallow: parallel/vectorize
   */
  kCommReduce = 2,
  /*!
   * \brief Serial loops with loop carry dependency,
   *  the iteration must execute in order.
   *  Cannot be re-ordered.
   *
   *  Disallow: reorder/parallel/vectorize
   */
  kOrdered = 3,
  /*!
   * \brief IterVar is opaque,
   *
   *  May not corresponds to any generated loop
   *  Disallow all IterVar manipulations and compute_at
   *
   * \note This is usually used to implement composite op
   *  or external op, where the
   */
  kOpaque = 4,
  // The following are possible additional
  // types that are provided during schedule
  /*!
   * \brief The execution is unrolled.
   */
  kUnrolled = 5,
  /*!
   * \brief The loop is vectorized.
   */
  kVectorized = 6,
  /*!
   * \brief The loop is parallelized.
   */
183 184 185 186 187
  kParallelized = 7,
  /*!
   * \brief Marks boundary of tensorization intrinsic.
   */
  kTensorized = 8
188 189 190
};

/*!
tqchen committed
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
 * \brief Iteration Variable,
 *  represents an iteration over an integer interval.
 */
class IterVar : public NodeRef {
 public:
  // construct a new iter var without a domain
  IterVar() {}
  // construct from shared ptr.
  explicit IterVar(std::shared_ptr<Node> n) : NodeRef(n) {}
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const IterVarNode* operator->() const;
  /*!
   * \return the corresponding var in the IterVar.
   */
  inline operator Expr() const;
  /*! \brief specify container node */
  using ContainerType = IterVarNode;
};

213 214 215 216 217 218
/*!
 * \brief Create a new IterVar that represents an axis in thread.
 *
 * \param dom Optional, domain of the thread axis.
 * \param tag The thread tag of the axis.
 */
219
TVM_DLL IterVar thread_axis(Range dom, std::string tag);
220 221 222 223 224 225 226

/*!
 * \brief Create a new IterVar for reduction operations.
 *
 * \param dom The domain of the reduction axis.
 * \param name The name of the reduction axis.
 */
227
TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
228

tqchen committed
229 230 231
using Domain = Array<Range>;

// print functions for expr
232
TVM_DLL std::ostream& operator<<(std::ostream& os, const NodeRef& n);  // NOLINT(*)
tqchen committed
233 234 235 236 237 238 239 240 241 242 243 244
// definition of Node.
/*!
 * \brief An iteration variable representing an iteration
 *  over a one dimensional interval.
 */
class IterVarNode : public Node {
 public:
  /*!
   * \brief the domain of iteration, if known, can be None
   *  For the intermediate schedule node, before schedule.
   */
  Range dom;
245 246
  /*! \brief The looping variable */
  Var var;
247 248
  /*! \brief The type of the IterVar */
  IterVarType iter_type;
tqchen committed
249 250 251 252 253 254 255 256
  /*!
   * \brief additional tag on the iteration variable,
   *  set this if this is binded already to a known thread tag.
   */
  std::string thread_tag;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dom", &dom);
257
    v->Visit("var", &var);
258
    v->Visit("iter_type", &iter_type);
tqchen committed
259 260 261
    v->Visit("thread_tag", &thread_tag);
  }

262 263 264
  TVM_DLL static IterVar make(Range dom, Var var,
                              IterVarType iter_type,
                              std::string thread_tag = "");
265

tqchen committed
266
  static constexpr const char* _type_key = "IterVar";
267
  TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node);
tqchen committed
268 269 270 271 272 273 274 275 276 277 278
};

// inline implementations
inline const IterVarNode* IterVar::operator->() const {
  return static_cast<const IterVarNode*>(node_.get());
}

inline IterVar::operator Expr() const {
  return (*this)->var;
}

279 280 281 282
inline const char* IterVarType2String(IterVarType t) {
  switch (t) {
    case kDataPar: return "DataPar";
    case kThreadIndex: return "ThreadIndex";
283
    case kCommReduce: return "CommReduce";
284 285 286 287 288
    case kOrdered: return "Ordered";
    case kOpaque: return "Opaque";
    case kUnrolled: return "Unrolled";
    case kVectorized: return "Vectorized";
    case kParallelized: return "Parallelized";
289
    case kTensorized: return "Tensorized";
290 291 292 293
  }
  return "Unknown";
}

294 295 296 297 298 299 300
/*!
 * \brief Construct a new Var expression
 * \param name_hint The name hint for the expression
 * \param t The type of the expression
 */
TVM_DLL Var var(const std::string& name_hint, Type t = Int(32));

301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
/*
 * \brief Template function to convert Map to unordered_map
 *  Sometimes useful for API gluing when internal uses unordered_map
 * \param dmap The container map
 * \return The corresponding unordered_map.
 * \tparam K the key of the Map.
 * \tparam V the value of the Map.
 */
template<typename K, typename V>
inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
  std::unordered_map<K, V> ret;
  for (auto kv : dmap) {
    ret[kv.first] = kv.second;
  }
  return ret;
}
tqchen committed
317
}  // namespace tvm
318 319 320 321 322 323 324 325 326

namespace std {
template <>
struct hash<::tvm::IterVar> {
  std::size_t operator()(const ::tvm::IterVar& k) const {
    return k.hash();
  }
};
}
tqchen committed
327
#endif  // TVM_EXPR_H_