/*! * Copyright (c) 2018 by Contributors * \file bounds_checker.cc */ // Instrument checkers for out of the bounds access. #include <tvm/ir.h> #include <tvm/ir_mutator.h> #include <tvm/ir_pass.h> #include <tvm/ir_visitor.h> #include <vector> #include <unordered_map> #include <utility> namespace tvm { namespace ir { class BoundCollector : public IRVisitor { public: BoundCollector() {} void Visit_(const AttrStmt *op) { if (op->attr_key == ir::attr::buffer_bound) { if (const Variable *key = op->node.as<Variable>()) { mem_to_shape[key] = op->value; } } IRVisitor::Visit_(op); } // Hashtable which maps buffer_var to shape. std::unordered_map<const Variable *, Expr> mem_to_shape; }; class BoundChecker : public IRMutator { public: explicit BoundChecker( const std::unordered_map<const Variable *, Expr> &mem_to_shape) : mem_to_shape_(mem_to_shape) {} Stmt Mutate_(const Allocate *op, const Stmt &s) final { // If the shape was updated we should update the hashtable. if (UpdateIsNeeded(op->buffer_var)) { Update(op->buffer_var, op->extents, op->type); } return IRMutator::Mutate_(op, s); } Expr Mutate_(const Call *op, const Expr &ex) final { if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) { unsafe_rewritten_ = true; } return IRMutator::Mutate_(op, ex); } Stmt Mutate_(const Store *op, const Stmt &s) final { store_scope_bound_collector_.clear(); process_store_ = true; unsafe_rewritten_ = false; IRMutator::Mutate_(op, s); process_store_ = false; if (CanInstrument(op->index, op->buffer_var)) { Collect(op->index, op->buffer_var); } // The collector should has at least one item. if (store_scope_bound_collector_.size()) { Expr condition = MakeCondition(); if (!condition.as<StringImm>()) { Stmt nop = Evaluate::make(1); Stmt then_case = Store::make(op->buffer_var, op->value, op->index, op->predicate); Stmt else_case = AssertStmt::make(condition, StringImm::make(error_message_), nop); Stmt body = IfThenElse::make(condition, then_case, else_case); return body; } } return s; } Expr Mutate_(const Load *op, const Expr &ex) final { if (CanInstrument(op->index, op->buffer_var)) { Collect(op->index, op->buffer_var); } return IRMutator::Mutate_(op, ex); } private: bool UpdateIsNeeded(const VarExpr &buffer_var) const { return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); } void Update(const VarExpr &buffer_var, const Array<Expr> &new_shape, const Type &type) { // Sanity check at first. if (!new_shape.size()) { return; } for (size_t i = 0; i < new_shape.size(); ++i) { if (!new_shape[0].defined() || !new_shape[i].type().is_scalar() || is_negative_const(new_shape[i])) { return; } } // Scalarize the shape. Expr shape = Mul::make(make_const(UInt(64), type.lanes()), Cast::make(UInt(64), new_shape[0])); for (size_t i = 1; i < new_shape.size(); ++i) { // Cast to unsigned to avoid integer overlow at frist. shape = Mul::make(shape, Mul::make(make_const(UInt(64), type.lanes()), Cast::make(UInt(64), new_shape[i]))); } mem_to_shape_[buffer_var.get()] = shape; } bool IndexIsValid(const Expr &index) const { if (!index.defined()) { return false; } if (const Ramp *ramp_index = index.as<Ramp>()) { return ramp_index->base.defined() && ramp_index->base.type().is_scalar() && ramp_index->stride.defined() && ramp_index->stride.type().is_scalar() && (ramp_index->lanes > 0); } return true; } bool CanInstrument(const Expr &index, const VarExpr &buffer_var) const { return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndexIsValid(index) && !unsafe_rewritten_; } void Collect(Expr index, VarExpr buffer_var) { store_scope_bound_collector_.push_back( std::make_pair(index, mem_to_shape_[buffer_var.get()])); } Expr MakeCondition() { Expr condition; for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) { std::pair<Expr, Expr> buffer_to_mem = store_scope_bound_collector_[i]; Expr index = buffer_to_mem.first; Expr upper_bound = buffer_to_mem.second; if (const Ramp *ramp_index = index.as<Ramp>()) { // In case index is base + stride * i. // Non inclusive range. index = Add::make( ramp_index->base, Mul::make(ramp_index->stride, make_const(ramp_index->stride.type(), ramp_index->lanes - 1))); } // Try to simplify index and bound. index = ir::Simplify(index); upper_bound = ir::Simplify(upper_bound); // Cast to the same type - signed, to be able to check lower bound. index = Cast::make(Int(64), index); upper_bound = Cast::make(Int(64), upper_bound); // Looks like a lower bound should always be zero after normalization. Expr lower_bound = make_zero(Int(64)); Expr current_condition = And::make(GE::make(index, lower_bound), LT::make(index, upper_bound)); condition = !i ? current_condition : And::make(condition, current_condition); } return condition; } // Whether we process store value recursively. bool process_store_{false}; // Whether we face tvm_if_then_else intrinsic. bool unsafe_rewritten_{false}; // Pool which collects the pair of index and shape for specific store/load. std::vector<std::pair<Expr, Expr>> store_scope_bound_collector_; // Error message. const char *const error_message_ = "OUT OF THE BOUNDS"; // Hashtable which maps buffer_var to shape. std::unordered_map<const Variable *, Expr> mem_to_shape_; }; Stmt InstrumentBoundCheckers(Stmt stmt) { BoundCollector bound_collector; // At first walk recursively and collect bound attributes. bound_collector.Visit(stmt); return BoundChecker(bound_collector.mem_to_shape).Mutate(stmt); } } // namespace ir } // namespace tvm