schedule_ops.cc 13.5 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2016 by Contributors
 * \file schedule_ops.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
9
#include <tvm/operation.h>
10
#include <tvm/schedule_pass.h>
11 12 13
#include <utility>
#include <unordered_map>
#include <unordered_set>
14
#include "graph.h"
15 16
#include "../op/op_util.h"
#include "../pass/ir_util.h"
17 18 19 20 21 22

namespace tvm {
namespace schedule {

using namespace ir;

23
Stmt MakePipeline(const Stage& s,
24
                  const std::unordered_map<IterVar, Range>& dom_map,
25
                  Stmt consumer,
26 27
                  bool debug_keep_trivial_loop) {
  Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
28 29
  if (producer.defined()) {
    producer = ProducerConsumer::make(s->op, true, producer);
30
  }
31 32 33 34
  if (s->double_buffer) {
    producer = AttrStmt::make(
        s->op, ir::attr::double_buffer_scope, 1, producer);
  }
35
  Stmt pipeline = producer;
36

Tianqi Chen committed
37
  if (consumer.defined() && !is_no_op(consumer)) {
38
    consumer = ProducerConsumer::make(s->op, false, consumer);
39 40
    pipeline = Block::make(producer, consumer);
  }
41
  pipeline = s->op->BuildRealize(s, dom_map, pipeline);
42 43
  // use attribute to mark scope of the operation.
  pipeline = AttrStmt::make(
44
      s->op, ir::attr::realize_scope,
45 46
      StringImm::make(s->scope),
      pipeline);
47 48 49 50 51

  if (s->is_opengl) {
    pipeline = AttrStmt::make(
        s->op, ir::attr::opengl_stage_scope, StringImm::make(""), pipeline);
  }
52
  return pipeline;
53 54 55
}

// inject the operator's realization on the stmt.
56
class InjectAttach : public IRMutator {
57
 public:
58
  InjectAttach(const Stage& stage,
59
               const Stage& attach_spec,
60
               const std::unordered_map<IterVar, Range>& dom_map,
61
               bool debug_keep_trivial_loop)
62
      : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
63
        debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
64 65 66 67 68 69

  Stmt Mutate(Stmt stmt) final {
    CHECK(stmt.defined());
    stmt =  IRMutator::Mutate(stmt);
    const AttrStmt* op = stmt.as<AttrStmt>();
    if (op != nullptr &&
70
        op->attr_key == attr::loop_scope) {
71 72 73 74 75
      if (attach_spec_->attach_type == kScope &&
          op->node == attach_spec_->attach_ivar) {
        CHECK(!found_attach)
            << "Find IterVar" << attach_spec_->attach_ivar
            << " in multiple places in the IR";
76 77
        found_attach = true;
        stmt = AttrStmt::make(
78
            op->node, op->attr_key, op->value,
79
            MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
80 81 82 83
      }
    }
    return stmt;
  }
84 85 86 87
  // whether attach point is found
  bool found_attach{false};

 private:
88
  // The stage.
89
  const Stage& stage_;
90 91
  // The attach spec, may not contain op.
  const Stage& attach_spec_;
92
  // domain map
93
  const std::unordered_map<IterVar, Range>& dom_map_;
94 95 96
  // Whether keep trivial loops with extent of 1 during lowering.
  // This is a debug feature for dataflow/axis analysis
  bool debug_keep_trivial_loop_;
97 98 99 100 101 102 103
};

// inject the operator's realization on the stmt.
class InjectScanStep : public IRMutator {
 public:
  InjectScanStep(const Stage& stage,
                 const Operation& scan_op,
104
                 const std::unordered_map<IterVar, Range>& dom_map,
105
                 bool is_init,
106
                 bool debug_keep_trivial_loop)
107
      : stage_(stage), scan_op_(scan_op),
108
        dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
109 110 111 112

  Stmt Mutate(Stmt stmt) final {
    CHECK(stmt.defined());
    stmt =  IRMutator::Mutate(stmt);
113 114 115
    // update
    const AttrStmt* op = stmt.as<AttrStmt>();
    if (op != nullptr &&
116 117
        ((op->attr_key == attr::scan_update_scope && !is_init_) ||
         (op->attr_key == attr::scan_init_scope && is_init_))) {
118
      if (op->node.same_as(scan_op_)) {
119
        found_attach = true;
120
        stmt = AttrStmt::make(
121
            op->node, op->attr_key, op->value,
122
            MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
123 124 125 126 127
      }
    }
    return stmt;
  }

128 129
  // whether attach point is found
  bool found_attach{false};
130 131 132 133 134 135

 private:
  // the operations to be carried
  const Stage& stage_;
  const Operation& scan_op_;
  // domain map
136
  const std::unordered_map<IterVar, Range>& dom_map_;
137 138
  // whether it is init.
  bool is_init_;
139 140 141
  // Whether keep trivial loops with extent of 1 during lowering.
  // This is a debug feature for dataflow/axis analysis
  bool debug_keep_trivial_loop_;
142 143
};

144 145 146 147 148
// Postprocessing of schedule op
// Replace the init and update's expression by scan's buffer.
class SchedulePostProc : public IRMutator {
 public:
  Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
149 150 151 152 153 154 155 156 157
    auto it = replace_op_.find(op->func.get());
    if (it != replace_op_.end()) {
      Stmt body = this->Mutate(op->body);
      if (it->second.defined()) {
        return ProducerConsumer::make(
            it->second, op->is_producer, body);
      } else {
        return body;
      }
158 159 160 161 162 163 164 165 166 167 168 169 170 171
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
  Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
    if (!HasSideEffect(op->value)) {
      var_value_[op->var.get()] = Mutate(op->value);
      return this->Mutate(op->body);
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
172 173
    if (op->attr_key == attr::loop_scope ||
        op->attr_key == attr::scan_init_scope) {
174
      return this->Mutate(op->body);
175
    } else if (op->attr_key == attr::scan_update_scope) {
176 177 178 179
      const ScanOpNode* scan = op->node.as<ScanOpNode>();
      CHECK(scan);
      var_value_[scan->scan_axis->var.get()] = op->value;
      return this->Mutate(op->body);
180 181 182 183
    } else if (op->attr_key == attr::thread_extent) {
      // delete duplicated thread extent attr
      auto it = thread_extent_scope_.find(op->node.get());
      if (it != thread_extent_scope_.end()) {
184
        CHECK(is_zero(ir::Simplify(it->second - op->value)));
185 186 187 188 189 190 191
        return this->Mutate(op->body);
      } else {
        thread_extent_scope_[op->node.get()] = op->value;
        Stmt ret = IRMutator::Mutate_(op, s);
        thread_extent_scope_.erase(op->node.get());
        return ret;
      }
192 193
    } else if (op->attr_key == ir::attr::realize_scope ||
               op->attr_key == ir::attr::double_buffer_scope) {
194 195 196 197
      auto it = replace_op_.find(op->node.get());
      if (it != replace_op_.end()) {
        if (it->second.defined()) {
          Stmt ret = AttrStmt::make(
198
              it->second, op->attr_key, op->value, op->body);
199
          return this->Mutate(ret);
200 201
        } else {
          return this->Mutate(op->body);
202 203 204 205 206 207 208 209 210 211 212 213 214
        }
      }
    } else if (op->attr_key == ir::attr::buffer_bind_scope) {
      Array<NodeRef> tuple(op->node.node_);
      Tensor tensor(tuple[1].node_);
      auto it = replace_op_.find(tensor->op.get());
      if (it != replace_op_.end()) {
        if (it->second.defined()) {
          return AttrStmt::make(
              Array<NodeRef>{tuple[0], it->second.output(tensor->value_index)},
              op->attr_key, op->value, Mutate(op->body));
        } else {
          return this->Mutate(op->body);
215 216 217 218 219 220 221 222 223 224 225 226
        }
      }
    } else if (op->attr_key == ir::attr::buffer_dim_align) {
      Tensor tensor(op->node.node_);
      auto it = replace_op_.find(tensor->op.get());
      if (it != replace_op_.end()) {
        if (it->second.defined()) {
          return AttrStmt::make(
              it->second.output(tensor->value_index),
              op->attr_key, op->value, Mutate(op->body));
        } else {
          return this->Mutate(op->body);
227
        }
228 229 230 231 232 233 234
      }
    }
    return IRMutator::Mutate_(op, s);
  }

  Stmt Mutate_(const Realize* op, const Stmt& s) final {
    TensorKey key{op->func, op->value_index};
235 236 237 238 239 240
    auto it = replace_realize_.find(key);
    if (it != replace_realize_.end()) {
      if (it->second.defined()) {
        Stmt ret = Realize::make(
            it->second->op, it->second->value_index,
            op->type, op->bounds, op->condition, op->body);
241
        return this->Mutate(ret);
242 243 244
      } else {
        return this->Mutate(op->body);
      }
245 246 247 248 249 250 251
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

  Stmt Mutate_(const Provide* op, const Stmt& s) final {
    TensorKey key{op->func, op->value_index};
252 253
    auto it = replace_buffer_.find(key);
    if (it != replace_buffer_.end()) {
254
      const Tensor& dst = it->second;
255
      Stmt ret = Provide::make(
256 257
          dst->op, dst->value_index, op->value, op->args);
      return this->Mutate(ret);
258 259 260 261 262 263
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

  Expr Mutate_(const Call* op, const Expr& e) final {
264
    if (op->call_type == Call::Halide) {
265
      TensorKey key{op->func, op->value_index};
266 267
      auto it = replace_buffer_.find(key);
      if (it != replace_buffer_.end()) {
268
        const Tensor& dst = it->second;
269
        Expr ret = Call::make(
270
            op->type, dst->op->name, op->args,
271
            op->call_type, dst->op, dst->value_index);
272
        return this->Mutate(ret);
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
      }
    }
    return IRMutator::Mutate_(op, e);
  }

  Expr Mutate_(const Variable* op, const Expr& e) final {
    auto it = var_value_.find(op);
    if (it != var_value_.end()) {
      return it->second;
    } else {
      return e;
    }
  }

  void Init(const Schedule& sch) {
    for (Stage s : sch->stages) {
289 290 291 292 293 294 295 296 297
      for (auto kv : s->iter_var_attrs) {
        // Update bind thread information.
        if (kv.second->bind_thread.defined()) {
          const Var& from = kv.first->var;
          const Var& to = kv.second->bind_thread->var;
          CHECK(!var_value_.count(from.get()));
          var_value_[from.get()] = to;
        }
      }
298 299
      // This must be checked for all ops, including scan.
      if (!s->op.same_as(s->origin_op)) {
300
        for (int i = 0; i < s->op->num_outputs(); ++i) {
301
          Tensor target = s->origin_op.output(i);
302 303 304
          AddReplace(s->op.output(i), target,
                     target, s->origin_op);
        }
305 306
      }
      // Specially add replacements for scan op.
307 308 309 310
      if (s->op.as<ScanOpNode>()) {
        const ScanOpNode* scan = s->op.as<ScanOpNode>();
        for (size_t i = 0; i < scan->update.size(); ++i) {
          Tensor t = s->origin_op.output(i);
311 312 313
          AddReplace(scan->init[i], t);
          AddReplace(scan->update[i], t);
          AddReplace(scan->state_placeholder[i], t);
314
        }
315 316 317 318 319
      }
    }
  }

 private:
320 321 322 323 324
  void AddReplace(Tensor src,
                  Tensor dst,
                  Tensor repl_realize = Tensor(),
                  Operation repl_op = Operation()) {
    TensorKey key{src->op, src->value_index};
325
    replace_buffer_[key] = dst;
326 327
    replace_realize_[key] = repl_realize;
    replace_op_[src->op.get()] = repl_op;
328
  }
329 330
  // The thread extent scope.
  std::unordered_map<const Node*, Expr> thread_extent_scope_;
331 332 333
  // The scan value
  std::unordered_map<const Variable*, Expr> var_value_;
  // buffer replacement
334
  std::unordered_map<TensorKey, Tensor> replace_buffer_;
335 336 337 338
  // buffere realization to be replaced
  std::unordered_map<TensorKey, Tensor> replace_realize_;
  // replace producer consumer.
  std::unordered_map<const Node*, Operation> replace_op_;
339 340
};

341
Stmt ScheduleOps(
342
    Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
343
  Stmt body = Stmt();
344
  std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
345
  // scan init and scan updates
346
  std::unordered_map<Operation, Operation> scan_init;
347 348 349 350
  for (Stage s : sch->stages) {
    const ScanOpNode* scan = s->op.as<ScanOpNode>();
    if (!scan) continue;
    for (Tensor t : scan->init) {
351 352
      if (scan_init.count(t->op)) {
        CHECK(scan_init.at(t->op).same_as(s->op))
353 354
            << "Scan init tensor can only belong to one scan";
      } else {
355
        scan_init[t->op] = s->op;
356 357 358
      }
    }
  }
359 360 361 362
  // verify correctness of group.
  for (Stage g : sch->groups) {
    CHECK(!g->op.defined());
    CHECK_EQ(g->leaf_iter_vars.size(), 0U);
363
  }
364 365 366
  // reverse the post DFS order.
  for (size_t i = sch->stages.size(); i != 0; --i) {
    Stage s = sch->stages[i - 1];
367 368
    CHECK_NE(s->attach_type, kInline)
        << "call schedule.normalize before scheduleops";
369
    CHECK(s->op.defined());
370 371
    // no need to specify place holder op.
    if (s->op.as<PlaceholderOpNode>()) continue;
372 373 374 375 376
    // Remove grouping sugar, get the real attach spec.
    Stage attach_spec = s.GetAttachSpec();

    if (scan_init.count(s->op)) {
      CHECK(body.defined());
377
      InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
378 379 380 381 382
      body = mu.Mutate(body);
      CHECK(mu.found_attach)
          << "did not find attachment point for scan.init";
    } else if (attach_spec->attach_type == kScanUpdate) {
      // Handle scan update
383
      CHECK(body.defined());
384
      InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
385 386
      body = mu.Mutate(body);
      CHECK(mu.found_attach)
387 388
          << "did not find attachment point for scan.update";
    } else if (attach_spec->attach_type == kInlinedAlready) {
389
      // do nothing
390 391
    } else if (attach_spec->attach_type == kGroupRoot) {
      CHECK(!s->group.defined());
392
      body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
393 394
    } else {
      CHECK_EQ(attach_spec->attach_type, kScope);
395
      CHECK(body.defined());
396
      InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
397 398
      body = mutator.Mutate(body);
      CHECK(mutator.found_attach)
399 400 401
          << "did not find attachment point for " << s << " in "
          << attach_spec->attach_stage->op  << " x " << attach_spec->attach_ivar
          << ", body:\n"
402
          << body;
403 404
    }
  }
405 406 407
  SchedulePostProc post_proc;
  post_proc.Init(sch);
  return post_proc.Mutate(body);
408 409 410 411
}

}  // namespace schedule
}  // namespace tvm