expr.h 13 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

tqchen committed
20
/*!
tqchen committed
21
 * \file tvm/expr.h
tqchen committed
22
 * \brief The Expr and related elements in DataFlow construction.
tqchen committed
23 24 25 26
 */
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_

27
#include <string>
tqchen committed
28
#include <algorithm>
29
#include <unordered_map>
30
#include "base.h"
31 32 33
#include "dtype.h"
#include "node/container.h"
#include "node/ir_functor.h"
34
#include "runtime/c_runtime_api.h"
tqchen committed
35 36 37

namespace tvm {

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
/*! \brief Base node of all expressions. */
class ExprNode : public Node {
 public:
  /*! \brief The data type of the expression. */
  DataType type;

  static constexpr const char* _type_key = "Expr";
  TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node);
};

/*! \brief Container of all expressions. */
class Expr : public NodeRef {
 public:
  Expr() {}
  explicit Expr(NodePtr<Node> ptr) : NodeRef(ptr) {}
  /*!
   * \brief construct from integer.
   * \param value The value to be constructed.
   */
  TVM_DLL Expr(int32_t value);  // NOLINT(*)
  /*!
   * \brief construct from float.
   * \param value The value to be constructed.
   */
  TVM_DLL Expr(float value);  // NOLINT(*)
  /*!
   * \brief construct from string.
   * \param str The value to be constructed.
   */
  TVM_DLL Expr(std::string str);  // NOLINT(*)

  /*! \return the data type of this expression. */
  DataType type() const {
    return static_cast<const ExprNode*>(get())->type;
72 73
  }

74 75 76
  /*! \brief type indicate the container type */
  using ContainerType = ExprNode;
};
77

78 79 80 81 82 83
/*! \brief Base node of all statements. */
class StmtNode : public Node {
 public:
  static constexpr const char* _type_key = "Stmt";
  TVM_DECLARE_BASE_NODE_INFO(StmtNode, Node);
};
84

85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
/*! \brief Container of all statements */
class Stmt : public NodeRef {
 public:
  TVM_DEFINE_NODE_REF_METHODS(Stmt, NodeRef, StmtNode);
};

class Var;
/*!
 * \brief A variable node in the IR.
 *
 * A vraible is uniquely identified by its address.
 *
 * Each variable is only binded once in the following nodes:
 * - Allocate
 * - For
 * - Let
 * - LetStmt
 */
class Variable : public ExprNode {
 public:
  /*!
   * \brief The hint to the variable name.
   * \note Each variable is uniquely identified by its address.
   */
  std::string name_hint;

  static Var make(DataType dtype, std::string name_hint);

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("name", &name_hint);
  }

  static constexpr const char* _type_key = "Variable";
  TVM_DECLARE_NODE_TYPE_INFO(Variable, ExprNode);
};
121

tqchen committed
122
/*! \brief a named variable in TVM */
123
class Var : public Expr {
124
 public:
125 126 127
  explicit Var(NodePtr<Node> n) : Expr(n) {}
  TVM_DLL explicit Var(std::string name_hint = "v",
                       Type t = Int(32));
128 129 130 131 132 133 134 135
  /*!
   * \brief Make a new copy of var with same type, append suffix
   * \param suffix The suffix to be appended.
   * \return the new Var copy
   */
  Var copy_with_suffix(const std::string& suffix) const {
    return Var((*this)->name_hint + suffix, (*this)->type);
  }
136 137 138 139 140 141 142 143 144 145 146 147 148 149
  /*!
   * \brief Get pointer to the internal value.
   * \return the corresponding Variable.
   */
  const Variable* operator->() const {
    return get();
  }
  /*!
   * \brief Get pointer to the internal value.
   * \return the corresponding Variable.
   */
  const Variable* get() const {
    return static_cast<Variable*>(node_.get());
  }
150 151
  /*! \brief type indicate the container type */
  using ContainerType = Variable;
152
};
tqchen committed
153

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
// Backward compatibility, will be removed later.
using VarExpr = Var;
using BaseExprNode = ExprNode;
using ExprHash = NodeHash;
using ExprEqual = NodeEqual;

class Integer;
/*! \brief ExprNode: constant integer. */
class IntImm : public ExprNode {
 public:
  /*! \brief the Internal value. */
  int64_t value;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dtype", &type);
    v->Visit("value", &value);
  }

  TVM_DLL static Integer make(DataType t, int64_t value);

  static constexpr const char* _type_key = "IntImm";
  TVM_DECLARE_NODE_TYPE_INFO(IntImm, ExprNode);
};
tqchen committed
177

178
/*!
179
 * \brief Container of constant integer (IntImm).
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
 *
 * This is used to store and automate type check
 * attributes that must be constant integer.
 */
class Integer : public Expr {
 public:
  Integer() : Expr() {}
  /*!
   * \brief constructor from node.
   */
  explicit Integer(NodePtr<Node> node) : Expr(node) {}
  /*!
   * \brief Construct integer from int value.
   */
  Integer(int value) : Expr(value) {}  // NOLINT(*)
  /*!
   * \brief Assign an expression to integer.
   * \param other another expression.
   */
  Integer& operator=(const Integer& other) {
    node_ = other.node_;
    return *this;
  }
  /*!
   * \brief Get pointer to the internal value.
   * \return the content of the integer.
   */
  const IntImm* operator->() const {
    return static_cast<const IntImm*>(node_.get());
  }
  /*!
   * \brief convert to int64_t
   */
  operator int64_t() const {
    CHECK(node_ != nullptr)
215
        << " Trying to reference a null Integer";
216 217 218 219 220 221
    return (*this)->value;
  }
  /*! \brief type indicate the container type */
  using ContainerType = IntImm;
};

222 223 224 225 226 227 228 229 230 231
/*! \brief range over one dimension */
class RangeNode : public Node {
 public:
  /*! \brief beginning of the node */
  Expr min;
  /*! \brief the extend of range */
  Expr extent;
  /*! \brief constructor */
  RangeNode() {}
  RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}
232

233 234 235 236
  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("min", &min);
    v->Visit("extent", &extent);
  }
tqchen committed
237

238 239 240 241 242 243
  static constexpr const char* _type_key = "Range";
  TVM_DECLARE_NODE_TYPE_INFO(RangeNode, Node);
};

/*! \brief Range constainer  */
class Range : public NodeRef {
tqchen committed
244 245 246 247 248 249
 public:
  /*!
   * \brief constructor by begin and end
   * \param begin The begin of the range.
   * \param end The end of the range.
   */
250
  TVM_DLL Range(Expr begin, Expr end);
251 252 253 254 255 256 257 258 259 260 261 262
  /*!
   * \brief construct a new range with min and extent
   *  The corresponding constructor is removed,
   *  because that is counter convention of tradition meaning
   *  of range(begin, end)
   *
   * \param min The minimum range.
   * \param extent The extent of the range.
   */
  static Range make_by_min_extent(Expr min, Expr extent);
  // declare range.
  TVM_DEFINE_NODE_REF_METHODS(Range, NodeRef, RangeNode);
tqchen committed
263 264
};

265 266 267
/*! \brief container class of iteration variable. */
class IterVarNode;

268 269
using Region = Array<Range>;

tqchen committed
270
/*!
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
 * \brief Type of iteration variable.
 *  Each IterVar have a specific type.
 *
 *  The type of iter var can be overriden via
 *  stage.iter_var_attrs given they are compatible.
 */
enum IterVarType : int {
  /*!
   * \brief Data parallel iteration.
   *  This normally corresponds to axis of Tensor.
   *  Allow all IterVar manipulations.
   *
   * \note This does not mean the loop
   *  have to be executed in parallel fashion.
   */
  kDataPar = 0,
  /*!
   * \brief The IterVar itself is a thread-index
   *  of a fixed thread launching group.
   *  Note that this is already assumed to be paralellized.
   *
   *  Disallow: split/fuse/vectorize/parallel
   */
  kThreadIndex = 1,
  /*!
   * \brief Communicative reduction.
   *  Cannot be directly parallelized.
   *
   *  Disallow: parallel/vectorize
   */
  kCommReduce = 2,
  /*!
   * \brief Serial loops with loop carry dependency,
   *  the iteration must execute in order.
   *  Cannot be re-ordered.
   *
   *  Disallow: reorder/parallel/vectorize
   */
  kOrdered = 3,
  /*!
   * \brief IterVar is opaque,
   *
   *  May not corresponds to any generated loop
   *  Disallow all IterVar manipulations and compute_at
   *
   * \note This is usually used to implement composite op
   *  or external op, where the
   */
  kOpaque = 4,
  // The following are possible additional
  // types that are provided during schedule
  /*!
   * \brief The execution is unrolled.
   */
  kUnrolled = 5,
  /*!
   * \brief The loop is vectorized.
   */
  kVectorized = 6,
  /*!
   * \brief The loop is parallelized.
   */
333 334 335 336 337
  kParallelized = 7,
  /*!
   * \brief Marks boundary of tensorization intrinsic.
   */
  kTensorized = 8
338 339 340
};

/*!
tqchen committed
341 342 343 344 345 346 347 348
 * \brief Iteration Variable,
 *  represents an iteration over an integer interval.
 */
class IterVar : public NodeRef {
 public:
  // construct a new iter var without a domain
  IterVar() {}
  // construct from shared ptr.
349
  explicit IterVar(NodePtr<Node> n) : NodeRef(n) {}
tqchen committed
350 351 352 353 354 355 356 357 358 359 360 361 362
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const IterVarNode* operator->() const;
  /*!
   * \return the corresponding var in the IterVar.
   */
  inline operator Expr() const;
  /*! \brief specify container node */
  using ContainerType = IterVarNode;
};

363 364 365 366 367 368
/*!
 * \brief Create a new IterVar that represents an axis in thread.
 *
 * \param dom Optional, domain of the thread axis.
 * \param tag The thread tag of the axis.
 */
369
TVM_DLL IterVar thread_axis(Range dom, std::string tag);
370 371 372 373 374 375 376

/*!
 * \brief Create a new IterVar for reduction operations.
 *
 * \param dom The domain of the reduction axis.
 * \param name The name of the reduction axis.
 */
377
TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
378

tqchen committed
379 380
using Domain = Array<Range>;

381 382 383 384 385 386
/*!
 * \brief Dump the node to stderr, used for debug purposes.
 * \param node The input node
 */
TVM_DLL void Dump(const NodeRef& node);

tqchen committed
387 388 389 390 391 392 393 394 395 396 397 398
// definition of Node.
/*!
 * \brief An iteration variable representing an iteration
 *  over a one dimensional interval.
 */
class IterVarNode : public Node {
 public:
  /*!
   * \brief the domain of iteration, if known, can be None
   *  For the intermediate schedule node, before schedule.
   */
  Range dom;
399 400
  /*! \brief The looping variable */
  Var var;
401 402
  /*! \brief The type of the IterVar */
  IterVarType iter_type;
tqchen committed
403 404 405 406 407 408 409 410
  /*!
   * \brief additional tag on the iteration variable,
   *  set this if this is binded already to a known thread tag.
   */
  std::string thread_tag;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("dom", &dom);
411
    v->Visit("var", &var);
412
    v->Visit("iter_type", &iter_type);
tqchen committed
413 414 415
    v->Visit("thread_tag", &thread_tag);
  }

416 417 418
  TVM_DLL static IterVar make(Range dom, Var var,
                              IterVarType iter_type,
                              std::string thread_tag = "");
419

tqchen committed
420
  static constexpr const char* _type_key = "IterVar";
421
  TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node);
tqchen committed
422 423 424 425 426 427 428 429 430 431 432
};

// inline implementations
inline const IterVarNode* IterVar::operator->() const {
  return static_cast<const IterVarNode*>(node_.get());
}

inline IterVar::operator Expr() const {
  return (*this)->var;
}

433 434 435 436
inline const char* IterVarType2String(IterVarType t) {
  switch (t) {
    case kDataPar: return "DataPar";
    case kThreadIndex: return "ThreadIndex";
437
    case kCommReduce: return "CommReduce";
438 439 440 441 442
    case kOrdered: return "Ordered";
    case kOpaque: return "Opaque";
    case kUnrolled: return "Unrolled";
    case kVectorized: return "Vectorized";
    case kParallelized: return "Parallelized";
443
    case kTensorized: return "Tensorized";
444 445 446 447
  }
  return "Unknown";
}

448 449 450 451 452
/*!
 * \brief Construct a new Var expression
 * \param name_hint The name hint for the expression
 * \param t The type of the expression
 */
453
TVM_DLL Var var(std::string name_hint, Type t = Int(32));
454

455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
/*
 * \brief Template function to convert Map to unordered_map
 *  Sometimes useful for API gluing when internal uses unordered_map
 * \param dmap The container map
 * \return The corresponding unordered_map.
 * \tparam K the key of the Map.
 * \tparam V the value of the Map.
 */
template<typename K, typename V>
inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
  std::unordered_map<K, V> ret;
  for (auto kv : dmap) {
    ret[kv.first] = kv.second;
  }
  return ret;
}
471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496

// Printer infra.
/*! \brief A Pretty printer class to print the IR. */
class IRPrinter {
 public:
  /*! \brief The output stream */
  std::ostream& stream;
  /*! \brief The indentation level. */
  int indent{0};
  explicit IRPrinter(std::ostream& stream)  // NOLINT(*)
      : stream(stream) {}

  /*! \brief The node to be printed. */
  TVM_DLL void Print(const NodeRef& node);
  /*! \brief Print indent to the stream */
  TVM_DLL void PrintIndent();
  // Allow registration to be printer.
  using FType = IRFunctor<void(const NodeRef&, IRPrinter *)>;
  TVM_DLL static FType& vtable();
};

// default print function for all nodes
inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) {  // NOLINT(*)
  IRPrinter(os).Print(n);
  return os;
}
tqchen committed
497
}  // namespace tvm
498 499 500 501 502 503 504 505 506

namespace std {
template <>
struct hash<::tvm::IterVar> {
  std::size_t operator()(const ::tvm::IterVar& k) const {
    return k.hash();
  }
};
}
tqchen committed
507
#endif  // TVM_EXPR_H_