/*! * Copyright (c) 2017 by Contributors * Loop unrolling. * \file unroll_loop.cc */ #include <tvm/ir.h> #include <tvm/ir_pass.h> #include <tvm/ir_mutator.h> #include <unordered_set> #include <unordered_map> #include <vector> #include "../arithmetic/compute_expr.h" namespace tvm { namespace ir { class LoopUnroller : public IRMutator { public: explicit LoopUnroller(int auto_max_step, int auto_min_depth, bool explicit_unroll) : auto_max_step_(auto_max_step), auto_min_depth_(auto_min_depth), explicit_unroll_(explicit_unroll) { } Stmt Mutate_(const For* op, const Stmt& s) { Stmt stmt = s; // constant folding. Expr extent = ir::Simplify(op->extent); const IntImm* v1 = extent.as<IntImm>(); const UIntImm* v2 = extent.as<UIntImm>(); int value = -1; if (v1 != nullptr) { value = static_cast<int>(v1->value); } if (v2 != nullptr) { value = static_cast<int>(v2->value); } bool auto_unroll = (op->for_type == ForType::Serial && value >= 0 && value <= auto_max_step_ && loop_depth_ >= auto_min_depth_); if (op->for_type == ForType::Unrolled) { CHECK_GE(value, 0) << "Cannot unroll non-constant loop"; auto_unroll = true; } if (auto_unroll && explicit_unroll_) { using arith::ComputeExpr; if (value == 0) return Evaluate::make(0); Stmt body = op->body; Map<Var, Expr> vmap; Stmt unrolled; for (int i = 0; i < value; ++i) { Var lv(op->loop_var.node_); vmap.Set(lv, ComputeExpr<Add>( op->min, make_const(op->loop_var.type(), i))); Stmt step = Substitute(body, vmap); if (unrolled.defined()) { unrolled = Block::make(unrolled, step); } else { unrolled = step; } } ++loop_depth_; Stmt ret = this->Mutate(unrolled); --loop_depth_; return ret; } else { ++loop_depth_; Stmt ret = IRMutator::Mutate_(op, stmt); if (auto_unroll) { op = ret.as<For>(); if (op->for_type != ForType::Unrolled) { ret = For::make( op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api, op->body); } } --loop_depth_; return ret; } } private: // maximum number of step to perform auto unroll. int auto_max_step_; int auto_min_depth_; bool explicit_unroll_; int loop_depth_{0}; }; Stmt UnrollLoop(Stmt stmt, int auto_max_step, int auto_min_depth, bool explicit_unroll) { Stmt ret = LoopUnroller( auto_max_step, auto_min_depth, explicit_unroll).Mutate(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { return ret; } } } // namespace ir } // namespace tvm