detect_linear_equation.cc 7.72 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22
 * \file detect_linear_equation.cc
 * \brief Utility to detect patterns in the expression.
23
 */
24
#include <tvm/runtime/registry.h>
25 26 27 28
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/stmt_functor.h>
29
#include <tvm/arith/analyzer.h>
30 31 32 33

namespace tvm {
namespace arith {

34
using namespace tir;
35 36 37

// Linear equation, the components can be undefined.
struct LinearEqEntry {
38 39
  PrimExpr base;
  PrimExpr coeff;
40 41
};

42
struct IntervalEntry {
43 44
  PrimExpr min_value;
  PrimExpr max_value;
45 46
};

47
class LinearEqDetector
48
    : public ExprFunctor<LinearEqEntry(const PrimExpr&, const PrimExpr &)> {
49 50 51 52
 public:
  explicit LinearEqDetector(Var var)
      : var_(var) {}

53
  bool Detect(const PrimExpr& e, LinearEqEntry* ret) {
54 55 56
    *ret = VisitExpr(e, e);
    if (fail_) return false;
    if (!ret->base.defined()) {
57
      ret->base = make_zero(var_.dtype());
58
    }
59
    if (!ret->coeff.defined()) {
60
      ret->coeff = make_zero(var_.dtype());
61
    }
62
    return true;
63 64
  }

65
  LinearEqEntry VisitExpr_(const AddNode* op, const PrimExpr& e) final {
66 67 68 69 70 71 72 73
    if (fail_) return LinearEqEntry();
    LinearEqEntry a = VisitExpr(op->a, op->a);
    LinearEqEntry b = VisitExpr(op->b, op->b);
    LinearEqEntry ret;
    ret.base = AddCombine(a.base, b.base);
    ret.coeff = AddCombine(a.coeff, b.coeff);
    return ret;
  }
74

75
  LinearEqEntry VisitExpr_(const SubNode* op, const PrimExpr& e) final {
76 77 78 79 80 81 82 83 84
    if (fail_) return LinearEqEntry();
    LinearEqEntry a = VisitExpr(op->a, op->a);
    LinearEqEntry b = VisitExpr(op->b, op->b);
    LinearEqEntry ret;
    ret.base = SubCombine(a.base, b.base);
    ret.coeff = SubCombine(a.coeff, b.coeff);
    return ret;
  }

85
  LinearEqEntry VisitExpr_(const MulNode* op, const PrimExpr& e) final {
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    if (fail_) return LinearEqEntry();
    LinearEqEntry a = VisitExpr(op->a, op->a);
    LinearEqEntry b = VisitExpr(op->b, op->b);
    if (a.coeff.defined()) {
      std::swap(a, b);
    }
    if (a.coeff.defined()) {
      fail_ = true;
      return LinearEqEntry();
    }
    LinearEqEntry ret;
    ret.base = MulCombine(a.base, b.base);
    ret.coeff = MulCombine(a.base, b.coeff);
    return ret;
  }
101
  LinearEqEntry VisitExpr_(const VarNode* op, const PrimExpr& e) final {
102 103
    LinearEqEntry ret;
    if (op == var_.get()) {
104
      ret.coeff = make_const(op->dtype, 1);
105 106 107 108 109
    } else {
      ret.base = e;
    }
    return ret;
  }
110
  LinearEqEntry VisitExprDefault_(const Object* op, const PrimExpr& e) final {
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    if (fail_) return LinearEqEntry();
    if (ExprUseVar(e, var_)) {
      fail_ = true;
      return LinearEqEntry();
    } else {
      LinearEqEntry ret;
      ret.base = e;
      return ret;
    }
  }

 private:
  Var var_;
  bool fail_{false};
  // Combine by add
126
  PrimExpr AddCombine(PrimExpr a, PrimExpr b) {
127 128
    if (!a.defined()) return b;
    if (!b.defined()) return a;
129
    return a + b;
130
  }
131
  PrimExpr SubCombine(PrimExpr a, PrimExpr b) {
132
    // Check b first in case they are both undefined
133
    if (!b.defined()) return a;
134
    if (!a.defined()) return -b;
135
    return a - b;
136
  }
137
  PrimExpr MulCombine(PrimExpr a, PrimExpr b) {
138 139
    if (!a.defined()) return a;
    if (!b.defined()) return b;
140
    return a * b;
141 142 143
  }
};

144 145 146 147
Array<PrimExpr> DetectLinearEquation(const PrimExpr& e,
                                          const Array<Var>& vars) {
  PrimExpr base = e;
  Array<PrimExpr> coeff;
148

149 150 151
  for (Var v : vars) {
    LinearEqEntry ret;
    if (!LinearEqDetector(v).Detect(base, &ret)) {
152
      return Array<PrimExpr>();
153
    }
154 155 156
    coeff.push_back(ret.coeff);
    base = std::move(ret.base);
  }
157

158
  std::unordered_set<const VarNode*> vset;
159 160 161 162
  for (size_t i = vars.size(); i > 1; --i) {
    vset.insert(vars[i - 1].get());
    // The previous coeff contains the variable
    if (ExprUseVar(coeff[i - 2], vset)) {
163
      return Array<PrimExpr>();
164 165 166 167
    }
  }
  coeff.push_back(base);
  return coeff;
168 169
}

170 171
// Detect clip condition as min max value
bool DetectClipBound(
172
    const PrimExpr& cond,
173
    std::unordered_map<const VarNode*, IntervalEntry>* bmap) {
174 175
  int flag = 0;
  Var var;
176
  auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) {
177
    if (const VarNode* v = n.as<VarNode>()) {
178 179
      if (bmap->count(v)) {
        if (flag == 0) {
180
          var = Downcast<Var>(n);
181 182 183 184 185 186 187 188 189 190 191 192
          flag = 1;
        } else if (flag == 1) {
          if (!var.same_as(n)) {
            flag = -1;
          }
        }
      }
    }
  };
  PostOrderVisit(cond, fvisit);
  if (flag != 1) return false;
  // canonical form: exp >= 0
193
  PrimExpr canonical;
194
  if (const LTNode* op = cond.as<LTNode>()) {
195 196
    if (!op->a.dtype().is_int()) return false;
    canonical = op->b - op->a - make_const(op->a.dtype(), 1);
197
  } else if (const LENode* op = cond.as<LENode>()) {
198
    if (!op->a.dtype().is_int()) return false;
199
    canonical = op->b - op->a;
200
  } else if (const GTNode* op = cond.as<GTNode>()) {
201 202
    if (!op->a.dtype().is_int()) return false;
    canonical = op->a - op->b - make_const(op->a.dtype(), 1);
203
  } else if (const GENode* op = cond.as<GENode>()) {
204
    if (!op->a.dtype().is_int()) return false;
205 206 207 208 209 210 211 212
    canonical = op->a - op->b;
  } else {
    return false;
  }
  LinearEqEntry ret;
  if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
  ret.coeff = Simplify(ret.coeff);
  IntervalEntry& p = (*bmap)[var.get()];
213
  if (is_const_int(ret.coeff, 1)) {
214 215
    // var + shift >=0 -> var >= -shift
    if (p.min_value.defined()) {
216
      p.min_value = tir::MaxNode::make(p.min_value, -ret.base);
217 218 219 220 221
    } else {
      p.min_value = -ret.base;
    }
    return true;
  }
222
  if (is_const_int(ret.coeff, -1)) {
223 224
    // -var + shift >=0 -> var <= shift
    if (p.max_value.defined()) {
225
      p.max_value = tir::MinNode::make(p.max_value, ret.base);
226 227 228 229 230 231 232 233 234 235
    } else {
      p.max_value = ret.base;
    }
    return true;
  }
  return false;
}


template<typename OP>
236
void SplitCommExpr(const PrimExpr& e, std::vector<PrimExpr>* ret) {
237 238 239 240 241 242 243 244 245 246
  if (const OP* op = e.as<OP>()) {
    SplitCommExpr<OP>(op->a, ret);
    SplitCommExpr<OP>(op->b, ret);
  } else {
    ret->push_back(e);
  }
}

// Detect the lower and upper bound from the expression.
// e must be connected by and.
247 248
Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) {
  std::vector<PrimExpr> splits;
249
  SplitCommExpr<tir::AndNode>(e, &splits);
250
  std::unordered_map<const VarNode*, IntervalEntry> rmap;
251 252 253
  for (Var v : vars) {
    rmap[v.get()] = IntervalEntry();
  }
254 255
  for (PrimExpr cond : splits) {
    if (!DetectClipBound(cond, &rmap)) return Array<PrimExpr>();
256
  }
257
  Array<PrimExpr> ret;
258 259 260 261 262 263 264 265 266 267 268 269 270 271
  for (Var v : vars) {
    IntervalEntry e = rmap[v.get()];
    if (e.min_value.defined()) {
      e.min_value = Simplify(e.min_value);
    }
    if (e.max_value.defined()) {
      e.max_value = Simplify(e.max_value);
    }
    ret.push_back(e.min_value);
    ret.push_back(e.max_value);
  }
  return ret;
}

272 273
TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
.set_body_typed(DetectLinearEquation);
274

275 276 277 278
TVM_REGISTER_GLOBAL("arith.DetectClipBound")
.set_body_typed([](const PrimExpr& e, const Array<Var>& vars) {
  return DetectClipBound(e, vars);
});
279 280
}  // namespace arith
}  // namespace tvm