bound_deducer.cc 10 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file bound_deducer.cc
 * \brief Utility to deduce bound of expression
 */
24 25 26
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/expr_functor.h>
27
#include <tvm/arith/analyzer.h>
28
#include <tvm/runtime/registry.h>
29

30 31
#include <unordered_set>
#include <unordered_map>
32
#include "interval_set.h"
33 34 35 36

namespace tvm {
namespace arith {

37
using namespace tir;
38 39 40

// a visitor to find the path to the target variable
// from a expression.
41
class VariablePathFinder: public ExprVisitor {
42
 public:
43
  explicit VariablePathFinder(PrimExpr target) : target_(target) {}
44

45
  void VisitExpr(const PrimExpr& node) final {
46 47 48 49 50
    if (visited_.count(node.get()) != 0) return;
    visited_.insert(node.get());

    if (!found_) path_.push_back(node.get());
    if (node.same_as(target_)) found_ = true;
51
    ExprVisitor::VisitExpr(node);
52 53 54
    if (!found_) path_.pop_back();
  }

55
  std::vector<const Object*> path_;
56 57 58

 private:
  bool found_{false};
59
  PrimExpr target_;
60
  std::unordered_set<const Object*> visited_;
61 62 63 64
};

// get the path to the variable,
// return empty vector to represent failure
65
std::vector<const Object*> GetPath(PrimExpr target, PrimExpr expr) {
66
  VariablePathFinder v(target);
67
  v(expr);
68 69 70
  return v.path_;
}

71 72
enum CompareOp {kGreater, kLess, kEqual};

73
// a visitor to deduce the bound of a variable from a expression
74
class BoundDeducer: public ExprVisitor {
75 76 77
 public:
  friend class BoundDeduceInputChecker;
  friend class Converter;
78
  BoundDeducer(PrimExpr target, PrimExpr expr,
79 80
               const std::unordered_map<const VarNode*, IntSet>& hint_map,
               const std::unordered_map<const VarNode*, IntSet>& relax_map)
81
  : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
82 83 84

  void Deduce();

85
  void VisitExpr(const PrimExpr& e) final {
86
    if (!success_) return;
87
    if (iter_ < path_.size() && e.get() == path_[iter_++]) {
88
      ExprVisitor::VisitExpr(e);
89
    } else {
90
      success_ = false;
91 92 93 94
      return;
    }
  }

95
  void VisitExpr_(const LTNode* op) final {
96 97 98
    LOG(FATAL) << "unable to deduce due to multiple comparison operator";
  }

99
  void VisitExpr_(const LENode* op) final {
100 101 102
    LOG(FATAL) << "unable to deduce due to multiple comparison operator";
  }

103
  void VisitExpr_(const GTNode* op) final {
104 105 106
    LOG(FATAL) << "unable to deduce due to multiple comparison operator";
  }

107
  void VisitExpr_(const GENode* op) final {
108 109 110
    LOG(FATAL) << "unable to deduce due to multiple comparison operator";
  }

111
  void VisitExpr_(const AddNode* op) final {
112
    bool left = op->a.get() == path_[iter_];
113
    result_ -= left ? op->b : op->a;
114
    this->VisitExpr(left ? op->a : op->b);
115 116
  }

117
  void VisitExpr_(const SubNode* op) final {
118 119
    bool left = op->a.get() == path_[iter_];
    if (left) {
120
      result_ += op->b;
121
    } else {
122 123
      result_ -= op->a;
      result_ = - result_;
124
      comp_op = ReverseOp(comp_op);
125
    }
126
    this->VisitExpr(left ? op->a : op->b);
127 128
  }

129
  void VisitExpr_(const MulNode* op) final {
130
    bool left = op->a.get() == path_[iter_];
131 132
    PrimExpr operand = left ? op->b : op->a;
    PrimExpr target_var = left ? op->a : op->b;
133

134
    SignType sign_operand;
135
    if (operand.dtype().is_uint()) {
136
      sign_operand = kPositive;
137
    } else {
138
      sign_operand = expr_map_[operand].sign_type();
139 140
    }

141
    if (sign_operand == SignType::kNegative) {
142
      comp_op = ReverseOp(comp_op);
143
    } else if (sign_operand == SignType::kUnknown) {
144
      // unable to get the sign of operand
145
      success_ = false;
146 147
      return;
    }
148

149
    // always use relax bound
150
    bool divided = analyzer_.CanProve(floormod(result_, operand) == 0);
151

152
    result_ = floordiv(result_, operand);   // rounding down here
153 154

    if (!divided) {
155
      if (comp_op == kGreater) {
156 157 158 159 160
        // System will round down in all the cases, so add one for result_ for kGreater
        // (x >= 3/2 --> x >= 2)
        // (x >= -3/2 --> x >= -1)
        // (x >= 3/-2 --> x >= -1)
        // (x >= -3/-2 --> x >= 2)
161
        result_ += 1;
162
      } else if (comp_op == kEqual) {
163
        // condition unsatisfiable as with floor div, it will change the expression
164 165
        success_ = false;
        return;
166
      } else {
167 168 169 170 171
        // System rounds down in all cases, do nothing for kLess.
        // ( x <= 3/2 --> x <= 1)
        // ( x <= -3/2 --> x <= -2)
        // ( x <= 3/-2 --> x <= -2)
        // ( x <= -3/-2 --> x <= 1)
172
      }
173
    }
174
    this->VisitExpr(left ? op->a : op->b);
175 176
  }

177
  PrimExpr result_;
178
  CompareOp comp_op{kGreater};
179
  bool success_{true};
180 181

 private:
182 183 184
  void Init();
  void Transform();
  void Relax();
185
  CompareOp ReverseOp(CompareOp comp_op);
186 187
  PrimExpr target_;
  PrimExpr expr_;
188 189
  const std::unordered_map<const VarNode*, IntSet>& hint_map_;
  const std::unordered_map<const VarNode*, IntSet>& relax_map_;
190
  ExprIntSetMap expr_map_;
191
  std::vector<const Object*> path_;
192
  size_t iter_{0};
193 194
  // internal analzyer
  Analyzer analyzer_;
195 196
};

197
class BoundDeduceInputChecker: public ExprVisitor {
198 199 200
 public:
  bool Check(BoundDeducer* deducer) {
    deducer_ = deducer;
201
    this->VisitExpr(deducer_->expr_);
202 203 204
    return target_count == 1;
  }

205
  void VisitExpr(const PrimExpr& e) final {
206
    if (e.same_as(deducer_->target_)) ++target_count;
207
    ExprVisitor::VisitExpr(e);
208 209 210 211 212 213 214
  }

 private:
  BoundDeducer* deducer_;
  size_t target_count{0};
};

215
void BoundDeducer::Init() {
216
  BoundDeduceInputChecker checker;
217
  if (!checker.Check(this)) success_ = false;
218 219
  Transform();
}
220

221 222 223 224 225 226 227 228 229 230 231
CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) {
  switch (comp_op) {
    case kEqual: return kEqual;   // IntSet can not represent range for `NE
    case kGreater: return kLess;
    case kLess: return kGreater;
    default:
      LOG(FATAL) << "Not a valid compare op";
      return kGreater;  // return some default value
  }
}

232
void BoundDeducer::Transform() {
233
  // We will ensure to set expr_ such that it contains target_
234
  if (const LTNode* op = expr_.as<LTNode>()) {
235 236
    if (GetPath(target_, op->a).empty()) {
      // a < b -> b >= a + 1
237
      comp_op = kGreater;
238
      expr_ = op->b;
239
      result_ = op->a + 1;
240 241
    } else {
      // a < b -> a <= b - 1
242
      comp_op = kLess;
243
      expr_ = op->a;
244
      result_ = op->b - 1;
245
    }
246
  } else if (const LENode* op = expr_.as<LENode>()) {
247 248
    if (GetPath(target_, op->a).empty()) {
      // a <= b -> b >= a
249
      comp_op = kGreater;
250
      expr_ = op->b;
251
      result_ = op->a;
252
    } else {
253
      comp_op = kLess;
254
      expr_ = op->a;
255
      result_ = op->b;
256
    }
257
  } else if (const GTNode* op = expr_.as<GTNode>()) {
258 259
    if (GetPath(target_, op->a).empty()) {
      // a > b -> b <= a - 1
260
      comp_op = kLess;
261
      expr_ = op->b;
262
      result_ = op->a - 1;
263 264
    } else {
      // a > b -> a >= b + 1
265
      comp_op = kGreater;
266
      expr_ = op->a;
267
      result_ = op->b + 1;
268
    }
269
  } else if (const GENode* op = expr_.as<GENode>()) {
270 271
    if (GetPath(target_, op->a).empty()) {
      // a >= b -> b <= a
272 273 274 275 276 277 278 279
      comp_op = kLess;
      expr_ = op->b;
      result_ = op->a;
    } else {
      comp_op = kGreater;
      expr_ = op->a;
      result_ = op->b;
    }
280
  } else if (const EQNode* op = expr_.as<EQNode>()) {
281 282 283
    comp_op = kEqual;
    if (GetPath(target_, op->a).empty()) {
      // if the b == a -> a == b
284
      expr_ = op->b;
285
      result_ = op->a;
286 287
    } else {
      expr_ = op->a;
288
      result_ = op->b;
289
    }
290
  } else {
291
    success_ = false;
292 293 294 295 296
  }
}

void BoundDeducer::Deduce() {
  Init();
297
  if (!success_) return;
298

299
  Relax();
300
  if (!success_) return;
301 302
  // get the path
  path_ = GetPath(target_, expr_);
303
  if (!path_.size()) {
304
    success_ = false;
305 306
    return;
  }
307
  expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
308

309
  this->VisitExpr(expr_);
310 311
}

312
void BoundDeducer::Relax() {
313
  IntSet a = EvalSet(expr_, relax_map_);
314
  IntSet b = EvalSet(result_, relax_map_);
315
  if (a.is_everything() || b.is_everything()) {
316
    success_ = false;
317
    return;
318
  }
319 320 321 322 323 324 325 326 327 328
  // Both LHS and RHS of the EQ should behave as constants e.g.  i == j,
  // can not be resolved when either `i` or `j`  or both are variables with
  // some Range OR `i` and `j` both should be a single point in IntSet
  if (comp_op == kEqual && (!analyzer_.CanProve(b.min() == b.max())
     || !analyzer_.CanProve(a.min() == a.max()))) {
    success_ = false;
    return;
  }
  expr_  = (comp_op == kGreater) ? a.min() : a.max();
  result_ = (comp_op == kGreater) ? b.max() : b.min();
329 330
}

331
IntSet DeduceBound(PrimExpr v, PrimExpr e,
332 333
  const std::unordered_map<const VarNode*, IntSet>& hint_map,
  const std::unordered_map<const VarNode*, IntSet>& relax_map) {
334
  BoundDeducer d(v, e, hint_map, relax_map);
335
  d.Deduce();
336
  if (!d.success_) return IntSet::nothing();
337
  PrimExpr min = neg_inf(), max = pos_inf();
338 339 340 341
  if (d.comp_op == kEqual) {
    min = d.result_;
    max = d.result_;
  } else if (d.comp_op == kGreater) {
342
    min = d.result_;
343
  } else {
344
    max = d.result_;
345 346 347 348
  }
  return IntSet::interval(min, max);
}

349 350
// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
351
IntSet DeduceBound(PrimExpr v, PrimExpr e,
352 353
                   const Map<Var, IntSet>& hint_map,
                   const Map<Var, IntSet>& relax_map) {
354
  std::unordered_map<const VarNode*, IntSet> hmap;
355 356 357
  for (auto kv : hint_map) {
    hmap[kv.first.get()] = kv.second;
  }
358
  std::unordered_map<const VarNode*, IntSet> rmap;
359 360 361 362 363 364
  for (auto kv : relax_map) {
    rmap[kv.first.get()] = kv.second;
  }
  return DeduceBound(v, e, hmap, rmap);
}

365 366
}  // namespace arith
}  // namespace tvm