schedule.h 24.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
tqchen committed
21
 * \file tvm/schedule.h
22 23 24 25 26 27
 * \brief Define a schedule.
 */
#ifndef TVM_SCHEDULE_H_
#define TVM_SCHEDULE_H_

#include <string>
28
#include <unordered_map>
29 30 31 32
#include "base.h"
#include "expr.h"
#include "tensor.h"
#include "tensor_intrin.h"
33 34 35

namespace tvm {

36 37
// Node container for Stage
class StageNode;
38 39
// Node container for Schedule
class ScheduleNode;
40 41
// Node container for IterVarRelation
class IterVarRelationNode;
42 43
// Attribute of itervar.
class IterVarAttrNode;
44 45 46

/*! \brief the attachment type */
enum AttachType : int {
47
  kGroupRoot = 1,
tqchen committed
48
  kInline = 2,
49 50 51
  kInlinedAlready = 3,
  kScope = 4,
  kScanUpdate = 5
52 53
};

54 55
/*! \brief Stage, contains scheduling for a stage of computation. */
class Stage : public NodeRef {
56
 public:
57
  Stage() {}
58
  explicit Stage(NodePtr<Node> n) : NodeRef(n) {}
tqchen committed
59 60 61 62
  /*!
   * \brief create a new schedule for op.
   * \param op The operator in the schedule
   */
63
  explicit Stage(Operation op);
64 65 66 67
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
68
  inline const StageNode* operator->() const;
tqchen committed
69 70 71 72
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
73 74 75 76 77
  inline StageNode* operator->();
  /*!
   * \brief set the memory scope of the stage
   * \param scope The memory scope.
   */
78
  TVM_DLL Stage& set_scope(std::string scope);  // NOLINT(*)
tqchen committed
79 80 81 82
  /*!
   * \brief specify the schedule to be computed at the parent schedule's scope.
   * \param parent The parent schedule.
   * \param scope The iteration point to carry the schedule.
tqchen committed
83
   * \return reference to self.
tqchen committed
84
   */
85
  TVM_DLL Stage& compute_at(Stage parent, IterVar scope);   // NOLINT(*)
tqchen committed
86
  /*!
87
   * \brief Compute the function inline.
tqchen committed
88
   * \return reference to self.
tqchen committed
89
   */
90
  TVM_DLL Stage& compute_inline();   // NOLINT(*)
tqchen committed
91
  /*!
92
   * \brief Compute the function at group root.
tqchen committed
93
   * \return reference to self.
tqchen committed
94
   */
95
  TVM_DLL Stage& compute_root();  // NOLINT(*)
tqchen committed
96
  /*!
97
   * \brief Bind the IterVar to thread index.
98
   *
99 100
   * \param ivar The IterVar to be bound.
   * \param thread_ivar The thread axis to be bound.
101 102
   * \return reference to self.
   */
103
  TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
104
  /*!
105 106
   * \brief Set the predicate to determine whether a store to the array should be performed.
   *  Use this when there are multiple threads performing the same store and we only
107 108 109
   *  need one of them to do the store.
   *
   * \note This is a dangerous scheduling primitive that can change behavior of program.
110
   *    Only do when we are certain that thare are duplicated stores.
111 112 113
   * \param predicate The condition to be checked.
   * \return reference to self.
   */
114
  TVM_DLL Stage& set_store_predicate(Expr predicate);
115
  /*!
116 117 118 119
   * \brief Specify environment threads that launched around the group's scope.
   *  This can only be used in group stage.
   * \param threads The threads to be launched around the scope.
   * \note Each thread can only appear in one env_threads.
120
   *    This is a beta feature.
121 122
   * \return reference to self.
   */
123
  TVM_DLL Stage& env_threads(Array<IterVar> threads);
124
  /*!
tqchen committed
125 126
   * \brief Split the parent by factor, generate
   * \param parent The parent iteration domain.
127
   * \param factor The split factor of the loop.
tqchen committed
128 129
   * \param p_outer The result outer domain
   * \param p_inner The result inner domain.
tqchen committed
130
   * \return reference to self.
tqchen committed
131
   */
132
  TVM_DLL Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner);  // NOLINT(*)
tqchen committed
133
  /*!
134
   * \brief Split the iteration with given number of parts.
tqchen committed
135 136
   *
   * \param parent The parent domain.
137 138
   * \param nparts The number of parts in the outer domain.
   * \param p_outer The result outer domain.
tqchen committed
139
   * \param p_inner The result inner domain.
tqchen committed
140
   * \return reference to self.
tqchen committed
141
   */
142
  TVM_DLL Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner);   // NOLINT(*)
tqchen committed
143 144 145
  /*!
   * \brief Fuse the inner outer domain to the target
   * \param outer The outer domain to be fused.
146
   * \param inner The inner domain to be fused
tqchen committed
147
   * \param p_target The result target domain.
tqchen committed
148
   * \return reference to self.
tqchen committed
149
   */
150
  TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target);  // NOLINT(*)
tqchen committed
151
  /*!
152 153 154 155 156 157
   * \brief Fuse all the axes together into a single axis.
   *
   * \param axes All the axes to be fused.
   * \param p_target The result target domain.
   *
   * \note axes can be an empty array,
158
   *       in that case, a singleton IterVar is created and
159 160 161 162 163
   *       inserted to the outermost loop.
   *       The fuse of empty array is used to support zero-dimension tensors.
   *
   * \return reference to self.
   */
164
  TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target);  // NOLINT(*)
165
  /*!
tqchen committed
166 167
   * \brief Reorder the iteration
   * \param order The order of iteration variable.
tqchen committed
168
   * \return reference to self.
tqchen committed
169
   */
170
  TVM_DLL Stage& reorder(const Array<IterVar>& order);   // NOLINT(*)
171 172 173 174 175 176 177
  /*!
   * \brief Perform tiling on two dimensions
   *  The final loop order from outmost to inner most are
   *  [x_outer, y_outer, x_inner, y_inner]
   *
   * \param x_parent The original x dimension
   * \param y_parent The original y dimension
178 179
   * \param x_factor The stride factor on x axis
   * \param y_factor The stride factor on y axis
180 181 182 183 184 185
   * \param p_x_outer Outer axis of x dimension
   * \param p_y_outer Outer axis of y dimension
   * \param p_x_inner Inner axis of x dimension
   * \param p_y_inner Inner axis of y dimension
   * \return reference to self.
   */
186
  TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent,   // NOLINT(*)
187 188 189
                     Expr x_factor, Expr y_factor,
                     IterVar* p_x_outer, IterVar* p_y_outer,
                     IterVar* p_x_inner, IterVar* p_y_inner);
190
  /*!
191 192 193 194
   * \brief Vectorize iteration.
   * \param var The axis to be vectorized.
   * \return reference to self.
   */
195
  TVM_DLL Stage& vectorize(IterVar var);   // NOLINT(*)
196
  /*!
197 198 199 200 201 202
   * \brief Replace computation of the current stage by tensor intrinsic f.
   * \param var The axis marks beginning of tensorization.
   *  Every operations inside the axis(include axis itself is tensorized).
   * \param f The Tensor compute intrinsics.
   * \return reference to self.
   */
203
  TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f);   // NOLINT(*)
204
  /*!
205
   * \brief Unroll iteration.
206
   * \param var The axis to be unrolled.
207 208
   * \return reference to self.
   */
209
  TVM_DLL Stage& unroll(IterVar var);   // NOLINT(*)
210
  /*!
211 212 213 214
   * \brief Parallelize iteration.
   * \param var The axis to be parallelized.
   * \return reference to self.
   */
215
  TVM_DLL Stage& parallel(IterVar var);   // NOLINT(*)
216
  /*!
217 218 219 220
   * \brief Annotate the iteration with pragma
   *
   * \param var The axis to be parallelized.
   * \param pragma_type The pragma type.
221
   * \param pragma_value The pragma value
222 223 224
   *
   * \return reference to self.
   */
225
  TVM_DLL Stage& pragma(IterVar var,
226 227
                       const std::string& pragma_type,
                       const Expr& pragma_value = Expr());   // NOLINT(*)
228
  /*!
229 230 231 232 233 234
   * \brief Fetch data in advance.
   * \param domain the tensor to be prefetched
   * \param var the iteration point at which to apply prefetching
   * \param offset the number of iterations be to fetched in advance
   * \return reference to self
   */
235
  TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
236
  /*!
237 238 239 240 241 242 243 244 245
   * \brief Set alignment requirement for specific dimension.
   *
   *  Such that stride[axis] == k * factor + offset for some k.
   *
   * \param axis The dimension to be specified for alignment.
   * \param factor The factor multiple of alignment
   * \param offset The required offset factor.
   * \return reference to self
   */
246
  TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
247
  /*!
248 249 250
   * \brief Compute current stage with double buffering.
   * \return reference to self.
   */
251
  TVM_DLL Stage& double_buffer();   // NOLINT(*)
252
  /*!
253 254 255 256 257
   * \brief Schedule for OpenGL fragment shader.
   * \return reference to self.
   */
  Stage& opengl(); // NOLINT(*)
  /*!
258 259 260
   * \brief whether the stage has been scheduled.
   * \return whether the stage has been scheduled.
   */
261 262 263 264 265 266 267 268 269
  bool is_scheduled() const;
  /*!
   * \brief Get attachment spec of current stage.
   *  If the stage compute at Group root, this function
   *  will traverse the group function to get the
   *  final spec from the group.
   * \return A stage representing the attach spec of the group.
   */
  Stage GetAttachSpec() const;
270 271
  // declare container type
  using ContainerType = StageNode;
272 273 274 275 276 277 278 279 280 281
};

/*!
 * \brief Global schedule container
 *  For operations and all the operations they depend on.
 *  The schedule per Operation is named as stage.
 */
class Schedule : public NodeRef {
 public:
  Schedule() {}
282
  explicit Schedule(NodePtr<Node> n) : NodeRef(n) {}
283
  /*!
284 285 286 287 288
   * \brief Get a copy of current schedule.
   * \return The copied schedule.
   */
  Schedule copy() const;
  /*!
289 290 291
   * \brief Get the stage corresponds to the op
   * \param op The operation.
   */
292
  TVM_DLL Stage operator[](const Operation& op);
293 294 295 296 297
  /*!
   * \brief Short hand for getting the stage of tensor's operation.
   * \param tensor The tensor
   * \return The stage corresponding to the tensor's op
   */
298
  TVM_DLL Stage operator[](const Tensor& tensor) {
299 300 301
    return this->operator[](tensor->op);
  }
  /*!
302 303 304 305 306 307 308 309
   * \brief Create a new stage group for all intermediate
   *  operations between inputs and outputs.
   *
   * \param outputs The output boundary of the group.
   * \param inputs The input boundary of the group.
   * \param include_inputs Whether include inputs if they are reachable from outputs.
   * \return The new grouped stage.
   */
310
  TVM_DLL Stage create_group(const Array<Tensor>& outputs,
311 312 313
                     const Array<Tensor>& inputs,
                     bool include_inputs = false);
  /*!
314 315 316 317 318 319 320 321
   * \brief create a cache read of original tensor for readers.
   *  This will mutate the body of the readers.
   *  A new stage will be created for the tensor.
   * \param tensor The tensor cached.
   * \param scope The scope of the cache.
   * \param readers The readers to redirect to the tensor.
   * \return The created tensor.
   */
322
  TVM_DLL Tensor cache_read(const Tensor& tensor,
323 324 325 326 327
                    const std::string& scope,
                    const Array<Operation>& readers);
  /*!
   * \brief Create a cache write tensor for producing tensor.
   *  The the tensor will take over body of original tensor op.
328 329 330 331 332 333 334 335 336
   *
   *  This function can be used to do data layout transformation.
   *  If there is a split/fuse/reorder on the data parallel axis of tensor
   *  before cache_write is called. The intermediate cache stores
   *  the data in the layout as the iteration order of leave axis.
   *  The data will be transformed back to the original layout in the original tensor.
   *  User can further call compute_inline to inline the original layout and keep
   *  the data stored in the transformed layout.
   *
337 338 339 340
   * \param tensor The tensors to be produced.
   * \param scope The scope of the storage.
   * \return The created tensor.
   */
341
  TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
342 343 344 345 346 347 348 349 350 351 352
  /*!
   * \brief Create a cache write tensor for producing tensor.
   *  The the tensor will take over body of original tensor op.
   *
   *  This function can be used to do data layout transformation.
   *  If there is a split/fuse/reorder on the data parallel axis of tensor
   *  before cache_write is called. The intermediate cache stores
   *  the data in the layout as the iteration order of leave axis.
   *  The data will be transformed back to the original layout in the original tensor.
   *  User can further call compute_inline to inline the original layout and keep
   *  the data stored in the transformed layout.
tqchen committed
353
   *
354 355 356 357
   * \param tensor The tensor to be produced.
   * \param scope The scope of the storage.
   * \return The created tensor.
   */
358
  TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
359
  /*!
360 361
   * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
   * This will create a new stage that generated the new tensor with axis
362
   * as the first dimension. The tensor's body will be rewritten as a reduction
363 364
   * over the factored tensor.
   *
365 366
   *  P. Suriana, A. Adams and S. Kamil. Parallel associative reductions in halide. CGO'17
   *
367 368
   * \param tensor The tensor to be factored.
   * \param axis The reduction axis in tensor's schedule to be factored.
369
   * \param factor_axis The position where the new axis is placed.
370
   * \return The created factored tensors.
371
   */
372
  TVM_DLL Array<Tensor> rfactor(const Tensor& tensor,
373 374
                        const IterVar& axis,
                        int factor_axis = 0);
375
  /*!
376 377 378 379 380 381 382
   * \brief Normalize the schedule.
   *  This is needed before bound inference.
   *  Insert necessary RebaseNode to make sure all leaf_iter_vars
   *  are in form [0, extent)
   *
   * \return A normalized schedule, can be same as current one.
   */
383
  Schedule normalize();
384
  /*!
385 386 387 388
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const ScheduleNode* operator->() const;
389 390 391 392 393
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline ScheduleNode* operator->();
394 395
  // declare container type
  using ContainerType = ScheduleNode;
396 397
};

398 399 400 401 402 403 404
/*!
 * \brief The schedule relation between IterVars
 *  can be Split, Fuse.
 */
class IterVarRelation : public NodeRef {
 public:
  IterVarRelation() {}
405
  explicit IterVarRelation(NodePtr<Node> n) : NodeRef(n) {}
406 407 408 409 410 411 412
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const IterVarRelationNode* operator->() const;
};

413 414 415 416 417 418
/*!
 * \brief Additional scheduable attributes about IterVar.
 */
class IterVarAttr : public NodeRef {
 public:
  IterVarAttr() {}
419
  explicit IterVarAttr(NodePtr<Node> n) : NodeRef(n) {}
420 421 422 423 424 425 426
  /*!
   * \brief access the internal node container
   * \return the pointer to the internal node container
   */
  inline const IterVarAttrNode* operator->() const;
};

427
/*!
428
 * \brief represents a stage.
429
 *
430
 *  relations form a Directed acylic hypergraph in bipartite manner.
431 432
 *  With each node is represented by a IterVar,
 *  and each hyper-edge is represented by a IterVarRelation.
433
 *  The relations connects the IterVars in the graph.
434
 *
435 436 437 438
 *  Besides typical stage that corresponds to operations.
 *  There is also group stage, which groups stages together.
 *  Each stage's group(given by group) represent an constraint,
 *  the stage can only be attached to stages within the group.
439
 *
440
 *  The group stage node can be attached to IterVars as in normal stage.
441
 */
442
class StageNode : public Node {
443
 public:
444 445 446 447
  /*!
   * \brief The operation of stage, can be different from original op.
   *  If it is null, then this stage is a group stage.
   */
448 449 450 451 452 453 454
  Operation op;
  /*!
   * \brief The original operator.
   *  The op field can change during schedule to alternate the dataflow,
   *  while origin_op remains fixed.
   */
  Operation origin_op;
455 456
  /*! \brief All the nodes in the iter var */
  Array<IterVar> all_iter_vars;
457
  /*! \brief The current active leaf iter vars in the stage. */
458
  Array<IterVar> leaf_iter_vars;
459 460 461
  /*!
   * \brief Specify threads to be launched at the stage.
   *  This is only valid for composite ops such as Scan.
462
   * \note Experimental primitive: used for thread persistence.
463
   */
464
  Array<IterVar> env_threads;
465 466 467 468 469 470
  /*!
   * \brief The predicate under which store can happen
   *  Use this when there can be duplicated threads doing the same store.
   * \note Experimental primitive: used by cross thread-reduction.
   */
  Expr store_predicate;
471 472
  /*! \brief The relation bwteen of IterVars */
  Array<IterVarRelation> relations;
473 474
  /*! \brief additional attributes about iter var. */
  Map<IterVar, IterVarAttr> iter_var_attrs;
tqchen committed
475
  /*! \brief The attachment type of the schedule */
476
  AttachType attach_type{kGroupRoot};
477 478 479 480
  /*! \brief The attach point of this schedule. */
  IterVar attach_ivar;
  /*! \brief The stage this node attaches to */
  Stage attach_stage;
481 482
  /*! \brief The thread storage scope level of the stage */
  std::string scope;
483 484
  /*! \brief Whether this is an output stage */
  bool is_output{false};
485 486
  /*! \brief Whether this is an OpenGL stage */
  bool is_opengl{false};
487 488
  /*! \brief Whether apply double buffer optimization to this stage */
  bool double_buffer{false};
489 490 491 492 493 494 495
  /*!
   * \brief The parent group of the current stage.
   *  The stage cannot be assigned to stages outside the group.
   */
  Stage group;
  /*! \brief Number of direct child stages, only used for group stage.*/
  int num_child_stages{0};
496

497
  void VisitAttrs(AttrVisitor* v) final {
tqchen committed
498
    v->Visit("op", &op);
499
    v->Visit("origin_op", &origin_op);
500 501
    v->Visit("all_iter_vars", &all_iter_vars);
    v->Visit("leaf_iter_vars", &leaf_iter_vars);
502
    v->Visit("env_threads", &env_threads);
503
    v->Visit("relations", &relations);
504
    v->Visit("iter_var_attrs", &iter_var_attrs);
tqchen committed
505
    v->Visit("attach_type", &attach_type);
506 507
    v->Visit("attach_ivar", &attach_ivar);
    v->Visit("attach_stage", &attach_stage);
508
    v->Visit("scope", &scope);
509
    v->Visit("is_output", &is_output);
510
    v->Visit("is_opengl", &is_opengl);
511
    v->Visit("double_buffer", &double_buffer);
512 513
    v->Visit("group", &group);
    v->Visit("num_child_stages", &num_child_stages);
514 515 516
  }

  static constexpr const char* _type_key = "Stage";
517
  TVM_DECLARE_NODE_TYPE_INFO(StageNode, Node);
518 519 520 521 522
};

/*! \brief node container for schedule */
class ScheduleNode : public Node {
 public:
523 524
  /*! \brief The output operations in original data flow graph */
  Array<Operation> outputs;
525
  /*!
526
   * \brief list of all stages for ops.
527
   * The stages are sorted in dependency order.
528 529
   */
  Array<Stage> stages;
530 531 532 533 534
  /*!
   * \brief List of all stage groups.
   */
  Array<Stage> groups;
  /*! \brief map of original operation to the stages */
535
  Map<Operation, Stage> stage_map;
536 537 538 539 540
  /*!
   * \brief Internal stage map to map internal ops to stages.
   *  This is created on demand and can be invalidated.
   */
  std::unordered_map<const Node*, Stage> op2stage_cache_;
541 542

  void VisitAttrs(AttrVisitor* v) final {
543
    v->Visit("outputs", &outputs);
544
    v->Visit("stages", &stages);
545
    v->Visit("groups", &groups);
546
    v->Visit("stage_map", &stage_map);
547
  }
548

549 550 551 552 553
  /*! \brief Initialize temp cache. */
  void InitCache();
  /*! \brief Invalidate temp cache. */
  void InvalidateCache();

554
  /*!
555 556 557 558
   * \brief Check if the schedule contains an Operation.
   * \param op The candidate Operation.
   * \return true if the schedule has the Operation. Otherwise, false.
   */
559
  TVM_DLL bool Contain(const Operation& op) const;
560 561 562 563 564 565

  /*!
   * \brief Check if the schedule contains a Tensor.
   * \param tensor The candidate tensor.
   * \return true if the schedule has the tensor. Otherwise, false.
   */
566
  TVM_DLL bool Contain(const Tensor& tensor) const {
567 568 569 570
    return Contain(tensor->op);
  }

  /*!
571 572 573 574
   * \brief Create a schedule for array of ops(and their dependencies).
   * \param ops The ops to be scheduled.
   * \return sch The created Schedule.
   */
575
  TVM_DLL static Schedule make(Array<Operation> ops);
576

577
  static constexpr const char* _type_key = "Schedule";
578
  TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
579 580
};

581 582 583 584 585 586 587 588 589
/*!
 * \brief Create a schedule for array of ops(and their dependencies).
 * \param ops The ops to be scheduled.
 * \return sch The created Schedule.
 */
inline Schedule create_schedule(Array<Operation> ops) {
  return ScheduleNode::make(ops);
}

590 591 592 593
/*! \brief node container for IterVar attr */
class IterVarAttrNode : public Node {
 public:
  /*! \brief The iteration type. */
594 595 596
  IterVarType iter_type{kDataPar};
  /*! \brief The thread this iter Var binds, can be null */
  IterVar bind_thread;
597 598 599 600
  /*! \brief List of tensor to be prefetched in this loop */
  Array<Tensor> prefetch_data;
  /*! \brief The offset used in each prefetch */
  Array<Expr> prefetch_offset;
601 602 603 604 605
  /*!
   * \brief Tensor intrinsic used in tensorization,
   *   when the axis is marked as Tensorized
   */
  TensorIntrin tensor_intrin;
606 607 608 609
  /*! \brief Alignment factor of buffer dimension */
  int dim_align_factor{0};
  /*! \brief Alignment offset of buffer dimension */
  int dim_align_offset{0};
610
  /*!
611
   * \brief Additional pragma keys, array of StringImm
612
   */
613 614 615 616 617
  Array<Expr> pragma_keys;
  /*!
   * \brief Additional values of pragma, if any
   */
  Array<Expr> pragma_values;
618 619 620

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("iter_type", &iter_type);
621
    v->Visit("bind_thread", &bind_thread);
622 623
    v->Visit("prefetch_data", &prefetch_data);
    v->Visit("prefetch_offset", &prefetch_offset);
624
    v->Visit("tensor_intrin", &tensor_intrin);
625 626
    v->Visit("dim_align_factor", &dim_align_factor);
    v->Visit("dim_align_offset", &dim_align_offset);
627 628
    v->Visit("pragma_keys", &pragma_keys);
    v->Visit("pragma_values", &pragma_values);
629 630 631
  }

  static constexpr const char* _type_key = "IterVarAttr";
632
  TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode, Node);
633 634
};

635 636
/*! \brief base node of iteration var */
class IterVarRelationNode : public Node {
637 638 639
 public:
  static constexpr const char* _type_key = "IterVarRelation";
  TVM_DECLARE_BASE_NODE_INFO(IterVarRelationNode, Node);
640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655
};

/*!
 * \brief Split the parent domain into product of
 *  outer and iter.
 */
class SplitNode : public IterVarRelationNode {
 public:
  /*! \brief The parent domain */
  IterVar parent;
  /*! \brief The outer domain */
  IterVar outer;
  /*! \brief The inner domain */
  IterVar inner;
  /*! \brief The split factor */
  Expr factor;
656 657
  /*! \brief Number of parts, only factor or nparts can be given */
  Expr nparts;
658 659 660 661 662 663

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("parent", &parent);
    v->Visit("outer", &outer);
    v->Visit("inner", &inner);
    v->Visit("factor", &factor);
664
    v->Visit("nparts", &nparts);
665 666
  }

667 668 669 670 671
  static IterVarRelation make(IterVar parent,
                              IterVar outer,
                              IterVar inner,
                              Expr factor,
                              Expr nparts);
672 673

  static constexpr const char* _type_key = "Split";
674
  TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode);
675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698
};

/*!
 * \brief Fuse two domains into one domain.
 */
class FuseNode : public IterVarRelationNode {
 public:
  /*! \brief The outer domain */
  IterVar outer;
  /*! \brief The inner domain */
  IterVar inner;
  /*! \brief The target domain */
  IterVar fused;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("outer", &outer);
    v->Visit("inner", &inner);
    v->Visit("fused", &fused);
  }

  static IterVarRelation make(
      IterVar outer, IterVar inner, IterVar fused);

  static constexpr const char* _type_key = "Fuse";
699
  TVM_DECLARE_NODE_TYPE_INFO(FuseNode, IterVarRelationNode);
700 701
};

702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721
/*!
 * \brief Rebase the iteration to make min to be 0.
 *  This is useful to normalize the Schedule
 *  to make every leaf variable's min to be 0.
 */
class RebaseNode : public IterVarRelationNode {
 public:
  /*! \brief The parent domain */
  IterVar parent;
  /*! \brief The inner domain */
  IterVar rebased;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("parent", &parent);
    v->Visit("rebased", &rebased);
  }

  static IterVarRelation make(IterVar parent, IterVar rebased);

  static constexpr const char* _type_key = "Rebase";
722
  TVM_DECLARE_NODE_TYPE_INFO(RebaseNode, IterVarRelationNode);
723 724 725
};


726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744
/*!
 * \brief Singleton iterator [0, 1)
 */
class SingletonNode : public IterVarRelationNode {
 public:
  /*! \brief The singleton iterator */
  IterVar iter;

  void VisitAttrs(AttrVisitor* v) final {
    v->Visit("iter", &iter);
  }

  static IterVarRelation make(IterVar iter);

  static constexpr const char* _type_key = "Singleton";
  TVM_DECLARE_NODE_TYPE_INFO(SingletonNode, IterVarRelationNode);
};


745
// implementations
746 747 748 749 750 751 752
inline const StageNode* Stage::operator->() const {
  return static_cast<const StageNode*>(node_.get());
}
inline StageNode* Stage::operator->() {
  return static_cast<StageNode*>(node_.get());
}

753 754 755
inline const ScheduleNode* Schedule::operator->() const {
  return static_cast<const ScheduleNode*>(node_.get());
}
756 757 758
inline ScheduleNode* Schedule::operator->() {
  return static_cast<ScheduleNode*>(node_.get());
}
759

760 761 762 763
inline const IterVarRelationNode* IterVarRelation::operator->() const {
  return static_cast<const IterVarRelationNode*>(node_.get());
}

764 765 766
inline const IterVarAttrNode* IterVarAttr::operator->() const {
  return static_cast<const IterVarAttrNode*>(node_.get());
}
767 768
}  // namespace tvm
#endif  // TVM_SCHEDULE_H_