schedule.cc 5.91 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2016 by Contributors
 * \file schedule.cc
 */
#include <tvm/schedule.h>

namespace tvm {

tqchen committed
9 10 11 12 13 14 15 16 17 18 19
namespace {

// find first occurance location in leaf
size_t FindIterVar(ArrayNode* array_node, const IterVar& v) {
  const Node* n = v.get();
  for (size_t i = 0; i < array_node->data.size(); ++i) {
    if (array_node->data[i].get() == n) return i;
  }
  return array_node->data.size();
}

tqchen committed
20 21 22 23 24 25 26 27 28 29 30 31
size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) {
  size_t pos = FindIterVar(leaf_vars, v);
  if (pos < leaf_vars->data.size()) return pos;

  if (FindIterVar(all_vars, v) < all_vars->data.size()) {
    LOG(FATAL) << "Operate on iter var " << v
               << "that has already been splitted";
  } else {
    LOG(FATAL) << "Operate on iter var " << v
               << "that is not part of the schedule";
  }
  return 0;
tqchen committed
32 33
}

tqchen committed
34 35 36 37 38 39 40 41 42 43 44 45 46 47
void Split(ScheduleNode* self, IterVar parent,
           IterVar outer, IterVar inner, Expr factor) {
  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
  size_t pos = FindLeafVar(all_vars, leaf_vars, parent);

  self->relations.push_back(SplitNode::make(parent, outer, inner, factor));
  // add vars to all vars
  all_vars->data.push_back(outer.node_);
  all_vars->data.push_back(inner.node_);
  // replace the position.
  leaf_vars->data.erase(leaf_vars->data.begin() + pos);
  leaf_vars->data.insert(leaf_vars->data.begin() + pos, inner.node_);
  leaf_vars->data.insert(leaf_vars->data.begin() + pos, outer.node_);
tqchen committed
48 49
}

tqchen committed
50 51
}  // namespace

tqchen committed
52
Schedule::Schedule(Operation op, std::string scope) {
53
  auto n = std::make_shared<ScheduleNode>();
tqchen committed
54
  n->op = op;
55
  n->scope = scope;
tqchen committed
56 57
  n->all_iter_vars = op->root_iter_vars();
  n->leaf_iter_vars = op->root_iter_vars();
58 59 60
  node_ = n;
}

tqchen committed
61 62 63 64
Schedule& Schedule::compute_at(Schedule parent, IterVar scope) {   // NOLINT(*)
  CHECK_EQ((*this)->attach_type, kNone);
  (*this)->attach_type = kScope;
  (*this)->attach_parent = scope;
tqchen committed
65 66 67 68 69 70 71 72
  bool found = false;
  for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) {
    if (scope == parent->leaf_iter_vars[i]) {
      found = true; break;
    }
  }
  CHECK(found)
      << "Cannot compute at a iteration variable that is not part of parent leaf vars";
tqchen committed
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
  parent->children.push_back(*this);
  return *this;
}

Schedule& Schedule::compute_inline(Schedule parent) {   // NOLINT(*)
  CHECK_EQ((*this)->attach_type, kNone);
  (*this)->attach_type = kInline;
  parent->children.push_back(*this);
  return *this;
}

Schedule& Schedule::compute_root(Schedule parent) {   // NOLINT(*)
  CHECK_EQ((*this)->attach_type, kNone);
  (*this)->attach_type = kRoot;
  parent->children.push_back(*this);
  return *this;
}

Schedule& Schedule::split(
    IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) {  // NOLINT(*)
tqchen committed
93 94 95 96
  // place holder for the splitted results.
  IterVar outer(Range(), parent->var->name_hint + ".outer");
  IterVar inner(Range(), parent->var->name_hint + ".inner");
  *p_outer = outer; *p_inner = inner;
tqchen committed
97

tqchen committed
98 99 100 101 102 103 104 105 106
  Split(operator->(), parent, outer, inner, factor);
  return *this;
}

Schedule& Schedule::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
  // place holder for the splitted results.
  IterVar inner(Range(), parent->var->name_hint + ".inner");
  *p_inner = inner;
  Split(operator->(), parent, outer, inner, factor);
tqchen committed
107 108 109 110

  return *this;
}

tqchen committed
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
Schedule& Schedule::fuse(IterVar inner, IterVar outer, IterVar* p_target) {  // NOLINT(*)
  IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused");
  ScheduleNode* self = operator->();
  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();

  self->relations.push_back(FuseNode::make(inner, outer, fused));
  all_vars->data.push_back(fused.node_);

  size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner);
  size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer);
  CHECK_EQ(pos_inner, pos_outer + 1)
      << "Can only fuse iterations that are consecutive between each other";
  leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
                        leaf_vars->data.begin() + pos_inner);
  leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
                         fused.node_);
  return *this;
}

Schedule& Schedule::reorder(const Array<IterVar>& order) {  // NOLINT(*)
  ScheduleNode* self = operator->();
  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
  std::vector<size_t> pos;
tqchen committed
136

tqchen committed
137 138 139 140 141 142 143 144 145 146 147 148 149
  for (size_t i = 0; i < order.size(); ++i) {
    pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
  }
  std::vector<std::shared_ptr<Node> > temp;
  for (size_t i = 0; i < pos.size(); ++i) {
    temp.emplace_back(leaf_vars->data[pos[i]]);
  }
  std::sort(pos.begin(), pos.end());
  for (size_t i = 0; i < pos.size(); ++i) {
    leaf_vars->data[pos[i]] = temp[i];
  }
  return *this;
}
tqchen committed
150

151 152 153
Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent,
                         IterVar* p_x_outer, IterVar* p_y_outer,
                         IterVar* p_x_inner, IterVar* p_y_inner,
ZihengJiang committed
154 155 156
                         Expr x_factor, Expr y_factor) { // NOLINT(*)
  split(x_parent, p_x_outer, p_x_inner, x_factor);
  split(y_parent, p_y_outer, p_y_inner, y_factor);
157
  reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
ZihengJiang committed
158 159 160
  return *this;
}

161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
IterVarRelation SplitNode::make(
    IterVar parent, IterVar outer,
    IterVar inner, Expr factor) {
  auto n = std::make_shared<SplitNode>();
  n->parent = parent;
  n->outer = outer;
  n->inner = inner;
  n->factor = factor;
  return IterVarRelation(n);
}

IterVarRelation FuseNode::make(
    IterVar outer, IterVar inner, IterVar fused) {
  auto n = std::make_shared<FuseNode>();
  n->outer = outer;
  n->inner = inner;
  n->fused = fused;
  return IterVarRelation(n);
}

181
TVM_REGISTER_NODE_TYPE(ScheduleNode);
182 183
TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode);
184 185

}  // namespace tvm