/*! * Copyright (c) 2017 by Contributors * \file narrow_channel_access.cc * \brief Narrow channel access to a smaller range * when possible by bringing it to the internal loop. */ #include <tvm/ir.h> #include <tvm/expr.h> #include <tvm/ir_pass.h> #include <tvm/ir_visitor.h> #include <tvm/ir_mutator.h> #include <tvm/arithmetic.h> #include <tvm/channel.h> #include "./ir_util.h" namespace tvm { namespace ir { using namespace arith; // Bound deducer for channel access. class ChannelAccessBound : public IRVisitor { public: ChannelAccessBound(const Variable* buf_var, bool read_access) : buf_var_(buf_var), read_access_(read_access) {} void Visit_(const Store* op) final { if (!read_access_ && buf_var_ == op->buffer_var.get()) { ret_.emplace_back(EvalSet(op->index, dom_map_)); } IRVisitor::Visit_(op); } void Visit_(const For* op) final { CHECK(is_zero(op->min)); // We know that the extent of the loop won't depend on relaxed scope. // TODO(tqchen) have a verification pass. dom_map_[op->loop_var.get()] = IntSet::interval(op->min, op->extent - 1); IRVisitor::Visit_(op); } void Visit_(const Load* op) final { if (read_access_ && buf_var_ == op->buffer_var.get()) { ret_.emplace_back(EvalSet(op->index, dom_map_)); } IRVisitor::Visit_(op); } void Visit_(const Let* op) final { LOG(FATAL) << "cannot pass through let"; } void Visit_(const LetStmt* op) final { LOG(FATAL) << "cannot pass through let"; } IntSet Eval(const Stmt& stmt) { Visit(stmt); return Union(ret_); } private: // The buffer variable. const Variable* buf_var_; // read or write bool read_access_{true}; // Box std::vector<IntSet> ret_; // Domain map. std::unordered_map<const Variable*, IntSet> dom_map_; }; class ChannelAccessIndexRewriter : public IRMutator { public: ChannelAccessIndexRewriter(const Variable* buf_var, Expr min, bool read_access) : buf_var_(buf_var), min_(min), read_access_(read_access) {} Expr Mutate_(const Load* op, const Expr& e) final { Expr expr = IRMutator::Mutate_(op, e); op = expr.as<Load>(); if (read_access_ && buf_var_ == op->buffer_var.get()) { return Load::make( op->type, op->buffer_var, ir::Simplify(op->index - min_), op->predicate); } else { return expr; } } Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as<Store>(); if (!read_access_ && buf_var_ == op->buffer_var.get()) { return Store::make( op->buffer_var, op->value, ir::Simplify(op->index - min_), op->predicate); } else { return stmt; } } private: // The buffer variable. const Variable* buf_var_; // The min bound. Expr min_; // read or write bool read_access_{true}; }; // Rewrite channel access pattern. class ChannelAccessRewriter : public IRMutator { public: Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { Stmt ret; const AttrStmt* adv = op->body.as<AttrStmt>(); if ((op->attr_key == ir::attr::channel_read_scope && adv && adv->attr_key == ir::attr::channel_read_advance) || (op->attr_key == ir::attr::channel_write_scope && adv && adv->attr_key == ir::attr::channel_write_advance)) { RewriteEntry e; e.window = op; e.advance = adv; e.read_access = op->attr_key == ir::attr::channel_read_scope; tasks_.push_back(e); ret = IRMutator::Mutate_(op, s); if (tasks_.back().rewrite_success) { ret = ret.as<AttrStmt>()->body.as<AttrStmt>()->body; } tasks_.pop_back(); return ret; } else { return IRMutator::Mutate_(op, s); } } Stmt Mutate_(const For* op, const Stmt& s) final { std::vector<RewriteEntry> tasks; std::swap(tasks_, tasks); Stmt body = op->body; std::vector<Stmt> nest; for (RewriteEntry& e : tasks) { body = RewriteAccess(op, body, &e, &nest); } if (!body.same_as(op->body)) { body = Mutate(body); body = For::make( op->loop_var, op->min, op->extent, op->for_type, op->device_api, body); body = MergeNest(nest, body); } else { CHECK_EQ(nest.size(), 0U); body = IRMutator::Mutate_(op, s); } std::swap(tasks_, tasks); return body; } private: struct RewriteEntry { bool read_access; const AttrStmt* window; const AttrStmt* advance; bool rewrite_success{false}; }; Stmt RewriteAccess(const For* for_op, Stmt body, RewriteEntry* e, std::vector<Stmt>* outer_nest) { const AttrStmt* adv_op = e->advance; const Expr& window = e->window->value; bool read_access = e->read_access; Var var(for_op->loop_var); Channel ch(adv_op->node.node_); ChannelAccessBound acc(ch->handle_var.get(), read_access); IntSet iset = acc.Eval(for_op->body); Range r = iset.cover_range(Range::make_by_min_extent(0, window)); r = Range::make_by_min_extent( ir::Simplify(r->min), ir::Simplify(r->extent)); if (ExprUseVar(r->extent, var)) return body; Array<Expr> linear_eq = DetectLinearEquation(r->min, {var}); if (linear_eq.size() == 0) return body; Expr coeff = linear_eq[0]; Expr base = linear_eq[1]; if (!is_zero(base)) return body; Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent); if (!can_prove(left >= 0)) return body; // rewrite access index. ChannelAccessIndexRewriter rw( ch->handle_var.get(), var * coeff, read_access); body = rw.Mutate(body); if (read_access) { body = AttrStmt::make( ch, ir::attr::channel_read_scope, r->extent, AttrStmt::make(ch, ir::attr::channel_read_advance, coeff, body)); } else { body = AttrStmt::make( ch, ir::attr::channel_write_scope, r->extent, AttrStmt::make(ch, ir::attr::channel_write_advance, coeff, body)); } if (!is_zero(left)) { Stmt no_op = Evaluate::make(0); if (read_access) { outer_nest->emplace_back( AttrStmt::make(ch, ir::attr::channel_read_advance, left, no_op)); } else { outer_nest->emplace_back( AttrStmt::make(ch, ir::attr::channel_write_advance, left, no_op)); } } e->rewrite_success = true; return body; } std::vector<RewriteEntry> tasks_; }; Stmt NarrowChannelAccess(Stmt stmt) { return ChannelAccessRewriter().Mutate(stmt); } } // namespace ir } // namespace tvm