Unverified Commit 75c29c6a by Tianqi Chen Committed by GitHub

[ARITH] Improve min/max/div cases in RewriteSimplify (#3463)

[PASS] Use new infra for lower warp memory
[ARITH] EvalSet recursively evaluates set in case dom_map contains set that need to be relaxed.
parent 5d068357
......@@ -305,6 +305,16 @@ class IntervalSetEvaluator :
IntervalSet Eval(const Expr& val) {
return this->VisitExpr(val);
}
// evaluate and relax the set
IntervalSet Eval(IntervalSet val) {
// avoid recursive indefinite recursive expansion.
if (static_cast<size_t>(recur_depth_) >= dom_map_.size()) return val;
++recur_depth_;
IntervalSet min_set = this->Eval(val->min_value);
IntervalSet max_set = this->Eval(val->max_value);
--recur_depth_;
return IntervalSet(min_set->min_value, max_set->max_value);
}
IntervalSet VisitExpr_(const IntImm* op) final {
return IntervalSet::SinglePoint(GetRef<Expr>(op));
......@@ -318,7 +328,14 @@ class IntervalSetEvaluator :
Var var = GetRef<Var>(op);
auto it = dom_map_.find(var);
if (it != dom_map_.end()) {
return ToIntervalSet((*it).second);
IntervalSet res = ToIntervalSet((*it).second);
if (res->min_value.same_as(var) &&
res->max_value.same_as(var)) {
return res;
}
// recursively evaluate mapped result
// in case the domain contains variables to be relaxed.
return Eval(res);
} else {
return IntervalSet::SinglePoint(var);
}
......@@ -440,6 +457,9 @@ class IntervalSetEvaluator :
return Combine<T>(analyzer_, a, b);
}
// recursive depth
int recur_depth_{0};
// analyzer
Analyzer* analyzer_;
const Map<Var, IntSet>& dom_map_;
bool eval_vec_{false};
......@@ -662,13 +682,10 @@ IntSet EvalSet(Range r,
const Map<Var, IntSet>& dom_map) {
Analyzer ana;
IntervalSetEvaluator m(&ana, dom_map);
IntervalSet min_set = m.Eval(r->min);
// Simplifying first can give tighter bounds if r->min and r->extent share variables
Expr sum = r->min + r->extent - 1;
IntervalSet max_set = m.Eval(Simplify(sum));
if (!min_set->HasLowerBound()) return IntSet::everything();
if (!max_set->HasUpperBound()) return IntSet::everything();
return IntervalSet(min_set->min_value, max_set->max_value);
auto res = m.Eval(IntervalSet(r->min, Simplify(sum)));
return res;
}
IntSet EvalSet(Range r,
......
......@@ -147,6 +147,17 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_REWRITE(min(x - z, y) + z, min(x, y + z));
TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y));
TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z));
TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(y + z * c1, x) + z * c2, min(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(y + z * c1, x) + z * c2, max(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y);
TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y);
TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y);
......@@ -265,6 +276,11 @@ Mutate_(const Sub* op, const Expr& self) {
TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y));
TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y));
TVM_TRY_REWRITE(max(x + y, z) - x, max(y, z - x));
TVM_TRY_REWRITE(max(y + x, z) - x, max(y, z - x));
TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y));
TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y));
TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z));
TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z));
TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y));
......@@ -397,6 +413,12 @@ Mutate_(const Div* op, const Expr& self) {
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// x / 2.0 = x * 0.5
if (const FloatImm* ptr = op->b.as<FloatImm>()) {
CHECK(op->type.is_float());
return op->a * make_const(op->b.type(), 1.0 / ptr->value);
}
// Vector rules
if (op->type.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) / broadcast(y, lanes),
......
......@@ -80,8 +80,11 @@ namespace ir {
class WarpStoreCoeffFinder : private IRVisitor {
public:
WarpStoreCoeffFinder(const Variable* buffer,
Var warp_index)
: buffer_(buffer), warp_index_(warp_index) {
Var warp_index,
arith::Analyzer* analyzer)
: buffer_(buffer),
warp_index_(warp_index),
analyzer_(analyzer) {
}
// find the warp co-efficient in the statement given the warp size
int Find(const Stmt& stmt) {
......@@ -113,7 +116,7 @@ class WarpStoreCoeffFinder : private IRVisitor {
CHECK_EQ(m.size(), 2U)
<< "LowerWarpMemory failed due to store index=" << index;
int coeff = 0;
Expr mcoeff = ir::Simplify(m[0]);
Expr mcoeff = analyzer_->canonical_simplify(m[0]);
CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0)
<< "LowerWarpMemory failed due to store index=" << index
......@@ -134,6 +137,8 @@ class WarpStoreCoeffFinder : private IRVisitor {
Var warp_index_;
// the coefficient
int warp_coeff_{0};
// analyzer.
arith::Analyzer* analyzer_;
};
......@@ -184,8 +189,8 @@ class WarpIndexFinder : private IRVisitor {
// Mutator to change the read pattern
class WarpAccessRewriter : protected IRMutator {
public:
explicit WarpAccessRewriter(int warp_size)
: warp_size_(warp_size) {}
explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer)
: warp_size_(warp_size), analyzer_(analyzer) {}
// Rewrite the allocate statement which transforms
// warp memory to local memory.
Stmt Rewrite(const Allocate* op, const Stmt& stmt) {
......@@ -196,7 +201,7 @@ class WarpAccessRewriter : protected IRMutator {
alloc_size *= op->type.lanes();
warp_index_ = WarpIndexFinder(warp_size_).Find(op->body)->var;
warp_coeff_ = WarpStoreCoeffFinder(
buffer_, warp_index_).Find(op->body);
buffer_, warp_index_, analyzer_).Find(op->body);
CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0)
<< "Warp memory must be multiple of warp size";
warp_group_ = alloc_size / (warp_size_ * warp_coeff_);
......@@ -258,21 +263,19 @@ class WarpAccessRewriter : protected IRMutator {
return std::make_pair(local_index, group);
}
Expr m = make_const(index.type(), warp_coeff_);
Range rng = Range::make_by_min_extent(
make_zero(index.type()), make_const(index.type(), warp_size_));
Map<Var, Range> vrange({{warp_index_, rng}});
// simple case, warp index is on the highest.
if (warp_group_ == 1) {
Expr x = Simplify(index % m, vrange);
Expr z = Simplify(index / m, vrange);
Expr x = analyzer_->canonical_simplify(index % m);
Expr z = analyzer_->canonical_simplify(index / m);
return std::make_pair(x, z);
} else {
Expr x = Simplify(index % m, vrange);
Expr x = analyzer_->canonical_simplify(index % m);
Expr y = index / make_const(index.type(), warp_coeff_ * warp_size_);
y = y * m + x;
Expr z = index % make_const(index.type(), warp_coeff_ * warp_size_) / m;
return std::make_pair(Simplify(y, vrange), Simplify(z, vrange));
return std::make_pair(analyzer_->canonical_simplify(y),
analyzer_->canonical_simplify(z));
}
}
......@@ -287,6 +290,44 @@ class WarpAccessRewriter : protected IRMutator {
int warp_coeff_{0};
// the coefficient n
int warp_group_{0};
// Internal analyzer
arith::Analyzer* analyzer_;
};
// Bind bound information of variables to make analyzer more effective
// TODO(tqchen): consider a pass to inline the bound info into the expr
// so analysis can be context independent.
class BindVarBoundInfo : public IRVisitor {
public:
explicit BindVarBoundInfo(arith::Analyzer* analyzer)
: analyzer_(analyzer) {}
void Visit_(const For* op) final {
Var loop_var(op->loop_var.node_);
analyzer_->Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
IRVisitor::Visit_(op);
}
void Visit_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
if (!var_dom_.count(iv->var.get())) {
Range dom = Range::make_by_min_extent(0, op->value);
var_dom_[iv->var.get()] = dom;
analyzer_->Bind(iv->var, dom);
}
}
IRVisitor::Visit_(op);
}
protected:
// internal analyzer.
arith::Analyzer* analyzer_;
// variable domain
std::unordered_map<const Variable*, Range> var_dom_;
};
// Mutator to change the read pattern
......@@ -298,6 +339,7 @@ class WarpMemoryRewriter : private IRMutator {
Stmt Rewrite(Stmt stmt) {
if (warp_size_ == 1) return stmt;
BindVarBoundInfo(&analyzer_).Visit(stmt);
stmt = this->Mutate(stmt);
stmt = CanonicalSimplify(stmt);
return stmt;
......@@ -306,7 +348,7 @@ class WarpMemoryRewriter : private IRMutator {
private:
Stmt Mutate_(const Allocate* op, const Stmt& stmt) {
if (warp_buffer_.count(op->buffer_var.get())) {
WarpAccessRewriter rewriter(warp_size_);
WarpAccessRewriter rewriter(warp_size_, &analyzer_);
return rewriter.Rewrite(op, stmt);
} else {
return IRMutator::Mutate_(op, stmt);
......@@ -331,6 +373,9 @@ class WarpMemoryRewriter : private IRMutator {
int warp_size_{0};
std::unordered_set<const Variable*> warp_buffer_;
arith::Analyzer analyzer_;
// variable domain
std::unordered_map<const Variable*, Range> var_dom_;
};
LoweredFunc
......
......@@ -151,6 +151,12 @@ def test_add_index_simplify():
ck.verify(tvm.min(x, y + 2) + (-2), tvm.min(x + (-2), y));
ck.verify(tvm.min(x + 2, y + 3) + (-2), tvm.min(x, y + 1));
ck.verify(tvm.max(0, 1 - x * 4) + x * 4, tvm.max(x * 4, 1))
ck.verify(tvm.max(2 - x * 4, 0) + x * 4, tvm.max(x * 4, 2))
ck.verify(tvm.min(0, 1 - x * 4) + x * 4, tvm.min(x * 4, 1))
ck.verify(tvm.min(2 - x * 4, 0) + x * 4, tvm.min(x * 4, 2))
ck.verify(x * y + x * 10, x * (y + 10))
ck.verify(y * x + x * 10, x * (y + 10))
ck.verify(y * x + 10 * x, x * (y + 10))
......@@ -212,6 +218,11 @@ def test_sub_index_simplify():
ck.verify(tvm.min(z, x + y) - x, tvm.min(z - x, y))
ck.verify(tvm.min(z, y + x) - x, tvm.min(z - x, y))
ck.verify(tvm.max(x + y, z) - x, tvm.max(y, z - x))
ck.verify(tvm.max(y + x, z) - x, tvm.max(y, z - x))
ck.verify(tvm.max(z, x + y) - x, tvm.max(z - x, y))
ck.verify(tvm.max(z, y + x) - x, tvm.max(z - x, y))
ck.verify(x - tvm.min(x + y, z), tvm.max(0 - y, x - z))
ck.verify(x - tvm.min(y + x, z), tvm.max(0 - y, x - z))
ck.verify(x - tvm.min(z, x + y), tvm.max(x - z, 0 - y))
......
......@@ -80,7 +80,6 @@ def test_single_point_test():
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def assert_expr_equal(a, b):
print(a, b)
assert tvm.ir_pass.Simplify(a - b).value == 0
def test_copy_pad_split():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment