/*! * Copyright (c) 2017 by Contributors * Vectorize the loop * \file vectorize_loop.cc */ #include <tvm/ir.h> #include <tvm/ir_pass.h> #include <tvm/ir_mutator.h> #include <unordered_set> #include <unordered_map> #include <vector> #include "../arithmetic/compute_expr.h" namespace tvm { namespace ir { inline Expr BroadcastTo(Expr e, int lanes) { if (e.type().lanes() == lanes) return e; CHECK_EQ(e.type().lanes(), 1) << "Cannot broadcast lane=" << e.type().lanes() << " to " << lanes; return Broadcast::make(e, lanes); } // Rewrite vectorized allocation access // s[i] = s[i * lanes + var] class VecAllocAccess : public IRMutator { public: VecAllocAccess(const Variable* buf, Var var, int var_lanes) : buf_(buf), var_(var), var_lanes_(var_lanes) {} // Load Expr Mutate_(const Load* op, const Expr& e) final { Expr expr = IRMutator::Mutate_(op, e); op = expr.as<Load>(); if (op->buffer_var.get() == buf_) { return Load::make(op->type, op->buffer_var, op->index * var_lanes_ + var_, op->predicate); } else { return expr; } } // Store Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as<Store>(); if (op->buffer_var.get() == buf_) { return Store::make(op->buffer_var, op->value, op->index * var_lanes_ + var_, op->predicate); } else { return stmt; } } private: // buffer var const Variable* buf_; // variable to be replaced Var var_; // the lanes. int var_lanes_; }; class Vectorizer : public IRMutator { public: Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) { ramp_ = Ramp::make(0, 1, var_lanes); } // user mutate from parent. using IRMutator::Mutate; Expr Mutate_(const Add* op, const Expr &e) final { return AddSubVec(op, e); } Expr Mutate_(const Sub* op, const Expr &e) final { return AddSubVec(op, e); } Expr Mutate_(const Mul* op, const Expr &e) final { return BinaryVec(op, e); } Expr Mutate_(const Div* op, const Expr &e) final { return BinaryVec(op, e); } Expr Mutate_(const Mod* op, const Expr &e) final { return BinaryVec(op, e); } Expr Mutate_(const Min* op, const Expr &e) final { return BinaryVec(op, e); } Expr Mutate_(const Max* op, const Expr &e) final { return BinaryVec(op, e); } Expr Mutate_(const EQ* op, const Expr &e) final { return BinaryVec(op, e); } Expr Mutate_(const NE* op, const Expr &e) final { return BinaryVec(op, e); } Expr Mutate_(const LT* op, const Expr &e) final { return BinaryVec(op, e); } Expr Mutate_(const GT* op, const Expr &e) final { return BinaryVec(op, e); } Expr Mutate_(const GE* op, const Expr &e) final { return BinaryVec(op, e); } Expr Mutate_(const And* op, const Expr &e) final { return BinaryVec(op, e); } Expr Mutate_(const Or* op, const Expr &e) final { return BinaryVec(op, e); } Expr Mutate_(const Select *op, const Expr& e) final { Expr cond = this->Mutate(op->condition); Expr t = this->Mutate(op->true_value); Expr f = this->Mutate(op->false_value); if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { return e; } else { int lanes = std::max(std::max( cond.type().lanes(), t.type().lanes()), f.type().lanes()); return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); } } Expr Mutate_(const Cast *op, const Expr& e) final { Expr value = this->Mutate(op->value); if (value.same_as(op->value)) { return e; } else { return Cast::make(op->type.with_lanes(value.type().lanes()), value); } } // Variable Expr Mutate_(const Variable* v, const Expr& e) final { if (v == var_.get()) { return ramp_; } else if (lets_.count(v)) { return lets_[v]; } else { return e; } } // Call Expr Mutate_(const Call* op, const Expr& e) final { int lane = 0; Array<Expr> new_args = MutateArray(op->args, &lane); if (op->args.same_as(new_args)) { return e; } else { return Call::make( op->type.with_lanes(lane), op->name, new_args, op->call_type, op->func, op->value_index); } } // Load Expr Mutate_(const Load* op, const Expr& e) final { Expr index = this->Mutate(op->index); Expr pred = this->Mutate(op->predicate); if (index.same_as(op->index) && pred.same_as(op->predicate)) { return e; } else { int lanes = std::max(index.type().lanes(), pred.type().lanes()); return Load::make( op->type.with_lanes(lanes), op->buffer_var, BroadcastTo(index, lanes), BroadcastTo(pred, lanes)); } } // Let Expr Mutate_(const Let* op, const Expr& e) final { Expr value = this->Mutate(op->value); CHECK(!lets_.count(op->var.get())) << "not SSA"; if (value.type().lanes() != op->value.type().lanes()) { Var v(op->var->name_hint, value.type()); lets_[op->var.get()] = v; return Let::make(v, value, Mutate(op->body)); } else { Expr body = this->Mutate(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return e; } else { return Let::make(op->var, value, body); } } } // Provide Stmt Mutate_(const Provide* op, const Stmt& s) final { Expr new_value = this->Mutate(op->value); int lane = new_value.type().lanes(); Array<Expr> new_args = MutateArray(op->args, &lane); if (op->args.same_as(new_args) && op->value.same_as(new_value)) { return s; } else { new_value = BroadcastTo(new_value, lane); return Provide::make(op->func, op->value_index, new_value, new_args); } } // Store Stmt Mutate_(const Store* op, const Stmt& s) final { Expr value = this->Mutate(op->value); Expr index = this->Mutate(op->index); Expr pred = this->Mutate(op->predicate); if (value.same_as(op->value) && index.same_as(op->index)) { return s; } else { int lanes = std::max(value.type().lanes(), index.type().lanes()); lanes = std::max(lanes, pred.type().lanes()); return Store::make(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes), BroadcastTo(pred, lanes)); } } // For Stmt Mutate_(const For* op, const Stmt& s) final { if (op->for_type == ForType::Vectorized) { LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring..."; } CHECK(is_zero(op->min)); CHECK(!op->extent.type().is_vector()); Expr extent = Mutate(op->extent); if (extent.type().is_vector()) { LOG(WARNING) << "Detect vectorized extent type, scalarizing..."; return Scalarize(s); } Stmt body = Mutate(op->body); if (extent.same_as(op->extent) && body.same_as(op->body)) { return s; } else { return For::make( op->loop_var, op->min, extent, op->for_type, op->device_api, body); } } // IfThenElse Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { CHECK(!op->condition.type().is_vector()); Expr condition = this->Mutate(op->condition); if (condition.type().is_vector()) { LOG(WARNING) << "Detect vector condition in Vectorized Loop, scalarizing..."; return Scalarize(s); } Stmt then_case = this->Mutate(op->then_case); Stmt else_case; if (op->else_case.defined()) { else_case = this->Mutate(op->else_case); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return s; } else { return IfThenElse::make(condition, then_case, else_case); } } // LetStmt Stmt Mutate_(const LetStmt* op, const Stmt& s) final { LOG(WARNING) << "Cannot vectorize with LetStmt, remove it with Simplify Before Vectorize"; return Scalarize(s); } // Allocate Stmt Mutate_(const Allocate* op, const Stmt& s) final { if (op->new_expr.defined()) { LOG(WARNING) << "Cannot vectorize with new expr"; return Scalarize(s); } Expr condition = Mutate(op->condition); if (condition.type().is_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc "; return Scalarize(s); } Array<Expr> extents; for (size_t i = 0; i < op->extents.size(); i++) { Expr new_ext = Mutate(op->extents[i]); if (new_ext.type().is_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc "; return Scalarize(s); } extents.push_back(new_ext); } // place the vector lanes in least significant dimension. extents.push_back(var_lanes_); // rewrite access to buffer internally. Stmt body = VecAllocAccess( op->buffer_var.get(), var_, var_lanes_).Mutate(op->body); body = Mutate(body); return Allocate::make( op->buffer_var, op->type, extents, condition, body, op->new_expr, op->free_function); } // scalarize the statment Stmt Scalarize(Stmt stmt) { Var idx(var_->name_hint + ".s", var_->type); stmt = Substitute(stmt, {{var_, idx}}); return For::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); } private: // variable to be replaced Var var_; // the lanes. int var_lanes_; // ramp representing the var. Expr ramp_; // The lets std::unordered_map<const Variable*, Expr> lets_; // mutate array, with given lane requirement // when finished, p_lane updates the lane requirement. Array<Expr> MutateArray(Array<Expr> arr, int* p_lanes) { if (arr.size() == 0) return arr; int& lanes = *p_lanes; bool changed = false; std::vector<Expr> new_arr(arr.size()); for (size_t i = 0; i < arr.size(); i++) { Expr old_elem = arr[i]; Expr new_elem = this->Mutate(old_elem); if (!new_elem.same_as(old_elem)) changed = true; new_arr[i] = new_elem; lanes = std::max(lanes, new_elem.type().lanes()); } for (size_t i = 0; i < arr.size(); ++i) { if (new_arr[i].type().lanes() != lanes) { new_arr[i] = BroadcastTo(new_arr[i], lanes); changed = true; } } if (!changed) return arr; return Array<Expr>(new_arr); } template<typename T> Expr BinaryVec(const T* op, const Expr& e) { Expr a = this->Mutate(op->a); Expr b = this->Mutate(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { return e; } else { int lanes = std::max(a.type().lanes(), b.type().lanes()); return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } } template<typename T> Expr AddSubVec(const T* op, const Expr& e) { Expr a = this->Mutate(op->a); Expr b = this->Mutate(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { return e; } else { int lanes = std::max(a.type().lanes(), b.type().lanes()); if (lanes != 1) { const Ramp* b_ramp = b.as<Ramp>(); const Ramp* a_ramp = a.as<Ramp>(); if (a.type().lanes() == 1 && b_ramp) { return Ramp::make( arith::ComputeExpr<T>(a, b_ramp->base), arith::ComputeExpr<T>(make_zero(b_ramp->stride.type()), b_ramp->stride), b_ramp->lanes); } if (b.type().lanes() == 1 && a_ramp) { return Ramp::make( arith::ComputeExpr<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); } } return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } } }; class LoopVectorizer : public IRMutator { public: Stmt Mutate_(const For* op, const Stmt& s) final { if (op->for_type == ForType::Vectorized) { CHECK(is_zero(op->min)); CHECK(is_positive_const(op->extent)); int lanes = 0; bool succ = arith::GetConstInt(op->extent, &lanes); if (!succ || lanes < 1) { LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; } Var var(op->loop_var.node_); return Vectorizer(var, lanes).Mutate(op->body); } else { return IRMutator::Mutate_(op, s); } } }; Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer().Mutate(stmt); } } // namespace ir } // namespace tvm