schedule_dataflow_rewrite.cc 29 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2017 by Contributors
 * \file schedule_dataflow_rewrite.cc
 */
#include <tvm/schedule.h>
6
#include <tvm/operation.h>
7 8 9
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
10
#include "message_passing.h"
11
#include "../pass/ir_util.h"
12
#include "../arithmetic/compute_expr.h"
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37

namespace tvm {

// find first occurance location in leaf
template<typename T>
size_t FindNodeRef(ArrayNode* array_node, const T& v) {
  const Node* n = v.get();
  for (size_t i = 0; i < array_node->data.size(); ++i) {
    if (array_node->data[i].get() == n) return i;
  }
  return array_node->data.size();
}

// The replacer of cache.
class VarReplacer : public ir::IRMutator {
 public:
  explicit VarReplacer(
      const std::unordered_map<const Variable*, Expr>& vsub)
      : vsub_(vsub) {}
  Expr Mutate_(const Variable* op, const Expr& e) {
    auto it = vsub_.find(op);
    if (it != vsub_.end()) return it->second;
    return e;
  }

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
  ir::CommReducer MutateCommReducer(ir::CommReducer combiner) {
    // Replace free variables in combiner
    auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const Expr& e) {
      return this->Mutate(e);
      });
    auto new_result = ir::UpdateArray(combiner->result, [this] (const Expr& e) {
      return this->Mutate(e);
      });

    if (combiner->identity_element.same_as(new_identity) &&
        combiner->identity_element.same_as(new_result)) {
      return combiner;
    } else {
      return ir::CommReducerNode::make(
        combiner->lhs, combiner->rhs, new_result, new_identity);
    }
  }

  Expr Mutate_(const ir::Reduce* op, const Expr& e) {
    Expr new_e = IRMutator::Mutate_(op, e);
    const ir::Reduce* new_reduce = new_e.as<ir::Reduce>();
    ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
    if (op->combiner.same_as(new_combiner)) {
      return new_e;
    } else {
      return ir::Reduce::make(
        new_combiner,
        new_reduce->source,
        new_reduce->axis,
        new_reduce->condition,
        new_reduce->value_index);
    }
  }

72 73 74 75
 private:
  const std::unordered_map<const Variable*, Expr>& vsub_;
};

76 77 78 79 80 81 82
Expr InjectPredicate(const Array<Expr>& predicates,
                     Expr body) {
  using ir::Reduce;
  using ir::Select;
  if (predicates.size() == 0) return body;
  const Reduce* reduce = body.as<Reduce>();
  if (reduce) {
83
    auto n = make_node<Reduce>(*reduce);
84
    n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates, Expr());
85 86
    return Expr(n);
  }
87
  return Select::make(arith::ComputeReduce<ir::And>(predicates, Expr()),
88 89 90 91
                      body,
                      make_zero(body.type()));
}

92 93
// Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced.
94
// Need to keep an update to the date transitive closure property on the vmap by a reverse map.
95
void ReplaceDataFlow(const Array<Stage>& stages,
96 97
                     std::unordered_map<Tensor, Tensor>* vmap,
                     std::unordered_map<Tensor, Tensor>* rvmap) {
98
  for (Stage s : stages) {
99 100 101
    Operation op = s->op->ReplaceInputs(s->op, *vmap);
    if (!op.same_as(s->op)) {
      for (int i = 0; i < op->num_outputs(); ++i) {
102 103 104 105 106 107 108
        auto it = rvmap->find(s->op.output(i));
        if (it != rvmap->end()) {
          (*vmap)[it->second] = op.output(i);
        } else {
          (*vmap)[s->op.output(i)] = op.output(i);
          (*rvmap)[op.output(i)] = s->op.output(i);
        }
109
      }
110
      s->op = op;
111 112 113 114
    }
  }
}

115 116 117 118 119 120 121
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
  return (a->combiner.same_as(b->combiner)) &&
         (a->source.same_as(b->source)) &&
         (a->axis.same_as(b->axis)) &&
         (a->condition.same_as(b->condition));
}

122 123 124
Tensor Schedule::cache_read(const Tensor& tensor,
                            const std::string& scope,
                            const Array<Operation>& readers) {
125
  (*this)->InvalidateCache();
126 127 128 129 130 131 132 133
  // create identity mapping.
  std::ostringstream os;
  os << tensor->op->name;
  if (tensor->op->num_outputs() != 1) {
    os << ".v" << tensor->value_index;
  }
  os << "." << scope;

134
  std::unordered_map<Tensor, Tensor> vsub;
135 136
  Stage s = operator[](tensor->op);
  Tensor sugar_tensor = s->op.output(tensor->value_index);
137 138 139
  Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array<Var>& i) {
      return sugar_tensor(Array<Expr>(i.begin(), i.end()));
    }, os.str());
140
  vsub[sugar_tensor] = cache;
141

142
  std::unordered_map<Tensor, Tensor> vmap;
143
  std::unordered_map<Tensor, Tensor> rvmap;
144 145
  for (Operation op : readers) {
    Stage s = operator[](op);
146 147
    Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
    CHECK(!repl_op.same_as(s->op))
148
        << "Cannot find " << tensor
149 150
        << " in the inputs of " << s->op;
    vmap[s->op.output(0)] = repl_op.output(0);
151
    rvmap[repl_op.output(0)] = s->op.output(0);
152 153
    s->op = repl_op;
  }
154
  ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
155
  ArrayNode* stages = (*this)->stages.CopyOnWrite();
156 157
  Stage op_stage = operator[](tensor->op);
  size_t pos = FindNodeRef(stages, op_stage);
158 159 160 161 162 163
  Stage cache_stage = Stage(cache->op);
  cache_stage.set_scope(scope);
  CHECK_LT(pos, stages->data.size());
  stages->data.insert(stages->data.begin() + pos + 1,
                      cache_stage.node_);
  (*this)->stage_map.Set(cache->op, cache_stage);
164 165 166 167 168
  // Update group
  cache_stage->group = op_stage->group;
  if (cache_stage->group.defined()) {
    ++cache_stage->group->num_child_stages;
  }
169 170 171
  return cache;
}

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
template<typename OpType>
void PrepareAxisMapping(Stage orig_stage,
                        OpType* op,
                        std::unordered_set<IterVar>* p_red_axis,
                        Array<IterVar>* p_new_axis,
                        std::unordered_map<IterVar, Range>* p_dom_map,
                        std::unordered_map<const Variable*, Expr>* p_vsub,
                        std::unordered_map<const Variable*, Expr>* p_vsub2newvar,
                        std::vector<Expr>* p_predicates) {
  auto& red_axis = *p_red_axis;
  auto& new_axis = *p_new_axis;
  auto& dom_map = *p_dom_map;
  auto& vsub = *p_vsub;
  auto& vsub2newvar = *p_vsub2newvar;
  auto& predicates = *p_predicates;

  for (IterVar iv : op->reduce_axis) {
189 190
    red_axis.insert(iv);
  }
191
  for (IterVar iv : op->axis) {
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
    dom_map[iv] = iv->dom;
  }
  schedule::PassDownDomain(orig_stage, &dom_map, true);
  {
    // The source->cache
    std::unordered_map<IterVar, Expr> value_map;
    for (IterVar iv : orig_stage->leaf_iter_vars) {
      if (red_axis.count(iv)) continue;
      CHECK_EQ(iv->iter_type, kDataPar)
          << "Can only relayout with in data parallel dimensions";
      Range dom = dom_map.at(iv);
      IterVar new_iv = IterVarNode::make(
          dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
      new_axis.push_back(new_iv);
      if (is_one(dom->min)) {
        value_map[iv] = dom->min;
      } else {
        value_map[iv] = iv->var;
        vsub2newvar[iv->var.get()] = new_iv->var;
      }
    }
    // skip reduction iteration.
    std::unordered_set<IterVar> skip_bound_check;
215
    for (IterVar iv : op->reduce_axis) {
216 217 218 219 220 221
      skip_bound_check.insert(iv);
    }
    schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
    predicates = schedule::MakeBoundCheck(
        orig_stage, dom_map, value_map, true, skip_bound_check);
    // The root axis
222 223 224 225
    for (IterVar iv : op->axis) {
      if (value_map.count(iv)) {
        vsub[iv->var.get()] = value_map.at(iv);
      }  // to handle tensor axis
226 227
    }
  }
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
}

Array<Tensor> ReplaceOriginalOp(Schedule sch,
                                Stage orig_stage,
                                const std::string& scope,
                                Operation cache_op,
                                Operation orig_new_op,
                                size_t tensor_size) {
  Array<Tensor> cache_tensor_list;
  for (size_t i = 0; i < tensor_size; i++) {
    Tensor cache_tensor = cache_op.output(i);
    cache_tensor_list.push_back(cache_tensor);
  }
  // The replace of the dataflow
  std::unordered_map<Tensor, Tensor> vmap;
  std::unordered_map<Tensor, Tensor> rvmap;
  vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
  rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
  for (size_t i = 0; i < tensor_size; i++) {
    vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
    rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
  }
  ReplaceDataFlow(sch->stages, &vmap, &rvmap);
  // mutate orig stage
  orig_stage->op = orig_new_op;
  orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
  orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
  orig_stage->relations = Array<IterVarRelation>();
  // create schedule for new cached stage.
  ArrayNode* stages = sch->stages.CopyOnWrite();
  size_t pos = FindNodeRef(stages, orig_stage);
  Stage cache_stage = Stage(cache_op);
  cache_stage.set_scope(scope);
  CHECK_LT(pos, stages->data.size());
  stages->data.insert(stages->data.begin() + pos,
                      cache_stage.node_);
  sch->stage_map.Set(cache_op, cache_stage);
  // Update group
  cache_stage->group = orig_stage->group;
  if (cache_stage->group.defined()) {
    ++cache_stage->group->num_child_stages;
  }
  return cache_tensor_list;
}


// Cache write and relayout the data according to loop pattern
Array<Tensor> CacheWriteWithReLayout(Schedule sch,
                                     const Array<Tensor>& tensor_array,
                                     const std::string& scope) {
  size_t tensor_size = tensor_array.size();
  sch->InvalidateCache();
  Tensor tensor = tensor_array[0];
  Stage orig_stage = sch[tensor->op];
  const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();

  std::unordered_set<IterVar> red_axis;
  Array<IterVar> new_axis;
  std::unordered_map<IterVar, Range> dom_map;

  std::unordered_map<const Variable*, Expr> vsub;
  std::unordered_map<const Variable*, Expr> vsub2newvar;
  std::vector<Expr> predicates;

  PrepareAxisMapping(orig_stage, compute,
    &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
294 295 296 297 298 299 300 301 302

  Expr body;
  Array<Expr> body_list;
  const ir::Reduce* first_reduce = nullptr;
  for (auto cbody : compute->body) {
    body = VarReplacer(vsub).Mutate(cbody);
    body = InjectPredicate(predicates, body);
    body = VarReplacer(vsub2newvar).Mutate(body);
    // Reduce nodes in ONE computeOp must be the same except value_index
303
    // This is right only if the original body ensures Reduce nodes are the same
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
    if (body->is_type<ir::Reduce>()) {
      const ir::Reduce* reduce_body = body.as<ir::Reduce>();
      if (first_reduce != nullptr) {
        CHECK(ReduceEqual(reduce_body, first_reduce));
        body = ir::Reduce::make(first_reduce->combiner,
                                first_reduce->source,
                                first_reduce->axis,
                                first_reduce->condition,
                                reduce_body->value_index);
      } else {
        first_reduce = reduce_body;
      }
    } else {
      CHECK(first_reduce == nullptr)
        << "cannot mix reduce and other node in ONE compute bodys";
    }
    body_list.push_back(body);
  }
322 323 324 325 326 327 328 329 330 331 332 333 334
  // The reader args
  Array<Expr> args;
  {
    // cache->compute
    std::unordered_map<IterVar, Expr> value_map;
    for (IterVar iv : compute->axis) {
      value_map[iv] = iv->var;
    }
    schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
    for (IterVar iv : orig_stage->leaf_iter_vars) {
      if (red_axis.count(iv)) continue;
      args.push_back(value_map.at(iv));
    }
335 336
  }
  Operation cache_op = ComputeOpNode::make(
337 338
      compute->name + "." + scope, compute->tag, compute->attrs,
      new_axis, body_list);
339

340 341 342 343 344
  Array<Expr> cache_expr_list;
  for (size_t i = 0; i < tensor_size; i++) {
    Tensor cache_tensor = cache_op.output(i);
    cache_expr_list.push_back(cache_tensor(args));
  }
345
  Operation orig_new_op = ComputeOpNode::make(
346 347
      compute->name, compute->tag, compute->attrs,
      compute->axis, cache_expr_list);
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
  return ReplaceOriginalOp(sch, orig_stage, scope,
    cache_op, orig_new_op, tensor_size);
}


// for tensor compute op
Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
                                           const Array<Tensor>& tensor_array,
                                           const std::string& scope) {
  size_t tensor_size = tensor_array.size();
  sch->InvalidateCache();
  Tensor tensor = tensor_array[0];
  Stage orig_stage = sch[tensor->op];
  const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
  CHECK_EQ(tensor_op->num_outputs(), 1)
      << "cache write only support single output tensor_compute_op";

  std::unordered_set<IterVar> red_axis;
  Array<IterVar> new_axis;
  std::unordered_map<IterVar, Range> dom_map;

  std::unordered_map<const Variable*, Expr> vsub;
  std::unordered_map<const Variable*, Expr> vsub2newvar;
  std::vector<Expr> predicates;

  PrepareAxisMapping(orig_stage, tensor_op,
    &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);


  for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
    IterVar iv = tensor_op->axis[i];
    IterVar new_iv = IterVarNode::make(
      iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
    new_axis.push_back(new_iv);
382
  }
383 384 385 386 387 388 389 390 391
  Array<Region> new_regions;
  for (Region old_region : tensor_op->input_regions) {
    Region region;
    for (Range r : old_region) {
      Expr min = VarReplacer(vsub2newvar).Mutate(r->min);
      Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent);
      region.push_back(Range::make_by_min_extent(min, extent));
    }
    new_regions.push_back(region);
392
  }
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436

  Operation cache_op = TensorComputeOpNode::make(
      tensor_op->name + "." + scope, tensor_op->tag, new_axis,
      tensor_op->reduce_axis, tensor_op->schedulable_ndim,
      tensor_op->intrin, tensor_op->inputs, new_regions);

  // axis will be used in generating compute op
  Array<IterVar> compute_axis = tensor_op->axis;
  for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
    IterVar iv = tensor_op->axis[i];
    IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar);
    compute_axis.Set(i, aiv);
  }

  // The reader args
  Array<Expr> args;
  {
    // cache->compute
    std::unordered_map<IterVar, Expr> value_map;
    for (IterVar iv : compute_axis) {
      value_map[iv] = iv->var;
    }
    schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
    for (IterVar iv : orig_stage->leaf_iter_vars) {
      if (red_axis.count(iv)) continue;
      args.push_back(value_map.at(iv));
    }
    // tensorized region axis
    for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
      IterVar iv = compute_axis[i];
      args.push_back(value_map.at(iv));
    }
  }

  Array<Expr> cache_expr_list;
  for (size_t i = 0; i < tensor_size; i++) {
    Tensor cache_tensor = cache_op.output(i);
    cache_expr_list.push_back(cache_tensor(args));
  }
  Operation orig_new_op = ComputeOpNode::make(
      tensor_op->name, tensor_op->tag, {},
      compute_axis, cache_expr_list);
  return ReplaceOriginalOp(sch, orig_stage, scope,
    cache_op, orig_new_op, tensor_size);
437 438
}

439

440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
                             const std::string& scope) {
  (*this)->InvalidateCache();
  CHECK(tensor_array.size() > 0)
      << "size of tensor_array must be greater than 0";
  Tensor tensor = tensor_array[0];
  Stage orig_stage = operator[](tensor->op);
  const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
  CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
      << "size of input tensor list must be same as number of stage outputs";
  for (size_t i = 1; i < tensor_array.size(); i++) {
    Stage tmp_stage = operator[](tensor_array[i]->op);
    CHECK(orig_stage.same_as(tmp_stage))
        << "Input tensor list must be generated by ONE computeOp";
  }
  return CacheWriteWithReLayout(*this, tensor_array, scope);
456 457
}

458

459 460
Tensor Schedule::cache_write(const Tensor& tensor,
                             const std::string& scope) {
461
  // support original compute and tensor compute both
462
  (*this)->InvalidateCache();
463 464 465 466 467 468 469 470 471
  const char* type_key = tensor->op->type_key();
  if (!strcmp(type_key, "ComputeOp")) {
    return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
  } else if (!strcmp(type_key, "TensorComputeOp")) {
    return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
  } else {
    LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
    return Tensor();
  }
472 473
}

474

475 476 477
void RebaseNonZeroMinLoop(const Schedule& sch) {
  std::unordered_map<IterVar, IterVar> rebase_map;
  for (Stage s : sch->stages) {
478
    if (s->attach_type == kInlinedAlready) continue;
479 480 481 482 483

    auto root_iter_vars = s->op->root_iter_vars();
    ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
    for (IterVar iv : root_iter_vars) {
      size_t idx = FindNodeRef(leaf_vars, iv);
484 485 486 487 488 489
      auto it  = s->iter_var_attrs.find(iv);
      // don;t need to rebase path that are binded.
      if (it != s->iter_var_attrs.end() &&
          (*it).second->bind_thread.defined()) {
        continue;
      }
490 491
      if (idx < leaf_vars->data.size()) {
        // insert rebase
492
        IterVar rebased = IterVarNode::make(
493
            Range(), iv->var.copy_with_suffix(""), iv->iter_type);
494
        s->relations.push_back(RebaseNode::make(iv, rebased));
495 496 497
        if (s->iter_var_attrs.count(iv)) {
          s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
        }
498 499 500 501 502 503 504 505 506 507 508 509
        leaf_vars->data[idx] = rebased.node_;
        rebase_map[iv] = rebased;
      }
    }
  }
  // remap the parent relation
  for (Stage s : sch->stages) {
    if (s->attach_type != kScope) continue;
    if (rebase_map.count(s->attach_ivar)) {
      s->attach_ivar = rebase_map.at(s->attach_ivar);
    }
  }
510 511 512 513 514 515
  for (Stage s : sch->groups) {
    if (s->attach_type != kScope) continue;
    if (rebase_map.count(s->attach_ivar)) {
      s->attach_ivar = rebase_map.at(s->attach_ivar);
    }
  }
516 517
}

518 519
void InjectInline(ScheduleNode* sch) {
  sch->InvalidateCache();
520

521
  std::vector<Array<Expr> > new_body(sch->stages.size());
522
  std::vector<bool> changed(sch->stages.size(), false);
523 524 525 526 527 528
  // inline all the ops
  for (size_t i = sch->stages.size(); i != 0; --i) {
    Stage stage = sch->stages[i - 1];
    if (stage->attach_type == kInline) {
      stage->attach_type = kInlinedAlready;
      Array<Var> args;
529
      Expr body;
530 531 532 533 534 535 536 537
      {
        // setup args
        const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
        CHECK(compute)
            << "can only inline compute op";
        for (auto iv : compute->axis) {
          args.push_back(iv->var);
        }
538 539 540
        CHECK_EQ(compute->body.size(), 1U)
            << "can only inline compute op with 1 output";
        body = compute->body[0];
541 542 543 544 545
      }
      for (size_t j = i; j < sch->stages.size(); ++j) {
        Stage s = sch->stages[j];
        const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
        if (compute) {
546
          if (!new_body[j].size()) {
547 548
            new_body[j] = s->op.as<ComputeOpNode>()->body;
          }
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
          if (new_body[j][0]->is_type<ir::Reduce>()) {
            // specially handle reduction inline for multiplre reductions.
            const ir::Reduce* reduce = new_body[j][0].as<ir::Reduce>();
            for (size_t k = 1; k < new_body[j].size(); ++k) {
              const ir::Reduce* reduce_ = new_body[j][k].as<ir::Reduce>();
              CHECK(reduce_);
              CHECK(ReduceEqual(reduce_, reduce))
                  << "The Reduce inputs of ComputeOp should "
                  << "have the same attribute except value_index";
            }
            Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][0]),
                                        stage->op, args, body).as<ir::Evaluate>()->value;
            if (!new_value.same_as(new_body[j][0])) {
              changed[j] = true;
              const ir::Reduce* r = new_value.as<ir::Reduce>();
              CHECK_EQ(new_body[j].size(), r->source.size());
              CHECK(r != nullptr);
              for (size_t k = 0; k < new_body[j].size(); ++k) {
567
                auto n = make_node<ir::Reduce>(*r);
568 569 570 571 572 573 574 575 576 577 578 579 580 581
                n->value_index = static_cast<int>(k);
                n->type = r->source[k].type();
                new_body[j].Set(k, Expr(n));
              }
            }
          } else {
            for (size_t k = 0; k < new_body[j].size(); ++k) {
              Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][k]),
                                          stage->op, args, body).as<ir::Evaluate>()->value;
              if (!new_value.same_as(new_body[j][k])) {
                new_body[j].Set(k, new_value);
                changed[j] = true;
              }
            }
582
          }
583 584 585 586
        }
      }
    }
  }
587
  std::unordered_map<Tensor, Tensor> repl;
588 589
  // rewrite dataflow
  for (size_t i = 0; i < sch->stages.size(); ++i) {
590 591
    Stage s = sch->stages[i];
    if (s->attach_type == kInlinedAlready) continue;
592
    if (new_body[i].size()) {
593
      // Logics from ReplaceDataFlow
594 595
      const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
      CHECK(compute);
596
      Operation op = s->op;
597
      if (changed[i]) {
598
        op = ComputeOpNode::make(
599 600
            compute->name, compute->tag, compute->attrs,
            compute->axis, new_body[i]);
601 602 603
      }
      op = op->ReplaceInputs(op, repl);
      if (!op.same_as(s->op)) {
604 605 606 607
        for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
          repl[s->op.output(idx)] = op.output(idx);
          s->op = op;
        }
608
      }
609 610 611 612 613 614 615 616
    } else {
      Operation op = s->op->ReplaceInputs(s->op, repl);
      if (!op.same_as(s->op)) {
        for (int j = 0; j < op->num_outputs(); ++j) {
          repl[s->op.output(j)] = op.output(j);
        }
        s->op = op;
      }
617 618 619 620
    }
  }
}

621 622 623 624 625
Schedule Schedule::normalize() {
  Schedule sn = copy();
  InjectInline(sn.operator->());
  RebaseNonZeroMinLoop(sn);
  return sn;
626 627
}

628
// Handle reduction factor.
629
Array<Tensor> Schedule::rfactor(const Tensor& tensor,
630 631
                                const IterVar& axis,
                                int factor_axis) {
632
  (*this)->InvalidateCache();
633 634 635 636 637
  using ir::Reduce;
  CHECK_EQ(axis->iter_type, kCommReduce)
      << "Can only factor reduction axis";
  Stage reduce_stage = operator[](tensor->op);
  const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
638
  CHECK(compute_op) << "Can only factor ComputeOp";
639 640 641 642 643 644 645 646 647 648 649
  ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
  {
    size_t axis_pos = FindNodeRef(leaf_vars, axis);
    CHECK_NE(axis_pos, leaf_vars->data.size())
        << "Cannot find IterVar " << axis << " in leaf iter vars";
  }
  // Find touched reduction axis.
  std::unordered_map<IterVar, int> touch_map;
  touch_map[axis] = 1;
  schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true);
  schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true);
650 651
  // skip reduction iteration.
  std::unordered_set<IterVar> skip_bound_check;
652 653 654 655
  // Verify normal axis are not touched.
  for (IterVar iv : compute_op->axis) {
    CHECK(!touch_map.count(iv))
        << "Factor axis touches normal axis.";
656
    skip_bound_check.insert(iv);
657 658 659 660 661
  }
  // Get the replace index
  std::unordered_map<IterVar, Range> dom_map;
  std::unordered_map<IterVar, Expr> value_map;
  for (IterVar iv : compute_op->reduce_axis) {
662 663 664 665 666
    if (touch_map.count(iv)) {
      dom_map[iv] = iv->dom;
    } else {
      skip_bound_check.insert(iv);
    }
667 668 669 670 671 672 673 674 675 676 677 678 679
  }
  schedule::PassDownDomain(reduce_stage, &dom_map, true);
  for (IterVar iv : reduce_stage->leaf_iter_vars) {
    if (touch_map.count(iv)) {
      Range dom = dom_map.at(iv);
      if (is_one(dom->extent)) {
        value_map[iv] = dom->min;
      } else {
        value_map[iv] = iv->var;
      }
    }
  }
  schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
680 681 682
  std::vector<Expr> predicates = schedule::MakeBoundCheck(
      reduce_stage, dom_map, value_map, true, skip_bound_check);

683
  // Get the factored op node.
684 685 686
  const int factor_axis_pos = \
      factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
  CHECK_LE(factor_axis_pos, compute_op->axis.size());
687
  auto n = make_node<ComputeOpNode>();
688 689 690
  n->name = compute_op->name + ".rf";
  {
    // axis relacement.
691
    auto iv_node = make_node<IterVarNode>();
692 693 694 695 696 697
    iv_node->dom = dom_map.at(axis);
    CHECK(is_zero(iv_node->dom->min))
        << "Can only factor reduction domain starting from 0";
    iv_node->var = axis->var;
    iv_node->iter_type = kDataPar;

698 699 700 701 702 703 704 705 706
    const int size = compute_op->axis.size();
    for (int idx = 0; idx < size; ++idx) {
      if (factor_axis_pos == idx) {
        n->axis.push_back(IterVar(iv_node));
      }
      n->axis.push_back(compute_op->axis[idx]);
    }
    if (factor_axis_pos == size) {
      n->axis.push_back(IterVar(iv_node));
707 708 709
    }
  }
  // predicate generation, copy not touched axis.
710 711
  int idx = tensor->value_index;
  const Reduce* reduce = compute_op->body[idx].as<Reduce>();
712
  CHECK(reduce) << "Can only rfactor non-inline reductions";
713
  predicates.push_back(reduce->condition);
714
  Expr predicate = arith::ComputeReduce<ir::And>(predicates, Expr());
715

716
  std::unordered_map<const Variable*, Expr> vsub;
717

718 719 720 721 722 723 724 725 726
  for (IterVar iv : compute_op->reduce_axis) {
    if (!touch_map.count(iv)) {
      n->reduce_axis.push_back(iv);
    } else {
      CHECK(value_map.count(iv));
      Expr index = value_map.at(iv);
      vsub[iv->var.get()] = index;
    }
  }
727

728 729 730 731
  // Copy touched axis.
  for (IterVar iv : reduce_stage->leaf_iter_vars) {
    if (touch_map.count(iv) && !iv.same_as(axis)) {
      CHECK_EQ(iv->iter_type, kCommReduce);
732
      auto ncpy = make_node<IterVarNode>(*iv.operator->());
733 734 735 736
      ncpy->dom = dom_map.at(iv);
      n->reduce_axis.push_back(IterVar(ncpy));
    }
  }
737 738 739 740 741 742 743 744 745 746 747 748
  VarReplacer replacer(vsub);
  Array<Expr> new_source = ir::UpdateArray(reduce->source,
    [&replacer] (const Expr& e) { return replacer.Mutate(e); });
  std::vector<Expr> body;
  for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
    body.emplace_back(Reduce::make(reduce->combiner,
                                   new_source,
                                   n->reduce_axis,
                                   predicate,
                                   idx));
  }
  n->body = Array<Expr>(body);
749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775
  // refresh relations, keep the un-touched relations.
  Array<IterVarRelation> rels;
  for (IterVarRelation rel : reduce_stage->relations) {
    bool touched = false;
    if (const SplitNode* r = rel.as<SplitNode>()) {
      if (touch_map.count(r->parent)) touched = true;
    } else if (const FuseNode* r = rel.as<FuseNode>()) {
      if (touch_map.count(r->fused)) touched = true;
    } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
      if (touch_map.count(r->parent)) touched = true;
    } else {
      LOG(FATAL) << "unknown relation type";
    }
    if (!touched) {
      rels.push_back(rel);
    }
  }
  // initialize the factored stage.
  Operation factor_op(n);
  ArrayNode* stages = (*this)->stages.CopyOnWrite();
  size_t stage_pos = FindNodeRef(stages, reduce_stage);
  Stage factor_stage = Stage(factor_op);
  factor_stage->relations = rels;
  CHECK_LT(stage_pos, stages->data.size());
  stages->data.insert(stages->data.begin() + stage_pos,
                      factor_stage.node_);
  (*this)->stage_map.Set(factor_op, factor_stage);
776 777 778 779
  factor_stage->group = reduce_stage->group;
  if (factor_stage->group.defined()) {
    ++factor_stage->group->num_child_stages;
  }
780 781 782
  // Replace the old reduction.
  IterVar repl_red_axis = reduce_axis(
      dom_map.at(axis), axis->var->name_hint + ".v");
783 784 785 786 787 788 789 790 791
  Array<Tensor> factor_tensors;
  Array<Tensor> old_tensors;
  int size = factor_op->num_outputs();
  for (int idx = 0; idx < size; ++idx) {
    factor_tensors.push_back(factor_op.output(idx));
    old_tensors.push_back(reduce_stage->op.output(idx));
  }
  Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
    [&](const Array<Var>& i) {
792
      Array<Expr> indices;
793 794 795 796 797 798 799 800 801
      const int idx_size = static_cast<int>(i.size());
      for (int idx = 0; idx < idx_size; ++idx) {
        if (factor_axis_pos == idx) {
          indices.push_back(repl_red_axis->var);
        }
        indices.push_back(i[idx]);
      }
      if (factor_axis_pos == idx_size) {
          indices.push_back(repl_red_axis->var);
802
      }
803 804 805 806 807 808 809 810 811 812 813 814 815
      Array<Expr> factor_exprs;
      for (int idx = 0; idx < size; ++idx) {
        factor_exprs.push_back(factor_tensors[idx](indices));
      }
      Array<Expr> reductions;
      Array<IterVar> axis = {repl_red_axis};
      Expr cond = const_true();
      for (int idx = 0; idx < size; ++idx) {
        reductions.push_back(Reduce::make(reduce->combiner,
          factor_exprs, axis, cond, idx));
      }
      return reductions;
    }, reduce_stage->op->name + ".repl");
816 817

  std::unordered_map<Tensor, Tensor> vmap;
818
  std::unordered_map<Tensor, Tensor> rvmap;
819 820
  for (int idx = 0; idx < size; ++idx) {
    vmap[old_tensors[idx]] = repl_tensors[idx];
821
    rvmap[repl_tensors[idx]] = old_tensors[idx];
822
  }
823
  ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
824
  // revamp the reduction stage.
825 826
  reduce_stage->op = repl_tensors[0]->op;
  reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
827 828
  reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
  reduce_stage->relations = Array<IterVarRelation>();
829
  return factor_tensors;
830
}
831

832
}  // namespace tvm