bound_checker.cc 7.91 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 * \file bounds_checker.cc
 */
// Instrument checkers for out of the bounds access.

25 26
#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
27
#include <tvm/tir/expr.h>
28 29
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
30
#include <tvm/tir/stmt_functor.h>
31 32 33 34 35
#include <vector>
#include <unordered_map>
#include <utility>

namespace tvm {
36
namespace tir {
37

38
class BoundCollector : public StmtVisitor {
39 40 41
 public:
  BoundCollector() {}

42
  void VisitStmt_(const AttrStmtNode* op) final {
43
    if (op->attr_key == tir::attr::buffer_bound) {
44
      if (const VarNode *key = op->node.as<VarNode>()) {
45 46 47
        mem_to_shape[key] = op->value;
      }
    }
48
    StmtVisitor::VisitStmt_(op);
49 50
  }
  // Hashtable which maps buffer_var to shape.
51
  std::unordered_map<const VarNode *, PrimExpr> mem_to_shape;
52 53
};

54
class BoundChecker : public StmtExprMutator {
55 56
 public:
  explicit BoundChecker(
57
      const std::unordered_map<const VarNode *, PrimExpr> &mem_to_shape)
58 59
      : mem_to_shape_(mem_to_shape) {}

60
  Stmt VisitStmt_(const AllocateNode* op) final {
61 62
    // If the shape was updated we should update the hashtable.
    if (UpdateIsNeeded(op->buffer_var)) {
63
      Update(op->buffer_var, op->extents, op->dtype);
64
    }
65
    return StmtExprMutator::VisitStmt_(op);
66 67
  }

68
  PrimExpr VisitExpr_(const CallNode* op) final {
69 70 71
    if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) {
      unsafe_rewritten_ = true;
    }
72
    return StmtExprMutator::VisitExpr_(op);
73 74
  }

75
  Stmt VisitStmt_(const StoreNode* op) final {
76 77 78
    store_scope_bound_collector_.clear();
    process_store_ = true;
    unsafe_rewritten_ = false;
79
    StmtExprMutator::VisitStmt_(op);
80 81 82 83 84 85
    process_store_ = false;
    if (CanInstrument(op->index, op->buffer_var)) {
      Collect(op->index, op->buffer_var);
    }
    // The collector should has at least one item.
    if (store_scope_bound_collector_.size()) {
86
      PrimExpr condition = MakeCondition();
87 88
      if (!condition.as<StringImmNode>()) {
        Stmt nop = EvaluateNode::make(1);
89
        Stmt then_case =
90
            StoreNode::make(op->buffer_var, op->value, op->index, op->predicate);
91
        Stmt else_case =
92 93
            AssertStmtNode::make(condition, StringImmNode::make(error_message_), nop);
        Stmt body = IfThenElseNode::make(condition, then_case, else_case);
94 95 96
        return body;
      }
    }
97
    return GetRef<Stmt>(op);
98 99
  }

100
  PrimExpr VisitExpr_(const LoadNode* op) final {
101 102 103
    if (CanInstrument(op->index, op->buffer_var)) {
      Collect(op->index, op->buffer_var);
    }
104
    return StmtExprMutator::VisitExpr_(op);
105 106 107
  }

 private:
108
  bool UpdateIsNeeded(const Var& buffer_var) const {
109 110 111
    return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
  }

112 113
  void Update(const Var& buffer_var,
              const Array<PrimExpr>& new_shape,
114
              const DataType& type) {
115 116 117 118 119 120
    // Sanity check at first.
    if (!new_shape.size()) {
      return;
    }

    for (size_t i = 0; i < new_shape.size(); ++i) {
121
      if (!new_shape[0].defined() || !new_shape[i].dtype().is_scalar() ||
122 123 124 125 126 127
          is_negative_const(new_shape[i])) {
        return;
      }
    }

    // Scalarize the shape.
128
    PrimExpr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()),
129
                           CastNode::make(DataType::UInt(64), new_shape[0]));
130 131
    for (size_t i = 1; i < new_shape.size(); ++i) {
      // Cast to unsigned to avoid integer overlow at frist.
132 133
      shape = MulNode::make(shape, MulNode::make(make_const(DataType::UInt(64), type.lanes()),
                                         CastNode::make(DataType::UInt(64), new_shape[i])));
134 135 136 137
    }
    mem_to_shape_[buffer_var.get()] = shape;
  }

138
  bool IndexIsValid(const PrimExpr& index) const {
139 140 141 142
    if (!index.defined()) {
      return false;
    }

143
    if (const RampNode *ramp_index = index.as<RampNode>()) {
144
      return ramp_index->base.defined() &&
145
             ramp_index->base.dtype().is_scalar() &&
146
             ramp_index->stride.defined() &&
147
             ramp_index->stride.dtype().is_scalar() && (ramp_index->lanes > 0);
148 149 150 151
    }
    return true;
  }

152
  bool CanInstrument(const PrimExpr& index, const Var& buffer_var) const {
153 154 155 156
    return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) &&
           IndexIsValid(index) && !unsafe_rewritten_;
  }

157
  void Collect(PrimExpr index, Var buffer_var) {
158 159 160 161
    store_scope_bound_collector_.push_back(
        std::make_pair(index, mem_to_shape_[buffer_var.get()]));
  }

162 163
  PrimExpr MakeCondition() {
    PrimExpr condition;
164
    for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) {
165 166 167
      std::pair<PrimExpr, PrimExpr> buffer_to_mem = store_scope_bound_collector_[i];
      PrimExpr index = buffer_to_mem.first;
      PrimExpr upper_bound = buffer_to_mem.second;
168

169
      if (const RampNode *ramp_index = index.as<RampNode>()) {
170 171
        // In case index is base + stride * i.
        // Non inclusive range.
172
        index = AddNode::make(
173
            ramp_index->base,
174
            MulNode::make(ramp_index->stride, make_const(ramp_index->stride.dtype(),
175 176 177 178
                                                     ramp_index->lanes - 1)));
      }

      // Try to simplify index and bound.
179 180
      index = analyzer_.Simplify(index);
      upper_bound = analyzer_.Simplify(upper_bound);
181 182

      // Cast to the same type - signed, to be able to check lower bound.
183 184
      index = CastNode::make(DataType::Int(64), index);
      upper_bound = CastNode::make(DataType::Int(64), upper_bound);
185 186

      // Looks like a lower bound should always be zero after normalization.
187
      PrimExpr lower_bound = make_zero(DataType::Int(64));
188

189
      PrimExpr current_condition =
190
          AndNode::make(GENode::make(index, lower_bound), LTNode::make(index, upper_bound));
191
      condition =
192
          !i ? current_condition : AndNode::make(condition, current_condition);
193 194 195 196 197 198 199 200 201
    }
    return condition;
  }

  // Whether we process store value recursively.
  bool process_store_{false};
  // Whether we face tvm_if_then_else intrinsic.
  bool unsafe_rewritten_{false};
  // Pool which collects the pair of index and shape for specific store/load.
202
  std::vector<std::pair<PrimExpr, PrimExpr>> store_scope_bound_collector_;
203 204 205
  // Error message.
  const char *const error_message_ = "OUT OF THE BOUNDS";
  // Hashtable which maps buffer_var to shape.
206
  std::unordered_map<const VarNode *, PrimExpr> mem_to_shape_;
207 208
  // internal analyzer
  arith::Analyzer analyzer_;
209 210 211 212 213
};

Stmt InstrumentBoundCheckers(Stmt stmt) {
  BoundCollector bound_collector;
  // At first walk recursively and collect bound attributes.
214 215
  bound_collector(stmt);
  return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt));
216
}
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236

namespace transform {

Pass InstrumentBoundCheckers() {
  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
    auto* n = f.CopyOnWrite();
    BoundCollector bound_collector;
    // At first walk recursively and collect bound attributes.
    bound_collector(n->body);
    n->body = BoundChecker(bound_collector.mem_to_shape)(std::move(n->body));
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {});
}

TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers")
.set_body_typed(InstrumentBoundCheckers);

}  // namespace transform

237
}  // namespace tir
238
}  // namespace tvm