schedule_dataflow_rewrite.cc 15.2 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2017 by Contributors
 * \file schedule_dataflow_rewrite.cc
 */
#include <tvm/schedule.h>
6
#include <tvm/operation.h>
7 8 9
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
10
#include "./message_passing.h"
11
#include "../pass/ir_util.h"
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

namespace tvm {

// find first occurance location in leaf
template<typename T>
size_t FindNodeRef(ArrayNode* array_node, const T& v) {
  const Node* n = v.get();
  for (size_t i = 0; i < array_node->data.size(); ++i) {
    if (array_node->data[i].get() == n) return i;
  }
  return array_node->data.size();
}

// The replacer of cache.
class VarReplacer : public ir::IRMutator {
 public:
  explicit VarReplacer(
      const std::unordered_map<const Variable*, Expr>& vsub)
      : vsub_(vsub) {}
  Expr Mutate_(const Variable* op, const Expr& e) {
    auto it = vsub_.find(op);
    if (it != vsub_.end()) return it->second;
    return e;
  }

 private:
  const std::unordered_map<const Variable*, Expr>& vsub_;
};

// Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced.
void ReplaceDataFlow(const Array<Stage>& stages,
44
                     std::unordered_map<Tensor, Tensor>* vmap) {
45
  for (Stage s : stages) {
46 47 48 49
    Operation op = s->op->ReplaceInputs(s->op, *vmap);
    if (!op.same_as(s->op)) {
      for (int i = 0; i < op->num_outputs(); ++i) {
        (*vmap)[s->op.output(i)] = op.output(i);
50
      }
51
      s->op = op;
52 53 54 55 56 57 58
    }
  }
}

Tensor Schedule::cache_read(const Tensor& tensor,
                            const std::string& scope,
                            const Array<Operation>& readers) {
59
  (*this)->InvalidateCache();
60 61 62 63 64 65 66 67 68 69 70
  // create identity mapping.
  std::ostringstream os;
  os << tensor->op->name;
  if (tensor->op->num_outputs() != 1) {
    os << ".v" << tensor->value_index;
  }
  os << "." << scope;

  Tensor cache = compute(tensor->shape, [&tensor](const Array<Var>& i) {
      return tensor(Array<Expr>(i.begin(), i.end()));
    }, os.str());
71 72
  std::unordered_map<Tensor, Tensor> vsub;
  vsub[tensor] = cache;
73

74
  std::unordered_map<Tensor, Tensor> vmap;
75 76
  for (Operation op : readers) {
    Stage s = operator[](op);
77 78
    Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
    CHECK(!repl_op.same_as(s->op))
79
        << "Cannot find " << tensor
80 81
        << " in the inputs of " << s->op;
    vmap[s->op.output(0)] = repl_op.output(0);
82 83 84 85
    s->op = repl_op;
  }
  ReplaceDataFlow((*this)->stages, &vmap);
  ArrayNode* stages = (*this)->stages.CopyOnWrite();
86 87
  Stage op_stage = operator[](tensor->op);
  size_t pos = FindNodeRef(stages, op_stage);
88 89 90 91 92 93
  Stage cache_stage = Stage(cache->op);
  cache_stage.set_scope(scope);
  CHECK_LT(pos, stages->data.size());
  stages->data.insert(stages->data.begin() + pos + 1,
                      cache_stage.node_);
  (*this)->stage_map.Set(cache->op, cache_stage);
94 95 96 97 98
  // Update group
  cache_stage->group = op_stage->group;
  if (cache_stage->group.defined()) {
    ++cache_stage->group->num_child_stages;
  }
99 100 101 102 103
  return cache;
}

Tensor Schedule::cache_write(const Tensor& tensor,
                             const std::string& scope) {
104
  (*this)->InvalidateCache();
105 106 107 108 109 110 111 112 113 114 115 116 117
  Stage orig_stage = operator[](tensor->op);
  const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
  CHECK(compute)
      << "cache write only take ComputeOp as writers";
  CHECK_EQ(orig_stage->relations.size(), 0U)
      << "Create cache_write before doing split/fuse/reorder";
  compute = orig_stage->op.as<ComputeOpNode>();
  CHECK(compute);
  Array<Expr> args;
  Array<IterVar> new_axis;
  std::unordered_map<const Variable*, Expr> vsub;
  for (IterVar iv : compute->axis) {
    args.push_back(iv->var);
118 119
    IterVar new_iv = IterVarNode::make(
        iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
120 121 122 123
    new_axis.push_back(new_iv);
    vsub[iv->var.get()] = new_iv->var;
  }
  VarReplacer repl(vsub);
124
  Expr body = repl.Mutate(compute->body[tensor->value_index]);
125
  Operation cache_op = ComputeOpNode::make(
126
      compute->name + "." + scope, compute->tag, new_axis, {body});
127 128
  Tensor cache_tensor = cache_op.output(0);
  Operation orig_new_op = ComputeOpNode::make(
129
      compute->name, compute->tag, compute->axis,
130
      {cache_tensor(args)});
131

132 133
  std::unordered_map<Tensor, Tensor> vmap;
  vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
134 135 136 137 138 139 140 141 142 143 144 145 146 147
  ReplaceDataFlow((*this)->stages, &vmap);
  // mutate orig stage
  orig_stage->op = orig_new_op;
  orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
  orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
  // create schedule for new cached stage.
  ArrayNode* stages = (*this)->stages.CopyOnWrite();
  size_t pos = FindNodeRef(stages, orig_stage);
  Stage cache_stage = Stage(cache_op);
  cache_stage.set_scope(scope);
  CHECK_LT(pos, stages->data.size());
  stages->data.insert(stages->data.begin() + pos,
                      cache_stage.node_);
  (*this)->stage_map.Set(cache_op, cache_stage);
148 149 150 151 152
  // Update group
  cache_stage->group = orig_stage->group;
  if (cache_stage->group.defined()) {
    ++cache_stage->group->num_child_stages;
  }
153 154 155 156 157 158
  return cache_tensor;
}

void RebaseNonZeroMinLoop(const Schedule& sch) {
  std::unordered_map<IterVar, IterVar> rebase_map;
  for (Stage s : sch->stages) {
159
    if (s->attach_type == kInlinedAlready) continue;
160 161 162 163 164

    auto root_iter_vars = s->op->root_iter_vars();
    ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
    for (IterVar iv : root_iter_vars) {
      size_t idx = FindNodeRef(leaf_vars, iv);
165 166 167 168 169 170
      auto it  = s->iter_var_attrs.find(iv);
      // don;t need to rebase path that are binded.
      if (it != s->iter_var_attrs.end() &&
          (*it).second->bind_thread.defined()) {
        continue;
      }
171 172
      if (idx < leaf_vars->data.size()) {
        // insert rebase
173
        IterVar rebased = IterVarNode::make(
174
            Range(), iv->var.copy_with_suffix(""), iv->iter_type);
175
        s->relations.push_back(RebaseNode::make(iv, rebased));
176 177 178
        if (s->iter_var_attrs.count(iv)) {
          s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
        }
179 180 181 182 183 184 185 186 187 188 189 190
        leaf_vars->data[idx] = rebased.node_;
        rebase_map[iv] = rebased;
      }
    }
  }
  // remap the parent relation
  for (Stage s : sch->stages) {
    if (s->attach_type != kScope) continue;
    if (rebase_map.count(s->attach_ivar)) {
      s->attach_ivar = rebase_map.at(s->attach_ivar);
    }
  }
191 192 193 194 195 196
  for (Stage s : sch->groups) {
    if (s->attach_type != kScope) continue;
    if (rebase_map.count(s->attach_ivar)) {
      s->attach_ivar = rebase_map.at(s->attach_ivar);
    }
  }
197 198
}

199 200
void InjectInline(ScheduleNode* sch) {
  sch->InvalidateCache();
201

202 203
  std::vector<Array<Expr>> new_body(sch->stages.size());
  std::vector<bool> changed(sch->stages.size(), false);
204 205 206 207 208 209
  // inline all the ops
  for (size_t i = sch->stages.size(); i != 0; --i) {
    Stage stage = sch->stages[i - 1];
    if (stage->attach_type == kInline) {
      stage->attach_type = kInlinedAlready;
      Array<Var> args;
210
      Array<Expr> body;
211 212 213 214 215 216 217 218 219 220 221 222 223 224
      {
        // setup args
        const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
        CHECK(compute)
            << "can only inline compute op";
        for (auto iv : compute->axis) {
          args.push_back(iv->var);
        }
        body = compute->body;
      }
      for (size_t j = i; j < sch->stages.size(); ++j) {
        Stage s = sch->stages[j];
        const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
        if (compute) {
225
          if (!new_body[j].size()) {
226 227
            new_body[j] = s->op.as<ComputeOpNode>()->body;
          }
228 229 230 231 232
          for (size_t k = 0; k < body.size(); ++k) {
            changed[j] = true;
            new_body[j].Set(k, ir::Inline(ir::Evaluate::make(new_body[j][k]),
                            stage->op, args, body[k]).as<ir::Evaluate>()->value);
          }
233 234 235 236
        }
      }
    }
  }
237
  std::unordered_map<Tensor, Tensor> repl;
238 239
  // rewrite dataflow
  for (size_t i = 0; i < sch->stages.size(); ++i) {
240 241
    Stage s = sch->stages[i];
    if (s->attach_type == kInlinedAlready) continue;
242
    if (new_body[i].size()) {
243
      // Logics from ReplaceDataFlow
244 245
      const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
      CHECK(compute);
246
      Operation op = s->op;
247
      if (changed[i]) {
248
        op = ComputeOpNode::make(
249
            compute->name, compute->tag, compute->axis, new_body[i]);
250 251 252
      }
      op = op->ReplaceInputs(op, repl);
      if (!op.same_as(s->op)) {
253 254 255 256
        for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
          repl[s->op.output(idx)] = op.output(idx);
          s->op = op;
        }
257
      }
258 259 260 261 262 263 264 265
    } else {
      Operation op = s->op->ReplaceInputs(s->op, repl);
      if (!op.same_as(s->op)) {
        for (int j = 0; j < op->num_outputs(); ++j) {
          repl[s->op.output(j)] = op.output(j);
        }
        s->op = op;
      }
266 267 268 269
    }
  }
}

270 271 272 273 274
Schedule Schedule::normalize() {
  Schedule sn = copy();
  InjectInline(sn.operator->());
  RebaseNonZeroMinLoop(sn);
  return sn;
275 276
}

277
// Handle reduction factor.
278 279
Array<Tensor> Schedule::rfactor(const Tensor& tensor,
                                const IterVar& axis) {
280
  (*this)->InvalidateCache();
281 282 283 284 285
  using ir::Reduce;
  CHECK_EQ(axis->iter_type, kCommReduce)
      << "Can only factor reduction axis";
  Stage reduce_stage = operator[](tensor->op);
  const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
286
  CHECK(compute_op) << "Can only factor ComputeOp";
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
  ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
  {
    size_t axis_pos = FindNodeRef(leaf_vars, axis);
    CHECK_NE(axis_pos, leaf_vars->data.size())
        << "Cannot find IterVar " << axis << " in leaf iter vars";
  }
  // Find touched reduction axis.
  std::unordered_map<IterVar, int> touch_map;
  touch_map[axis] = 1;
  schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true);
  schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true);
  // Verify normal axis are not touched.
  for (IterVar iv : compute_op->axis) {
    CHECK(!touch_map.count(iv))
        << "Factor axis touches normal axis.";
  }
  // Get the replace index
  std::unordered_map<IterVar, Range> dom_map;
  std::unordered_map<IterVar, Expr> value_map;
  for (IterVar iv : compute_op->reduce_axis) {
    if (touch_map.count(iv)) dom_map[iv] = iv->dom;
  }
  schedule::PassDownDomain(reduce_stage, &dom_map, true);
  for (IterVar iv : reduce_stage->leaf_iter_vars) {
    if (touch_map.count(iv)) {
      Range dom = dom_map.at(iv);
      if (is_one(dom->extent)) {
        value_map[iv] = dom->min;
      } else {
        value_map[iv] = iv->var;
      }
    }
  }
  schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
  // Get the factored op node.
  auto n = std::make_shared<ComputeOpNode>();
  n->name = compute_op->name + ".rf";
  {
    // axis relacement.
    auto iv_node = std::make_shared<IterVarNode>();
    iv_node->dom = dom_map.at(axis);
    CHECK(is_zero(iv_node->dom->min))
        << "Can only factor reduction domain starting from 0";
    iv_node->var = axis->var;
    iv_node->iter_type = kDataPar;
    n->axis.push_back(IterVar(iv_node));

    for (IterVar iv : compute_op->axis) {
      n->axis.push_back(iv);
    }
  }
  // predicate generation, copy not touched axis.
339 340
  int idx = tensor->value_index;
  const Reduce* reduce = compute_op->body[idx].as<Reduce>();
341 342
  CHECK(reduce) << "Can only rfactor non-inline reductions";
  Expr predicate = reduce->condition;
343 344 345 346 347 348 349 350 351 352
  std::unordered_map<const Variable*, Expr> vsub;
  for (IterVar iv : compute_op->reduce_axis) {
    if (!touch_map.count(iv)) {
      n->reduce_axis.push_back(iv);
    } else {
      CHECK(value_map.count(iv));
      Expr index = value_map.at(iv);
      vsub[iv->var.get()] = index;
      if (!index.same_as(iv->var)) {
        Expr cond = (index < dom_map.at(iv)->extent);
353
        if (is_one(predicate)) {
354
          predicate = cond;
355 356
        } else {
          predicate = predicate && cond;
357 358 359 360 361 362 363 364 365 366 367 368 369
        }
      }
    }
  }
  // Copy touched axis.
  for (IterVar iv : reduce_stage->leaf_iter_vars) {
    if (touch_map.count(iv) && !iv.same_as(axis)) {
      CHECK_EQ(iv->iter_type, kCommReduce);
      auto ncpy = std::make_shared<IterVarNode>(*iv.operator->());
      ncpy->dom = dom_map.at(iv);
      n->reduce_axis.push_back(IterVar(ncpy));
    }
  }
370 371 372 373 374 375 376 377 378 379 380 381
  VarReplacer replacer(vsub);
  Array<Expr> new_source = ir::UpdateArray(reduce->source,
    [&replacer] (const Expr& e) { return replacer.Mutate(e); });
  std::vector<Expr> body;
  for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
    body.emplace_back(Reduce::make(reduce->combiner,
                                   new_source,
                                   n->reduce_axis,
                                   predicate,
                                   idx));
  }
  n->body = Array<Expr>(body);
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
  // refresh relations, keep the un-touched relations.
  Array<IterVarRelation> rels;
  for (IterVarRelation rel : reduce_stage->relations) {
    bool touched = false;
    if (const SplitNode* r = rel.as<SplitNode>()) {
      if (touch_map.count(r->parent)) touched = true;
    } else if (const FuseNode* r = rel.as<FuseNode>()) {
      if (touch_map.count(r->fused)) touched = true;
    } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
      if (touch_map.count(r->parent)) touched = true;
    } else {
      LOG(FATAL) << "unknown relation type";
    }
    if (!touched) {
      rels.push_back(rel);
    }
  }
  // initialize the factored stage.
  Operation factor_op(n);
  ArrayNode* stages = (*this)->stages.CopyOnWrite();
  size_t stage_pos = FindNodeRef(stages, reduce_stage);
  Stage factor_stage = Stage(factor_op);
  factor_stage->relations = rels;
  CHECK_LT(stage_pos, stages->data.size());
  stages->data.insert(stages->data.begin() + stage_pos,
                      factor_stage.node_);
  (*this)->stage_map.Set(factor_op, factor_stage);
409 410 411 412
  factor_stage->group = reduce_stage->group;
  if (factor_stage->group.defined()) {
    ++factor_stage->group->num_child_stages;
  }
413 414 415
  // Replace the old reduction.
  IterVar repl_red_axis = reduce_axis(
      dom_map.at(axis), axis->var->name_hint + ".v");
416 417 418 419 420 421 422 423 424
  Array<Tensor> factor_tensors;
  Array<Tensor> old_tensors;
  int size = factor_op->num_outputs();
  for (int idx = 0; idx < size; ++idx) {
    factor_tensors.push_back(factor_op.output(idx));
    old_tensors.push_back(reduce_stage->op.output(idx));
  }
  Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
    [&](const Array<Var>& i) {
425 426 427 428 429
      Array<Expr> indices;
      indices.push_back(repl_red_axis->var);
      for (Var v : i) {
        indices.push_back(v);
      }
430 431 432 433 434 435 436 437 438 439 440 441 442
      Array<Expr> factor_exprs;
      for (int idx = 0; idx < size; ++idx) {
        factor_exprs.push_back(factor_tensors[idx](indices));
      }
      Array<Expr> reductions;
      Array<IterVar> axis = {repl_red_axis};
      Expr cond = const_true();
      for (int idx = 0; idx < size; ++idx) {
        reductions.push_back(Reduce::make(reduce->combiner,
          factor_exprs, axis, cond, idx));
      }
      return reductions;
    }, reduce_stage->op->name + ".repl");
443 444

  std::unordered_map<Tensor, Tensor> vmap;
445 446 447
  for (int idx = 0; idx < size; ++idx) {
    vmap[old_tensors[idx]] = repl_tensors[idx];
  }
448 449
  ReplaceDataFlow((*this)->stages, &vmap);
  // revamp the reduction stage.
450 451
  reduce_stage->op = repl_tensors[0]->op;
  reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
452 453
  reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
  reduce_stage->relations = Array<IterVarRelation>();
454
  return factor_tensors;
455
}
456
}  // namespace tvm