schedule_ops.cc 14.3 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 *  Copyright (c) 2016 by Contributors
 * \file schedule_ops.cc
 */
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
28
#include <tvm/operation.h>
29
#include <tvm/schedule_pass.h>
30 31 32
#include <utility>
#include <unordered_map>
#include <unordered_set>
33
#include "graph.h"
34 35
#include "../op/op_util.h"
#include "../pass/ir_util.h"
36 37 38 39 40 41

namespace tvm {
namespace schedule {

using namespace ir;

42
Stmt MakePipeline(const Stage& s,
43
                  const std::unordered_map<IterVar, Range>& dom_map,
44
                  Stmt consumer,
45 46
                  bool debug_keep_trivial_loop) {
  Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
47 48
  if (producer.defined()) {
    producer = ProducerConsumer::make(s->op, true, producer);
49
  }
50 51 52 53
  if (s->double_buffer) {
    producer = AttrStmt::make(
        s->op, ir::attr::double_buffer_scope, 1, producer);
  }
54
  Stmt pipeline = producer;
55

Tianqi Chen committed
56
  if (consumer.defined() && !is_no_op(consumer)) {
57
    consumer = ProducerConsumer::make(s->op, false, consumer);
58 59
    pipeline = Block::make(producer, consumer);
  }
60
  pipeline = s->op->BuildRealize(s, dom_map, pipeline);
61 62
  // use attribute to mark scope of the operation.
  pipeline = AttrStmt::make(
63
      s->op, ir::attr::realize_scope,
64 65
      StringImm::make(s->scope),
      pipeline);
66 67 68 69 70

  if (s->is_opengl) {
    pipeline = AttrStmt::make(
        s->op, ir::attr::opengl_stage_scope, StringImm::make(""), pipeline);
  }
71
  return pipeline;
72 73 74
}

// inject the operator's realization on the stmt.
75
class InjectAttach : public IRMutator {
76
 public:
77
  InjectAttach(const Stage& stage,
78
               const Stage& attach_spec,
79
               const std::unordered_map<IterVar, Range>& dom_map,
80
               bool debug_keep_trivial_loop)
81
      : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
82
        debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
83 84 85 86 87 88

  Stmt Mutate(Stmt stmt) final {
    CHECK(stmt.defined());
    stmt =  IRMutator::Mutate(stmt);
    const AttrStmt* op = stmt.as<AttrStmt>();
    if (op != nullptr &&
89
        op->attr_key == attr::loop_scope) {
90 91 92 93 94
      if (attach_spec_->attach_type == kScope &&
          op->node == attach_spec_->attach_ivar) {
        CHECK(!found_attach)
            << "Find IterVar" << attach_spec_->attach_ivar
            << " in multiple places in the IR";
95 96
        found_attach = true;
        stmt = AttrStmt::make(
97
            op->node, op->attr_key, op->value,
98
            MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
99 100 101 102
      }
    }
    return stmt;
  }
103 104 105 106
  // whether attach point is found
  bool found_attach{false};

 private:
107
  // The stage.
108
  const Stage& stage_;
109 110
  // The attach spec, may not contain op.
  const Stage& attach_spec_;
111
  // domain map
112
  const std::unordered_map<IterVar, Range>& dom_map_;
113 114 115
  // Whether keep trivial loops with extent of 1 during lowering.
  // This is a debug feature for dataflow/axis analysis
  bool debug_keep_trivial_loop_;
116 117 118 119 120 121 122
};

// inject the operator's realization on the stmt.
class InjectScanStep : public IRMutator {
 public:
  InjectScanStep(const Stage& stage,
                 const Operation& scan_op,
123
                 const std::unordered_map<IterVar, Range>& dom_map,
124
                 bool is_init,
125
                 bool debug_keep_trivial_loop)
126
      : stage_(stage), scan_op_(scan_op),
127
        dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {}
128 129 130 131

  Stmt Mutate(Stmt stmt) final {
    CHECK(stmt.defined());
    stmt =  IRMutator::Mutate(stmt);
132 133 134
    // update
    const AttrStmt* op = stmt.as<AttrStmt>();
    if (op != nullptr &&
135 136
        ((op->attr_key == attr::scan_update_scope && !is_init_) ||
         (op->attr_key == attr::scan_init_scope && is_init_))) {
137
      if (op->node.same_as(scan_op_)) {
138
        found_attach = true;
139
        stmt = AttrStmt::make(
140
            op->node, op->attr_key, op->value,
141
            MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
142 143 144 145 146
      }
    }
    return stmt;
  }

147 148
  // whether attach point is found
  bool found_attach{false};
149 150 151 152 153 154

 private:
  // the operations to be carried
  const Stage& stage_;
  const Operation& scan_op_;
  // domain map
155
  const std::unordered_map<IterVar, Range>& dom_map_;
156 157
  // whether it is init.
  bool is_init_;
158 159 160
  // Whether keep trivial loops with extent of 1 during lowering.
  // This is a debug feature for dataflow/axis analysis
  bool debug_keep_trivial_loop_;
161 162
};

163 164 165 166 167
// Postprocessing of schedule op
// Replace the init and update's expression by scan's buffer.
class SchedulePostProc : public IRMutator {
 public:
  Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
168 169 170 171 172 173 174 175 176
    auto it = replace_op_.find(op->func.get());
    if (it != replace_op_.end()) {
      Stmt body = this->Mutate(op->body);
      if (it->second.defined()) {
        return ProducerConsumer::make(
            it->second, op->is_producer, body);
      } else {
        return body;
      }
177 178 179 180 181 182 183 184 185 186 187 188 189 190
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
  Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
    if (!HasSideEffect(op->value)) {
      var_value_[op->var.get()] = Mutate(op->value);
      return this->Mutate(op->body);
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
191 192
    if (op->attr_key == attr::loop_scope ||
        op->attr_key == attr::scan_init_scope) {
193
      return this->Mutate(op->body);
194
    } else if (op->attr_key == attr::scan_update_scope) {
195 196 197 198
      const ScanOpNode* scan = op->node.as<ScanOpNode>();
      CHECK(scan);
      var_value_[scan->scan_axis->var.get()] = op->value;
      return this->Mutate(op->body);
199 200 201 202
    } else if (op->attr_key == attr::thread_extent) {
      // delete duplicated thread extent attr
      auto it = thread_extent_scope_.find(op->node.get());
      if (it != thread_extent_scope_.end()) {
203
        CHECK(is_zero(ir::Simplify(it->second - op->value)));
204 205 206 207 208 209 210
        return this->Mutate(op->body);
      } else {
        thread_extent_scope_[op->node.get()] = op->value;
        Stmt ret = IRMutator::Mutate_(op, s);
        thread_extent_scope_.erase(op->node.get());
        return ret;
      }
211 212
    } else if (op->attr_key == ir::attr::realize_scope ||
               op->attr_key == ir::attr::double_buffer_scope) {
213 214 215 216
      auto it = replace_op_.find(op->node.get());
      if (it != replace_op_.end()) {
        if (it->second.defined()) {
          Stmt ret = AttrStmt::make(
217
              it->second, op->attr_key, op->value, op->body);
218
          return this->Mutate(ret);
219 220
        } else {
          return this->Mutate(op->body);
221 222 223 224 225 226 227 228 229 230 231 232 233
        }
      }
    } else if (op->attr_key == ir::attr::buffer_bind_scope) {
      Array<NodeRef> tuple(op->node.node_);
      Tensor tensor(tuple[1].node_);
      auto it = replace_op_.find(tensor->op.get());
      if (it != replace_op_.end()) {
        if (it->second.defined()) {
          return AttrStmt::make(
              Array<NodeRef>{tuple[0], it->second.output(tensor->value_index)},
              op->attr_key, op->value, Mutate(op->body));
        } else {
          return this->Mutate(op->body);
234 235 236 237 238 239 240 241 242 243 244 245
        }
      }
    } else if (op->attr_key == ir::attr::buffer_dim_align) {
      Tensor tensor(op->node.node_);
      auto it = replace_op_.find(tensor->op.get());
      if (it != replace_op_.end()) {
        if (it->second.defined()) {
          return AttrStmt::make(
              it->second.output(tensor->value_index),
              op->attr_key, op->value, Mutate(op->body));
        } else {
          return this->Mutate(op->body);
246
        }
247 248 249 250 251 252 253
      }
    }
    return IRMutator::Mutate_(op, s);
  }

  Stmt Mutate_(const Realize* op, const Stmt& s) final {
    TensorKey key{op->func, op->value_index};
254 255 256 257 258 259
    auto it = replace_realize_.find(key);
    if (it != replace_realize_.end()) {
      if (it->second.defined()) {
        Stmt ret = Realize::make(
            it->second->op, it->second->value_index,
            op->type, op->bounds, op->condition, op->body);
260
        return this->Mutate(ret);
261 262 263
      } else {
        return this->Mutate(op->body);
      }
264 265 266 267 268 269 270
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

  Stmt Mutate_(const Provide* op, const Stmt& s) final {
    TensorKey key{op->func, op->value_index};
271 272
    auto it = replace_buffer_.find(key);
    if (it != replace_buffer_.end()) {
273
      const Tensor& dst = it->second;
274
      Stmt ret = Provide::make(
275 276
          dst->op, dst->value_index, op->value, op->args);
      return this->Mutate(ret);
277 278 279 280 281 282
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }

  Expr Mutate_(const Call* op, const Expr& e) final {
283
    if (op->call_type == Call::Halide) {
284
      TensorKey key{op->func, op->value_index};
285 286
      auto it = replace_buffer_.find(key);
      if (it != replace_buffer_.end()) {
287
        const Tensor& dst = it->second;
288
        Expr ret = Call::make(
289
            op->type, dst->op->name, op->args,
290
            op->call_type, dst->op, dst->value_index);
291
        return this->Mutate(ret);
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
      }
    }
    return IRMutator::Mutate_(op, e);
  }

  Expr Mutate_(const Variable* op, const Expr& e) final {
    auto it = var_value_.find(op);
    if (it != var_value_.end()) {
      return it->second;
    } else {
      return e;
    }
  }

  void Init(const Schedule& sch) {
    for (Stage s : sch->stages) {
308 309 310 311 312 313 314 315 316
      for (auto kv : s->iter_var_attrs) {
        // Update bind thread information.
        if (kv.second->bind_thread.defined()) {
          const Var& from = kv.first->var;
          const Var& to = kv.second->bind_thread->var;
          CHECK(!var_value_.count(from.get()));
          var_value_[from.get()] = to;
        }
      }
317 318
      // This must be checked for all ops, including scan.
      if (!s->op.same_as(s->origin_op)) {
319
        for (int i = 0; i < s->op->num_outputs(); ++i) {
320
          Tensor target = s->origin_op.output(i);
321 322 323
          AddReplace(s->op.output(i), target,
                     target, s->origin_op);
        }
324 325
      }
      // Specially add replacements for scan op.
326
      if (const ScanOpNode* scan = s->op.as<ScanOpNode>()) {
327 328
        for (size_t i = 0; i < scan->update.size(); ++i) {
          Tensor t = s->origin_op.output(i);
329 330 331
          AddReplace(scan->init[i], t);
          AddReplace(scan->update[i], t);
          AddReplace(scan->state_placeholder[i], t);
332
        }
333 334 335 336 337
      }
    }
  }

 private:
338 339 340 341 342
  void AddReplace(Tensor src,
                  Tensor dst,
                  Tensor repl_realize = Tensor(),
                  Operation repl_op = Operation()) {
    TensorKey key{src->op, src->value_index};
343
    replace_buffer_[key] = dst;
344 345
    replace_realize_[key] = repl_realize;
    replace_op_[src->op.get()] = repl_op;
346
  }
347 348
  // The thread extent scope.
  std::unordered_map<const Node*, Expr> thread_extent_scope_;
349 350 351
  // The scan value
  std::unordered_map<const Variable*, Expr> var_value_;
  // buffer replacement
352
  std::unordered_map<TensorKey, Tensor> replace_buffer_;
353 354 355 356
  // buffere realization to be replaced
  std::unordered_map<TensorKey, Tensor> replace_realize_;
  // replace producer consumer.
  std::unordered_map<const Node*, Operation> replace_op_;
357 358
};

359
Stmt ScheduleOps(
360
    Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
361
  Stmt body = Stmt();
362
  std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
363
  // scan init and scan updates
364
  std::unordered_map<Operation, Operation> scan_init;
365 366 367 368
  for (Stage s : sch->stages) {
    const ScanOpNode* scan = s->op.as<ScanOpNode>();
    if (!scan) continue;
    for (Tensor t : scan->init) {
369 370
      if (scan_init.count(t->op)) {
        CHECK(scan_init.at(t->op).same_as(s->op))
371 372
            << "Scan init tensor can only belong to one scan";
      } else {
373
        scan_init[t->op] = s->op;
374 375 376
      }
    }
  }
377 378 379 380
  // verify correctness of group.
  for (Stage g : sch->groups) {
    CHECK(!g->op.defined());
    CHECK_EQ(g->leaf_iter_vars.size(), 0U);
381
  }
382 383 384
  // reverse the post DFS order.
  for (size_t i = sch->stages.size(); i != 0; --i) {
    Stage s = sch->stages[i - 1];
385 386
    CHECK_NE(s->attach_type, kInline)
        << "call schedule.normalize before scheduleops";
387
    CHECK(s->op.defined());
388 389
    // no need to specify place holder op.
    if (s->op.as<PlaceholderOpNode>()) continue;
390 391 392 393 394
    // Remove grouping sugar, get the real attach spec.
    Stage attach_spec = s.GetAttachSpec();

    if (scan_init.count(s->op)) {
      CHECK(body.defined());
395
      InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
396 397 398 399 400
      body = mu.Mutate(body);
      CHECK(mu.found_attach)
          << "did not find attachment point for scan.init";
    } else if (attach_spec->attach_type == kScanUpdate) {
      // Handle scan update
401
      CHECK(body.defined());
402
      InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
403 404
      body = mu.Mutate(body);
      CHECK(mu.found_attach)
405 406
          << "did not find attachment point for scan.update";
    } else if (attach_spec->attach_type == kInlinedAlready) {
407
      // do nothing
408 409
    } else if (attach_spec->attach_type == kGroupRoot) {
      CHECK(!s->group.defined());
410
      body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
411 412
    } else {
      CHECK_EQ(attach_spec->attach_type, kScope);
413
      CHECK(body.defined());
414
      InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
415 416
      body = mutator.Mutate(body);
      CHECK(mutator.found_attach)
417 418 419
          << "did not find attachment point for " << s << " in "
          << attach_spec->attach_stage->op  << " x " << attach_spec->attach_ivar
          << ", body:\n"
420
          << body;
421 422
    }
  }
423 424 425
  SchedulePostProc post_proc;
  post_proc.Init(sch);
  return post_proc.Mutate(body);
426 427 428 429
}

}  // namespace schedule
}  // namespace tvm