schedule.h 18.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
/*!
 *  Copyright (c) 2016 by Contributors
 * \file schedule.h
 * \brief Define a schedule.
 */
#ifndef TVM_SCHEDULE_H_
#define TVM_SCHEDULE_H_

#include <string>
#include "./base.h"
11
#include "./expr.h"
12
#include "./tensor.h"
13 14 15

namespace tvm {

16 17
// Node container for Stage
class StageNode;
18 19
// Node container for Schedule
class ScheduleNode;
20 21
// Node container for IterVarRelation
class IterVarRelationNode;
22 23
// Attribute of itervar.
class IterVarAttrNode;
24 25 26

/*! \brief the attachment type */
enum AttachType : int {
27
  kGroupRoot = 1,
tqchen committed
28
  kInline = 2,
29 30 31
  kInlinedAlready = 3,
  kScope = 4,
  kScanUpdate = 5
32 33
};

34 35
/*! \brief Stage, contains scheduling for a stage of computation. */
class Stage : public NodeRef {
36
 public:
37 38
  Stage() {}
  explicit Stage(std::shared_ptr<Node> n) : NodeRef(n) {}
tqchen committed
39 40 41 42
  /*!
   * \brief create a new schedule for op.
   * \param op The operator in the schedule
   */
43
  explicit Stage(Operation op);
44 45 46 47
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
48
  inline const StageNode* operator->() const;
tqchen committed
49 50 51 52
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
53 54 55 56 57 58
  inline StageNode* operator->();
  /*!
   * \brief set the memory scope of the stage
   * \param scope The memory scope.
   */
  Stage& set_scope(std::string scope);  // NOLINT(*)
tqchen committed
59 60 61 62
  /*!
   * \brief specify the schedule to be computed at the parent schedule's scope.
   * \param parent The parent schedule.
   * \param scope The iteration point to carry the schedule.
tqchen committed
63
   * \return reference to self.
tqchen committed
64
   */
65
  Stage& compute_at(Stage parent, IterVar scope);   // NOLINT(*)
tqchen committed
66
  /*!
67
   * \brief Compute the function inline.
tqchen committed
68
   * \return reference to self.
tqchen committed
69
   */
70
  Stage& compute_inline();   // NOLINT(*)
tqchen committed
71
  /*!
72
   * \brief Compute the function at group root.
tqchen committed
73
   * \return reference to self.
tqchen committed
74
   */
75
  Stage& compute_root();  // NOLINT(*)
tqchen committed
76
  /*!
77
   * \brief Bind the ivar to thread index.
78
   *
79 80
   * \param ivar The IterVar to be binded.
   * \param thread_ivar The thread axis to be binded.
81 82
   * \return reference to self.
   */
83 84
  Stage& bind(IterVar ivar, IterVar thread_ivar);
  /*!
85 86 87 88 89 90 91 92 93 94 95
   * \brief Set predicate under which store to the array can be performed.
   *  Use this when there are duplicated threads doing the same store and we only
   *  need one of them to do the store.
   *
   * \note This is a dangerous scheduling primitive that can change behavior of program.
   *    Only do when we are certain that thare are duplicated store.
   * \param predicate The condition to be checked.
   * \return reference to self.
   */
  Stage& set_store_predicate(Expr predicate);
  /*!
96 97 98 99
   * \brief Specify environment threads that launched around the group's scope.
   *  This can only be used in group stage.
   * \param threads The threads to be launched around the scope.
   * \note Each thread can only appear in one env_threads.
100
   *    This is a beta feature.
101 102 103
   * \return reference to self.
   */
  Stage& env_threads(Array<IterVar> threads);
104
  /*!
tqchen committed
105 106
   * \brief Split the parent by factor, generate
   * \param parent The parent iteration domain.
107
   * \param factor The split factor of the loop.
tqchen committed
108 109
   * \param p_outer The result outer domain
   * \param p_inner The result inner domain.
tqchen committed
110
   * \return reference to self.
tqchen committed
111
   */
112
  Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner);  // NOLINT(*)
tqchen committed
113
  /*!
114
   * \brief Split the iteration with given number of parts.
tqchen committed
115 116
   *
   * \param parent The parent domain.
117 118
   * \param nparts The number of parts in the outer domain.
   * \param p_outer The result outer domain.
tqchen committed
119
   * \param p_inner The result inner domain.
tqchen committed
120
   * \return reference to self.
tqchen committed
121
   */
122
  Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner);   // NOLINT(*)
tqchen committed
123 124 125 126 127
  /*!
   * \brief Fuse the inner outer domain to the target
   * \param inner The inner domain to be fused
   * \param outer The outer domain to be fused.
   * \param p_target The result target domain.
tqchen committed
128
   * \return reference to self.
tqchen committed
129
   */
130
  Stage& fuse(IterVar inner, IterVar outer, IterVar* p_target);  // NOLINT(*)
tqchen committed
131 132 133
  /*!
   * \brief Reorder the iteration
   * \param order The order of iteration variable.
tqchen committed
134
   * \return reference to self.
tqchen committed
135
   */
136 137 138 139 140 141 142 143
  Stage& reorder(const Array<IterVar>& order);   // NOLINT(*)
  /*!
   * \brief Perform tiling on two dimensions
   *  The final loop order from outmost to inner most are
   *  [x_outer, y_outer, x_inner, y_inner]
   *
   * \param x_parent The original x dimension
   * \param y_parent The original y dimension
144 145
   * \param x_factor The stride factor on x axis
   * \param y_factor The stride factor on y axis
146 147 148 149 150 151 152
   * \param p_x_outer Outer axis of x dimension
   * \param p_y_outer Outer axis of y dimension
   * \param p_x_inner Inner axis of x dimension
   * \param p_y_inner Inner axis of y dimension
   * \return reference to self.
   */
  Stage& tile(IterVar x_parent, IterVar y_parent,   // NOLINT(*)
153
              Expr x_factor, Expr y_factor,
154
              IterVar* p_x_outer, IterVar* p_y_outer,
155
              IterVar* p_x_inner, IterVar* p_y_inner);
156
  /*!
157 158 159 160 161 162 163
   * \brief Vectorize iteration.
   * \param var The axis to be vectorized.
   * \return reference to self.
   */
  Stage& vectorize(IterVar var);   // NOLINT(*)
  /*!
   * \brief Unroll iteration.
164
   * \param var The axis to be unrolled.
165 166 167 168
   * \return reference to self.
   */
  Stage& unroll(IterVar var);   // NOLINT(*)
  /*!
169 170 171 172 173 174
   * \brief Parallelize iteration.
   * \param var The axis to be parallelized.
   * \return reference to self.
   */
  Stage& parallel(IterVar var);   // NOLINT(*)
  /*!
175 176 177
   * \brief whether the stage has been scheduled.
   * \return whether the stage has been scheduled.
   */
178 179 180 181 182 183 184 185 186
  bool is_scheduled() const;
  /*!
   * \brief Get attachment spec of current stage.
   *  If the stage compute at Group root, this function
   *  will traverse the group function to get the
   *  final spec from the group.
   * \return A stage representing the attach spec of the group.
   */
  Stage GetAttachSpec() const;
187 188
  // declare container type
  using ContainerType = StageNode;
189 190 191 192 193 194 195 196 197 198 199 200
};

/*!
 * \brief Global schedule container
 *  For operations and all the operations they depend on.
 *  The schedule per Operation is named as stage.
 */
class Schedule : public NodeRef {
 public:
  Schedule() {}
  explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
  /*!
201 202 203 204 205
   * \brief Get a copy of current schedule.
   * \return The copied schedule.
   */
  Schedule copy() const;
  /*!
206 207 208 209 210 211 212 213 214 215 216 217 218
   * \brief Get the stage corresponds to the op
   * \param op The operation.
   */
  Stage operator[](const Operation& op);
  /*!
   * \brief Short hand for getting the stage of tensor's operation.
   * \param tensor The tensor
   * \return The stage corresponding to the tensor's op
   */
  Stage operator[](const Tensor& tensor) {
    return this->operator[](tensor->op);
  }
  /*!
219 220 221 222 223 224 225 226 227 228 229 230
   * \brief Create a new stage group for all intermediate
   *  operations between inputs and outputs.
   *
   * \param outputs The output boundary of the group.
   * \param inputs The input boundary of the group.
   * \param include_inputs Whether include inputs if they are reachable from outputs.
   * \return The new grouped stage.
   */
  Stage create_group(const Array<Tensor>& outputs,
                     const Array<Tensor>& inputs,
                     bool include_inputs = false);
  /*!
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
   * \brief create a cache read of original tensor for readers.
   *  This will mutate the body of the readers.
   *  A new stage will be created for the tensor.
   * \param tensor The tensor cached.
   * \param scope The scope of the cache.
   * \param readers The readers to redirect to the tensor.
   * \return The created tensor.
   */
  Tensor cache_read(const Tensor& tensor,
                    const std::string& scope,
                    const Array<Operation>& readers);
  /*!
   * \brief Create a cache write tensor for producing tensor.
   *  The the tensor will take over body of original tensor op.
   *  The original tensor's body will be changed to an identity read
   *  from the corresponding cache.
   * \param tensor The tensor to be produced.
   * \param scope The scope of the storage.
   * \return The created tensor.
   */
  Tensor cache_write(const Tensor& tensor, const std::string& scope);
  /*!
253 254
   * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
   * This will create a new stage that generated the new tensor with axis
255
   * as the first dimension. The tensor's body will be rewritten as a reduction
256 257 258 259
   * over the factored tensor.
   *
   * \param tensor The tensor to be factored.
   * \param axis The reduction axis in tensor's schedule to be factored.
260
   * \return The created factored tensors.
261
   */
262 263
  Array<Tensor> rfactor(const Tensor& tensor,
                        const IterVar& axis);
264
  /*!
265 266 267 268 269 270 271
   * \brief Normalize the schedule.
   *  This is needed before bound inference.
   *  Insert necessary RebaseNode to make sure all leaf_iter_vars
   *  are in form [0, extent)
   *
   * \return A normalized schedule, can be same as current one.
   */
272
  Schedule normalize();
273
  /*!
274 275 276 277
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const ScheduleNode* operator->() const;
278 279 280 281 282
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline ScheduleNode* operator->();
283 284
  // declare container type
  using ContainerType = ScheduleNode;
285 286
};

287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
/*!
 * \brief The schedule relation between IterVars
 *  can be Split, Fuse.
 */
class IterVarRelation : public NodeRef {
 public:
  IterVarRelation() {}
  explicit IterVarRelation(std::shared_ptr<Node> n) : NodeRef(n) {}
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const IterVarRelationNode* operator->() const;
};

302 303 304 305 306 307 308 309 310 311 312 313 314 315
/*!
 * \brief Additional scheduable attributes about IterVar.
 */
class IterVarAttr : public NodeRef {
 public:
  IterVarAttr() {}
  explicit IterVarAttr(std::shared_ptr<Node> n) : NodeRef(n) {}
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const IterVarAttrNode* operator->() const;
};

316
/*!
317
 * \brief represents a stage.
318
 *
319
 *  relations form a Directed acylic hypergraph in bipartite manner.
320 321
 *  With each node is represented by a IterVar,
 *  and each hyper-edge is represented by a IterVarRelation.
322
 *  The relations connects the IterVars in the graph.
323
 *
324 325 326 327
 *  Besides typical stage that corresponds to operations.
 *  There is also group stage, which groups stages together.
 *  Each stage's group(given by group) represent an constraint,
 *  the stage can only be attached to stages within the group.
328
 *
329
 *  The group stage node can be attached to IterVars as in normal stage.
330
 */
331
class StageNode : public Node {
332
 public:
333 334 335 336
  /*!
   * \brief The operation of stage, can be different from original op.
   *  If it is null, then this stage is a group stage.
   */
337 338 339 340 341 342 343
  Operation op;
  /*!
   * \brief The original operator.
   *  The op field can change during schedule to alternate the dataflow,
   *  while origin_op remains fixed.
   */
  Operation origin_op;
344 345
  /*! \brief All the nodes in the iter var */
  Array<IterVar> all_iter_vars;
346
  /*! \brief The current active leaf iter vars in the stage. */
347
  Array<IterVar> leaf_iter_vars;
348 349 350
  /*!
   * \brief Specify threads to be launched at the stage.
   *  This is only valid for composite ops such as Scan.
351
   * \note Experimental primitive: used for thread persistence.
352
   */
353
  Array<IterVar> env_threads;
354 355 356 357 358 359
  /*!
   * \brief The predicate under which store can happen
   *  Use this when there can be duplicated threads doing the same store.
   * \note Experimental primitive: used by cross thread-reduction.
   */
  Expr store_predicate;
360 361
  /*! \brief The relation bwteen of IterVars */
  Array<IterVarRelation> relations;
362 363
  /*! \brief additional attributes about iter var. */
  Map<IterVar, IterVarAttr> iter_var_attrs;
tqchen committed
364
  /*! \brief The attachment type of the schedule */
365
  AttachType attach_type{kGroupRoot};
366 367 368 369
  /*! \brief The attach point of this schedule. */
  IterVar attach_ivar;
  /*! \brief The stage this node attaches to */
  Stage attach_stage;
370 371
  /*! \brief The thread storage scope level of the stage */
  std::string scope;
372 373
  /*! \brief Whether this is an output stage */
  bool is_output{false};
374 375 376 377 378 379 380
  /*!
   * \brief The parent group of the current stage.
   *  The stage cannot be assigned to stages outside the group.
   */
  Stage group;
  /*! \brief Number of direct child stages, only used for group stage.*/
  int num_child_stages{0};
381

382
  void VisitAttrs(AttrVisitor* v) final {
tqchen committed
383
    v->Visit("op", &op);
384
    v->Visit("origin_op", &origin_op);
385 386
    v->Visit("all_iter_vars", &all_iter_vars);
    v->Visit("leaf_iter_vars", &leaf_iter_vars);
387
    v->Visit("env_threads", &env_threads);
388
    v->Visit("relations", &relations);
389
    v->Visit("iter_var_attrs", &iter_var_attrs);
tqchen committed
390
    v->Visit("attach_type", &attach_type);
391 392
    v->Visit("attach_ivar", &attach_ivar);
    v->Visit("attach_stage", &attach_stage);
393
    v->Visit("scope", &scope);
394
    v->Visit("is_output", &is_output);
395 396
    v->Visit("group", &group);
    v->Visit("num_child_stages", &num_child_stages);
397 398 399
  }

  static constexpr const char* _type_key = "Stage";
400
  TVM_DECLARE_NODE_TYPE_INFO(StageNode, Node);
401 402 403 404 405
};

/*! \brief node container for schedule */
class ScheduleNode : public Node {
 public:
406 407
  /*! \brief The output operations in original data flow graph */
  Array<Operation> outputs;
408
  /*!
409
   * \brief list of all stages for ops.
410
   * The stages are sorted in dependency order.
411 412
   */
  Array<Stage> stages;
413 414 415 416 417
  /*!
   * \brief List of all stage groups.
   */
  Array<Stage> groups;
  /*! \brief map of original operation to the stages */
418
  Map<Operation, Stage> stage_map;
419 420 421 422 423
  /*!
   * \brief Internal stage map to map internal ops to stages.
   *  This is created on demand and can be invalidated.
   */
  std::unordered_map<const Node*, Stage> op2stage_cache_;
424 425

  void VisitAttrs(AttrVisitor* v) final {
426
    v->Visit("outputs", &outputs);
427
    v->Visit("stages", &stages);
428
    v->Visit("groups", &groups);
429
    v->Visit("stage_map", &stage_map);
430
  }
431

432 433 434 435 436
  /*! \brief Initialize temp cache. */
  void InitCache();
  /*! \brief Invalidate temp cache. */
  void InvalidateCache();

437 438 439 440 441 442 443
  /*!
   * \brief Create a schedule for array of ops(and their dependencies).
   * \param ops The ops to be scheduled.
   * \return sch The created Schedule.
   */
  static Schedule make(Array<Operation> ops);

444
  static constexpr const char* _type_key = "Schedule";
445
  TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
446 447
};

448 449 450 451 452 453 454 455 456
/*!
 * \brief Create a schedule for array of ops(and their dependencies).
 * \param ops The ops to be scheduled.
 * \return sch The created Schedule.
 */
inline Schedule create_schedule(Array<Operation> ops) {
  return ScheduleNode::make(ops);
}

457 458 459 460
/*! \brief node container for IterVar attr */
class IterVarAttrNode : public Node {
 public:
  /*! \brief The iteration type. */
461 462 463
  IterVarType iter_type{kDataPar};
  /*! \brief The thread this iter Var binds, can be null */
  IterVar bind_thread;
464 465 466 467
  /*! \brief List of tensor to be prefetched in this loop */
  Array<Tensor> prefetch_data;
  /*! \brief The offset used in each prefetch */
  Array<Expr> prefetch_offset;
468 469 470

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("iter_type", &iter_type);
471
    v->Visit("bind_thread", &bind_thread);
472 473
    v->Visit("prefetch_data", &prefetch_data);
    v->Visit("prefetch_offset", &prefetch_offset);
474 475 476
  }

  static constexpr const char* _type_key = "IterVarAttr";
477
  TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode, Node);
478 479
};

480 481
/*! \brief base node of iteration var */
class IterVarRelationNode : public Node {
482 483 484
 public:
  static constexpr const char* _type_key = "IterVarRelation";
  TVM_DECLARE_BASE_NODE_INFO(IterVarRelationNode, Node);
485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500
};

/*!
 * \brief Split the parent domain into product of
 *  outer and iter.
 */
class SplitNode : public IterVarRelationNode {
 public:
  /*! \brief The parent domain */
  IterVar parent;
  /*! \brief The outer domain */
  IterVar outer;
  /*! \brief The inner domain */
  IterVar inner;
  /*! \brief The split factor */
  Expr factor;
501 502
  /*! \brief Number of parts, only factor or nparts can be given */
  Expr nparts;
503 504 505 506 507 508

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("parent", &parent);
    v->Visit("outer", &outer);
    v->Visit("inner", &inner);
    v->Visit("factor", &factor);
509
    v->Visit("nparts", &nparts);
510 511
  }

512 513 514 515 516
  static IterVarRelation make(IterVar parent,
                              IterVar outer,
                              IterVar inner,
                              Expr factor,
                              Expr nparts);
517 518

  static constexpr const char* _type_key = "Split";
519
  TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode);
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
};

/*!
 * \brief Fuse two domains into one domain.
 */
class FuseNode : public IterVarRelationNode {
 public:
  /*! \brief The outer domain */
  IterVar outer;
  /*! \brief The inner domain */
  IterVar inner;
  /*! \brief The target domain */
  IterVar fused;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("outer", &outer);
    v->Visit("inner", &inner);
    v->Visit("fused", &fused);
  }

  static IterVarRelation make(
      IterVar outer, IterVar inner, IterVar fused);

  static constexpr const char* _type_key = "Fuse";
544
  TVM_DECLARE_NODE_TYPE_INFO(FuseNode, IterVarRelationNode);
545 546
};

547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
/*!
 * \brief Rebase the iteration to make min to be 0.
 *  This is useful to normalize the Schedule
 *  to make every leaf variable's min to be 0.
 */
class RebaseNode : public IterVarRelationNode {
 public:
  /*! \brief The parent domain */
  IterVar parent;
  /*! \brief The inner domain */
  IterVar rebased;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("parent", &parent);
    v->Visit("rebased", &rebased);
  }

  static IterVarRelation make(IterVar parent, IterVar rebased);

  static constexpr const char* _type_key = "Rebase";
567
  TVM_DECLARE_NODE_TYPE_INFO(RebaseNode, IterVarRelationNode);
568 569 570
};


571
// implementations
572 573 574 575 576 577 578
inline const StageNode* Stage::operator->() const {
  return static_cast<const StageNode*>(node_.get());
}
inline StageNode* Stage::operator->() {
  return static_cast<StageNode*>(node_.get());
}

579 580 581
inline const ScheduleNode* Schedule::operator->() const {
  return static_cast<const ScheduleNode*>(node_.get());
}
582 583 584
inline ScheduleNode* Schedule::operator->() {
  return static_cast<ScheduleNode*>(node_.get());
}
585

586 587 588 589
inline const IterVarRelationNode* IterVarRelation::operator->() const {
  return static_cast<const IterVarRelationNode*>(node_.get());
}

590 591 592
inline const IterVarAttrNode* IterVarAttr::operator->() const {
  return static_cast<const IterVarAttrNode*>(node_.get());
}
593 594
}  // namespace tvm
#endif  // TVM_SCHEDULE_H_