Commit 86e56824 by Tianqi Chen Committed by Haichen Shen

[ARITH] More aggressive CSE during canonical simplify (#166)

parent 67a314c1
...@@ -93,6 +93,26 @@ struct ComExpr { ...@@ -93,6 +93,26 @@ struct ComExpr {
std::shared_ptr<ComExprNode> ptr_; std::shared_ptr<ComExprNode> ptr_;
}; };
// binary comparison op.
struct BinaryExpr {
int kind;
Expr lhs, rhs;
// comparator
bool operator<(const BinaryExpr& b) const {
if (kind < b.kind) return true;
if (kind > b.kind) return false;
if (lhs.get() < b.lhs.get()) return true;
if (lhs.get() > b.lhs.get()) return false;
return rhs.get() < b.rhs.get();
}
// equality
bool operator==(const BinaryExpr& b) const {
return kind == b.kind &&
lhs.same_as(b.lhs) &&
rhs.same_as(b.rhs);
}
};
template<typename T> template<typename T>
inline Expr Binary_(const T* op, inline Expr Binary_(const T* op,
const Expr& e, const Expr& e,
...@@ -104,12 +124,6 @@ inline Expr Binary_(const T* op, ...@@ -104,12 +124,6 @@ inline Expr Binary_(const T* op,
} }
} }
template<typename T>
inline Expr Binary(
const T* op, const Expr& e, IRMutator* m) {
return Binary_(op, e, m->Mutate(op->a), m->Mutate(op->b));
}
// internal of canonical engine. // internal of canonical engine.
class Canonical::Internal : public IRMutator { class Canonical::Internal : public IRMutator {
public: public:
...@@ -200,7 +214,7 @@ class Canonical::Internal : public IRMutator { ...@@ -200,7 +214,7 @@ class Canonical::Internal : public IRMutator {
// Add // Add
Expr Mutate_(const Add* op, const Expr& e) final { Expr Mutate_(const Add* op, const Expr& e) final {
if (!EnableOpt(op->type)) { if (!EnableOpt(op->type)) {
return Binary(op, e, this); return Binary(op, e);
} }
CacheEntry a = Produce(op->a); CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b); CacheEntry b = Produce(op->b);
...@@ -212,7 +226,7 @@ class Canonical::Internal : public IRMutator { ...@@ -212,7 +226,7 @@ class Canonical::Internal : public IRMutator {
// Sub // Sub
Expr Mutate_(const Sub* op, const Expr& e) final { Expr Mutate_(const Sub* op, const Expr& e) final {
if (!EnableOpt(op->type)) { if (!EnableOpt(op->type)) {
return Binary(op, e, this); return Binary(op, e);
} }
CacheEntry a = Produce(op->a); CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b); CacheEntry b = Produce(op->b);
...@@ -224,7 +238,7 @@ class Canonical::Internal : public IRMutator { ...@@ -224,7 +238,7 @@ class Canonical::Internal : public IRMutator {
// Mul // Mul
Expr Mutate_(const Mul* op, const Expr& e) final { Expr Mutate_(const Mul* op, const Expr& e) final {
if (!EnableOpt(op->type)) { if (!EnableOpt(op->type)) {
return Binary(op, e, this); return Binary(op, e);
} }
CacheEntry a = Produce(op->a); CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b); CacheEntry b = Produce(op->b);
...@@ -252,7 +266,7 @@ class Canonical::Internal : public IRMutator { ...@@ -252,7 +266,7 @@ class Canonical::Internal : public IRMutator {
// comparison // comparison
Expr Mutate_(const LT* op, const Expr& e) { Expr Mutate_(const LT* op, const Expr& e) {
if (!EnableOpt(op->a.type())) { if (!EnableOpt(op->a.type())) {
return Binary(op, e, this); return Binary(op, e);
} }
CacheEntry a = Produce(op->a); CacheEntry a = Produce(op->a);
CacheEntry b = Produce(op->b); CacheEntry b = Produce(op->b);
...@@ -266,6 +280,23 @@ class Canonical::Internal : public IRMutator { ...@@ -266,6 +280,23 @@ class Canonical::Internal : public IRMutator {
return Binary_(op, e, a.value, b.value); return Binary_(op, e, a.value, b.value);
} }
} }
// IntImm
Expr Mutate_(const IntImm* op, const Expr& e) final {
auto it = cache_intimm_.find(op->value);
if (it != cache_intimm_.end()) {
return it->second;
} else {
cache_intimm_[op->value] = e;
return e;
}
}
// binary ops
Expr Mutate_(const Div* op, const Expr& e) final {
return Binary(op, e);
}
Expr Mutate_(const Mod* op, const Expr& e) final {
return Binary(op, e);
}
// Call // Call
Expr Mutate_(const Call* op, const Expr& e) final { Expr Mutate_(const Call* op, const Expr& e) final {
if (!op->is_pure()) { if (!op->is_pure()) {
...@@ -309,12 +340,30 @@ class Canonical::Internal : public IRMutator { ...@@ -309,12 +340,30 @@ class Canonical::Internal : public IRMutator {
} }
private: private:
template<typename T>
Expr Binary(const T* op, const Expr& e) {
Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
BinaryExpr key{static_cast<int>(T::_type_info), a, b};
auto it = cache_binary_.find(key);
if (it != cache_binary_.end()) {
return it->second;
} else {
Expr ret = Binary_(op, e, a, b);
cache_binary_[key] = ret;
return ret;
}
}
// return entry // return entry
CacheEntry ret_entry_; CacheEntry ret_entry_;
// internal information stack // internal information stack
std::vector<StackEntry> stack_; std::vector<StackEntry> stack_;
// cache sum // cache sum
std::map<ComExpr, CacheEntry> cache_sum_; std::map<ComExpr, CacheEntry> cache_sum_;
// cache of normal binary op
std::map<BinaryExpr, Expr> cache_binary_;
// cache of int constant
std::unordered_map<int64_t, Expr> cache_intimm_;
// range of each var // range of each var
std::unordered_map<const Variable*, IntSet> var_range_; std::unordered_map<const Variable*, IntSet> var_range_;
// level of each var // level of each var
......
import tvm
def csimplify(z):
return tvm.ir_pass.CanonicalSimplify(
tvm.make.Evaluate(z)).value
def test_simplify():
x = tvm.var('n')
z = x * 4 - x * 2
zz = csimplify(z)
assert zz.b.value == 2
z = (x / 4) * 2 - (x / 4)
zz = csimplify(z)
assert zz.a == x and zz.b.value == 4
z = (x % 4) * 3 + (x % 4)
zz = csimplify(z)
assert zz.b.value == 4
zz = zz.a
assert zz.a == x and zz.b.value == 4
if __name__ == "__main__":
test_simplify()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment