/*! * Copyright (c) 2016 by Contributors * \file buffer.cc */ #include <tvm/buffer.h> #include <tvm/runtime/device_api.h> #include <tvm/ir.h> #include <tvm/ir_pass.h> #include <iterator> #include "../arithmetic/compute_expr.h" namespace tvm { Array<Expr> SimplifyArray(Array<Expr> array) { for (size_t i = 0; i < array.size(); ++i) { array.Set(i, ir::Simplify(array[i])); } return array; } Buffer decl_buffer(Array<Expr> shape, Type dtype, std::string name) { return BufferNode::make( Var(name, Handle()), dtype, shape, Array<Expr>(), Expr(), name, "", 0, 0); } // Split the given expression w.r.t the add operator inline std::vector<const Expr*> ExprSplitAddition(const Expr &expr) { using namespace ir; std::vector<const Expr*> ret; std::stack<const Expr*> split_buffer; split_buffer.push(&expr); while (!split_buffer.empty()) { const Expr* top_ele = split_buffer.top(); split_buffer.pop(); auto expr_add_match = top_ele->as<Add>(); if (expr_add_match) { split_buffer.push(&expr_add_match->b); split_buffer.push(&expr_add_match->a); } else { ret.emplace_back(top_ele); } } return ret; } // Searches for the following types of expr: // mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki // mod_l_expr = c // mod_r_expr = k1 * k2 * ... * ki // If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c) // Currently the we will not search the add/mult combinations exhaustively // as it will take too much computation. inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr, const Expr &mod_l_expr, const Expr &mod_r_expr) { using namespace ir; const Mul* mult_ptr = mult_expr.as<Mul>(); if (!mult_ptr) return std::make_pair(false, Expr()); Expr mult_outer = mult_ptr->b; const Expr* inner = &(mult_ptr->a); // 1. Calculate the outer multiplier while (true) { mult_ptr = inner->as<Mul>(); if (mult_ptr) { inner = &(mult_ptr->a); mult_outer = mult_ptr->b * mult_outer; } else { break; } } // 2. Search for the pattern c / (...) * (...) + c % (...) // We match the search element with Add, Mul and Div. // If Add is found, we need to continue our search for the rhs // If Mult is found, we will expand the inner multiplication factor // If Div is found, we will go on testing whether lhs matches the lhs of mod expr // and returns the optimization result. const Expr* search_ptr = inner; Expr mult_inner; // The inner multiplication factor Expr no_opt_sum; // Sum of the exprs that cannot be optimized while (true) { auto inner_div_ptr = search_ptr->as<Div>(); auto inner_mult_ptr = search_ptr->as<Mul>(); auto inner_add_ptr = search_ptr->as<Add>(); if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) { return std::make_pair(false, Expr()); } else if (inner_div_ptr) { Expr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer; if (Equal(overall_mult, inner_div_ptr->b) && Equal(overall_mult, mod_r_expr) && Equal(inner_div_ptr->a, mod_l_expr)) { // Found! Expr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr; return std::make_pair(true, ret); } else { return std::make_pair(false, Expr()); } } else if (inner_mult_ptr) { mult_inner = mult_inner.get() ? inner_mult_ptr->b * mult_inner : inner_mult_ptr->b; search_ptr = &(inner_mult_ptr->a); } else if (inner_add_ptr) { if (mult_inner.get()) { return std::make_pair(false, Expr()); } no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a; search_ptr = &(inner_add_ptr->b); } else { LOG(FATAL) << "Unexpected search result!"; break; } } return std::make_pair(false, Expr()); } // Insert the elements into the corresponding mult_exprs and mod_exprs. // If the element is found to match Mul, it will be pushed to the mult_exprs. // If the element it found to match Mod, it will be pused to the mod_exprs. // Otherwise, the elements will be added to the no_opt_sum variable inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles, std::list<Expr>* mult_exprs, std::list<std::pair<Expr, Expr> >* mod_exprs, Expr* no_opt_sum, bool* has_mult, bool* has_mod) { using namespace ir; *has_mult = false; *has_mod = false; for (const Expr* ele : eles) { auto mod_ptr = ele->as<Mod>(); auto mult_ptr = ele->as<Mul>(); if (mod_ptr) { *has_mod = true; mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b))); } else if (mult_ptr) { *has_mult = true; mult_exprs->emplace_back(*ele); } else { *no_opt_sum = no_opt_sum->get() ? *no_opt_sum + *ele : *ele; } } } // Searches for this types of expr: // (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki // + c % (k1 * k2 * ... * ki) // and simplifies to (a1 + a2 + ... + aj) * kt * ... * ki + c // The search will be performed repeatively until no pattern is found. // Return: a pair with (false, Expr()) if cannot be optimized. // a pair with (true, optimized_expr) if can be optimized inline Expr MergeMulMod(const Expr &base) { using namespace ir; // 1. Prepare the lists. // We store two lists, a list that contain all the elements that match Mul and // a list that contain all the elements that match Mod. // The elements in the Mod will be used to match against the elements in Mul. // The result will then be split and pushed back to these two lists. Expr simplified_base = Simplify(base); std::vector<const Expr*> eles = ExprSplitAddition(simplified_base); std::list<Expr> mult_exprs; std::list<std::pair<Expr, Expr> > mod_exprs; Expr no_opt_sum; bool has_mult; bool has_mod; MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, &has_mod); bool find_opt = false; std::list<std::pair<Expr, Expr> >::iterator search_mod_it = mod_exprs.begin(); // 2. Exhaustive Search while (search_mod_it != mod_exprs.end()) { std::list<Expr>::iterator mult_it = mult_exprs.begin(); bool inner_find_opt = false; while (mult_it != mult_exprs.end()) { std::pair<bool, Expr> ret = MergeMulModInner(*mult_it, search_mod_it->first, search_mod_it->second); if (ret.first) { inner_find_opt = true; auto temp_mod_it = search_mod_it; ++search_mod_it; mod_exprs.erase(temp_mod_it); mult_exprs.erase(mult_it); std::vector<const Expr*> ret_eles = ExprSplitAddition(ret.second); MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, &has_mod); if (has_mult) { search_mod_it = mod_exprs.begin(); } else if (has_mod && search_mod_it == mod_exprs.end()) { search_mod_it--; } break; } else { ++mult_it; } } find_opt = find_opt || inner_find_opt; if (!inner_find_opt) { ++search_mod_it; } } if (!find_opt) { return simplified_base; } for (std::list<Expr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) { no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it; } for (std::list<std::pair<Expr, Expr> >::iterator it = mod_exprs.begin(); it != mod_exprs.end(); ++it) { no_opt_sum = no_opt_sum.get() ? no_opt_sum + it->first % it->second : it->first % it->second; } return no_opt_sum; } // The buffer offset in convention of number of elements of // original data ignoring number of lanes. // We also perform optimization to simplify the indexing expression. inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) { Expr base = n->elem_offset; if (n->strides.size() == 0) { CHECK_EQ(n->shape.size(), index.size()); if (n->shape.size() != 0) { if (is_zero(base)) { base = index[0]; } else { base = base + index[0]; } } base = MergeMulMod(base); for (size_t i = 1; i < index.size(); ++i) { base = MergeMulMod(base * n->shape[i] + index[i]); } } else { CHECK_EQ(n->strides.size(), index.size()); if (is_zero(base)) { base = MergeMulMod(index[0] * n->strides[0]); } else { base = MergeMulMod(base + index[0] * n->strides[0]); } for (size_t i = 1; i < index.size(); ++i) { base = MergeMulMod(base + index[i] * n->strides[i]); } } return base; } inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, Type dtype) { Expr offset = ElemOffset(n, index); if (n->dtype.lanes() != 1) { offset = offset * make_const(offset.type(), dtype.lanes()); } if (dtype.lanes() != 1) { return ir::Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes()); } else { return offset; } } Expr Buffer::vload(Array<Expr> begin, Type dtype) const { const BufferNode* n = operator->(); CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) << "Cannot load " << dtype << " from buffer of " << n->dtype; return ir::Load::make( dtype, n->data, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); } Stmt Buffer::vstore(Array<Expr> begin, Expr value) const { const BufferNode* n = operator->(); Type dtype = value.type(); CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) << "Cannot load " << dtype << " from buffer of " << n->dtype; return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); } Buffer Buffer::MakeStrideView() const { if ((*this)->strides.size() != 0) return *this; if ((*this)->shape.size() == 0) return *this; std::vector<Expr> temp; auto n = std::make_shared<BufferNode>(*operator->()); Expr acc = make_const(n->DefaultIndexType(), 1); for (size_t i = n->shape.size(); i != 0 ; --i) { temp.push_back(acc); acc = acc * n->shape[i - 1]; } for (size_t i = temp.size(); i != 0; --i) { n->strides.push_back(temp[i - 1]); } return Buffer(n); } Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const { const BufferNode* n = operator->(); begins = SimplifyArray(begins); Expr elem_offset = ir::Simplify(ElemOffset(n, begins)); Array<Expr> strides = n->strides; if (strides.size() == 0) { bool can_relax = true; bool need_stride = false; // check if stride is needed. for (size_t i = 0; i < extents.size(); ++i) { if (!can_relax) { if (!is_zero(begins[i]) || !is_zero(ir::Simplify(extents[i] - n->shape[i]))) { need_stride = true; } } if (!is_one(extents[i])) can_relax = false; } // make stride. if (need_stride) { return MakeStrideView().MakeSlice(begins, extents); } } return BufferNode::make(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->scope, n->data_alignment, 0); } Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const { const BufferNode* self = operator->(); Expr e_dtype; Expr extent; if (self->shape.size() == 0) { extent = make_const(self->DefaultIndexType(), 1); } else if (self->strides.size() == self->shape.size()) { int highest_dim = 0; extent = arith::ComputeExpr<ir::Mul>( self->strides[highest_dim], self->shape[highest_dim]); } else { extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()); } Expr elem_offset = self->elem_offset + offset; if (content_lanes > 1) { e_dtype = make_zero(self->dtype.with_lanes(content_lanes)); extent = extent / make_const(self->elem_offset.type(), content_lanes); elem_offset = self->elem_offset / make_const(self->elem_offset.type(), content_lanes); } else { e_dtype = make_zero(self->dtype); } Array<Expr> acc_args{ e_dtype, self->data, elem_offset, extent, make_const(Int(32), access_mask)}; return ir::Call::make( ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic); } Buffer BufferNode::make(Var data, Type dtype, Array<Expr> shape, Array<Expr> strides, Expr elem_offset, std::string name, std::string scope, int data_alignment, int offset_factor) { auto n = std::make_shared<BufferNode>(); n->data = std::move(data); n->dtype = dtype; n->shape = std::move(shape); n->strides = std::move(strides); n->name = std::move(name); if (scope.length() == 0) { scope = "global"; } n->scope = std::move(scope); if (!elem_offset.defined()) { elem_offset = make_const(n->DefaultIndexType(), 0); } if (data_alignment <= 0) { data_alignment = runtime::kAllocAlignment; } if (offset_factor == 0) { offset_factor = 1; } n->elem_offset = std::move(elem_offset); n->data_alignment = data_alignment; n->offset_factor = offset_factor; return Buffer(n); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch<BufferNode>([](const BufferNode *op, IRPrinter *p) { p->stream << "buffer(" << op->name << ", " << op << ")"; }); TVM_REGISTER_NODE_TYPE(BufferNode); } // namespace tvm