bound_deducer.cc 8.13 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 *  Copyright (c) 2017 by Contributors
 * \file bound_deducer.cc
 * \brief Utility to deduce bound of expression
 */
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
28
#include <tvm/arithmetic.h>
29
#include <tvm/api_registry.h>
30

31 32 33 34 35 36 37
#include <unordered_set>
#include <unordered_map>

namespace tvm {
namespace arith {

using namespace ir;
38
using HalideIR::Internal::Interval;
39 40 41 42 43

// a visitor to find the path to the target variable
// from a expression.
class VariablePathFinder: public IRVisitor {
 public:
44
  explicit VariablePathFinder(Expr target) : target_(target) {}
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59

  void Visit(const NodeRef& node) final {
    if (visited_.count(node.get()) != 0) return;
    visited_.insert(node.get());

    if (!found_) path_.push_back(node.get());
    if (node.same_as(target_)) found_ = true;
    IRVisitor::Visit(node);
    if (!found_) path_.pop_back();
  }

  std::vector<const Node*> path_;

 private:
  bool found_{false};
60
  Expr target_;
61 62 63 64 65
  std::unordered_set<const Node*> visited_;
};

// get the path to the variable,
// return empty vector to represent failure
66
std::vector<const Node*> GetPath(Expr target, Expr expr) {
67 68 69 70 71 72 73 74 75 76 77 78
  VariablePathFinder v(target);
  v.Visit(expr);
  return v.path_;
}

class BoundDeduceIntputChecker;

// a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor {
 public:
  friend class BoundDeduceInputChecker;
  friend class Converter;
79 80 81 82
  BoundDeducer(Expr target, Expr expr,
               const std::unordered_map<const Variable*, IntSet>& hint_map,
               const std::unordered_map<const Variable*, IntSet>& relax_map)
  : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149

  void Deduce();

  void Visit(const NodeRef& e) final {
    if (!success) return;
    if (e.get() == path_[iter_++]) {
      IRVisitor::Visit(e);
    } else {
      success = false;
      return;
    }
  }

  void Visit_(const LT* op) final {
    LOG(FATAL) << "unable to deduce due to multiple comparison operator";
  }

  void Visit_(const LE* op) final {
    LOG(FATAL) << "unable to deduce due to multiple comparison operator";
  }

  void Visit_(const GT* op) final {
    LOG(FATAL) << "unable to deduce due to multiple comparison operator";
  }

  void Visit_(const GE* op) final {
    LOG(FATAL) << "unable to deduce due to multiple comparison operator";
  }

  void Visit_(const Add* op) final {
    bool left = op->a.get() == path_[iter_];
    result -= left ? op->b : op->a;
    Visit(left ? op->a : op->b);
  }

  void Visit_(const Sub* op) final {
    bool left = op->a.get() == path_[iter_];
    if (left) {
      result += op->b;
    } else {
      result -= op->a;
      result = - result;
      is_greater = !is_greater;
    }
    Visit(left ? op->a : op->b);
  }

  void Visit_(const Mul* op) final {
    bool left = op->a.get() == path_[iter_];
    Expr operand = left ? op->b : op->a;

    SignType sign;
    if (operand.type().is_uint()) {
      sign = kPositive;
    } else {
      sign = expr_map_[operand].sign_type();
    }

    if (sign == SignType::kNegative) {
      is_greater = !is_greater;
    } else if (sign == SignType::kUnknown) {
      // unable to get the sign of operand
      success = false;
      return;
    }

    // always use relax bound
150 151 152 153 154 155 156 157 158 159 160 161 162 163
    bool divided = can_prove(result % operand == 0);
    result = result / operand;
    // since system will round down when not divided
    // eg. 2/4 -> 0; -2/4 -> -1
    // no need fix for !is_greater:
    // eg. a <= 2/4 -> a <= 0
    // eg. a <= 0/4 -> a <= 0
    // so just fix for not divided and is_greater
    // eg. a >= 2/4 -> a >= 0 + 1
    // eg. a >= 0/4 -> a >= 0
    if (is_greater && !divided) {
       result += 1;
    }

164 165 166 167 168 169 170 171
    Visit(left ? op->a : op->b);
  }

  Expr result;
  bool is_greater{true};
  bool success{true};

 private:
172 173 174 175 176
  void Init();
  void Transform();
  void Relax();

  Expr target_;
177
  Expr expr_;
178 179
  const std::unordered_map<const Variable*, IntSet>& hint_map_;
  const std::unordered_map<const Variable*, IntSet>& relax_map_;
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
  ExprIntSetMap expr_map_;
  std::vector<const Node*> path_;
  size_t iter_{0};
};

class BoundDeduceInputChecker: public IRVisitor {
 public:
  bool Check(BoundDeducer* deducer) {
    deducer_ = deducer;
    Visit(deducer_->expr_);
    return target_count == 1;
  }

  void Visit(const NodeRef& e) final {
    if (e.same_as(deducer_->target_)) ++target_count;
    IRVisitor::Visit(e);
  }

 private:
  BoundDeducer* deducer_;
  size_t target_count{0};
};

203
void BoundDeducer::Init() {
204 205
  BoundDeduceInputChecker checker;
  if (!checker.Check(this)) success = false;
206 207
  Transform();
}
208

209
void BoundDeducer::Transform() {
210
  // We will ensure to set expr_ such that it contains target_
211
  if (const LT* op = expr_.as<LT>()) {
212 213 214 215 216 217 218 219 220 221 222
    if (GetPath(target_, op->a).empty()) {
      // a < b -> b >= a + 1
      is_greater = true;
      expr_ = op->b;
      result = op->a + 1;
    } else {
      // a < b -> a <= b - 1
      is_greater = false;
      expr_ = op->a;
      result = op->b - 1;
    }
223
  } else if (const LE* op = expr_.as<LE>()) {
224 225 226 227 228 229 230 231 232 233
    if (GetPath(target_, op->a).empty()) {
      // a <= b -> b >= a
      is_greater = true;
      expr_ = op->b;
      result = op->a;
    } else {
      is_greater = false;
      expr_ = op->a;
      result = op->b;
    }
234
  } else if (const GT* op = expr_.as<GT>()) {
235 236 237 238 239 240 241 242 243 244 245
    if (GetPath(target_, op->a).empty()) {
      // a > b -> b <= a - 1
      is_greater = false;
      expr_ = op->b;
      result = op->a - 1;
    } else {
      // a > b -> a >= b + 1
      is_greater = true;
      expr_ = op->a;
      result = op->b + 1;
    }
246
  } else if (const GE* op = expr_.as<GE>()) {
247 248 249 250 251 252 253 254 255 256
    if (GetPath(target_, op->a).empty()) {
      // a >= b -> b <= a
      is_greater = false;
      expr_ = op->b;
      result = op->a;
    } else {
      is_greater = true;
      expr_ = op->a;
      result = op->b;
    }
257 258 259 260 261 262 263 264
  } else {
    success = false;
  }
}

void BoundDeducer::Deduce() {
  Init();
  if (!success) return;
265
  Relax();
266
  if (!success) return;
267 268
  // get the path
  path_ = GetPath(target_, expr_);
269 270 271 272
  if (!path_.size()) {
    success = false;
    return;
  }
273

274
  expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
275 276 277 278

  Visit(expr_);
}

279
void BoundDeducer::Relax() {
280 281 282 283 284
  IntSet a = EvalSet(expr_, relax_map_);
  IntSet b = EvalSet(result, relax_map_);
  if (a.is_everything() || b.is_everything()) {
    success = false;
    return;
285
  }
286 287
  expr_  = is_greater ? a.min() : a.max();
  result = is_greater ? b.max() : b.min();
288 289 290 291 292 293
}

IntSet DeduceBound(Expr v, Expr e,
  const std::unordered_map<const Variable*, IntSet>& hint_map,
  const std::unordered_map<const Variable*, IntSet>& relax_map) {
  BoundDeducer d(v, e, hint_map, relax_map);
294 295 296 297
  d.Deduce();
  if (!d.success) return IntSet::nothing();
  Expr min = Interval::neg_inf, max = Interval::pos_inf;
  if (d.is_greater) {
298
    min = d.result;
299
  } else {
300
    max = d.result;
301 302 303 304
  }
  return IntSet::interval(min, max);
}

305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
IntSet DeduceBound(Expr v, Expr e,
                   const Map<Var, IntSet>& hint_map,
                   const Map<Var, IntSet>& relax_map) {
  std::unordered_map<const Variable*, IntSet> hmap;
  for (auto kv : hint_map) {
    hmap[kv.first.get()] = kv.second;
  }
  std::unordered_map<const Variable*, IntSet> rmap;
  for (auto kv : relax_map) {
    rmap[kv.first.get()] = kv.second;
  }
  return DeduceBound(v, e, hmap, rmap);
}

321 322
}  // namespace arith
}  // namespace tvm