Unverified Commit 4273e461 by Tianqi Chen Committed by GitHub

Migrate simplifier to new infra. (#3368)

parent f2a6851a
...@@ -154,7 +154,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS ...@@ -154,7 +154,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
file(GLOB TOPI_SRCS file(GLOB TOPI_SRCS
topi/src/*.cc topi/src/*.cc
) )
file(GLOB_RECURSE HALIDEIR_SRCS 3rdparty/HalideIR/src/*.cpp) file(GLOB_RECURSE HALIDEIR_SRCS
3rdparty/HalideIR/src/base/*.cpp
3rdparty/HalideIR/src/ir/*.cpp
3rdparty/HalideIR/src/tvm/*.cpp
)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS}) list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
file(GLOB RUNTIME_SRCS file(GLOB RUNTIME_SRCS
src/runtime/*.cc src/runtime/*.cc
......
...@@ -623,12 +623,15 @@ IntSet Intersect(const Array<IntSet>& sets); ...@@ -623,12 +623,15 @@ IntSet Intersect(const Array<IntSet>& sets);
* give the domain of each variables. Return undefined IntSet to * give the domain of each variables. Return undefined IntSet to
* represent failure. * represent failure.
* *
* \note The returned set may be smaller than set that
* contains all possible values of v that satisfies the bound.
*
* \param v The target variable to be deduced. * \param v The target variable to be deduced.
* \param cond The conditional expression. * \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce. * \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain, * \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map * The deduce bound must implies e for all value in relax_map
* \return An integer set that can cover all the possible values. * \return An integer set that always satisfies the condition.
*/ */
IntSet DeduceBound(Expr v, Expr cond, IntSet DeduceBound(Expr v, Expr cond,
const Map<Var, IntSet>& hint_map, const Map<Var, IntSet>& hint_map,
...@@ -641,7 +644,7 @@ IntSet DeduceBound(Expr v, Expr cond, ...@@ -641,7 +644,7 @@ IntSet DeduceBound(Expr v, Expr cond,
* \param hint_map The domain of variable, used to help deduce. * \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain, * \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map * The deduce bound mush implies e for all value in relax_map
* \return An integer set that can cover all the possible values. * \return An integer set that always satisfies the condition.
*/ */
IntSet DeduceBound(Expr v, Expr cond, IntSet DeduceBound(Expr v, Expr cond,
const std::unordered_map<const Variable*, IntSet>& hint_map, const std::unordered_map<const Variable*, IntSet>& hint_map,
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#ifndef TVM_IR_PASS_H_ #ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_ #define TVM_IR_PASS_H_
#include <arithmetic/Simplify.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
......
...@@ -106,6 +106,7 @@ bool Analyzer::CanProve(const Expr& expr) { ...@@ -106,6 +106,7 @@ bool Analyzer::CanProve(const Expr& expr) {
Expr Analyzer::Simplify(const Expr& expr) { Expr Analyzer::Simplify(const Expr& expr) {
if (is_const(expr)) return expr; if (is_const(expr)) return expr;
auto res = this->rewrite_simplify(expr); auto res = this->rewrite_simplify(expr);
if (is_const(res)) return res;
res = this->canonical_simplify(res); res = this->canonical_simplify(res);
return res; return res;
} }
......
...@@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor { ...@@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor {
void Deduce(); void Deduce();
void Visit(const NodeRef& e) final { void Visit(const NodeRef& e) final {
if (!success) return; if (!success_) return;
if (e.get() == path_[iter_++]) { if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e); IRVisitor::Visit(e);
} else { } else {
success = false; success_ = false;
return; return;
} }
} }
...@@ -111,18 +111,18 @@ class BoundDeducer: public IRVisitor { ...@@ -111,18 +111,18 @@ class BoundDeducer: public IRVisitor {
void Visit_(const Add* op) final { void Visit_(const Add* op) final {
bool left = op->a.get() == path_[iter_]; bool left = op->a.get() == path_[iter_];
result -= left ? op->b : op->a; result_ -= left ? op->b : op->a;
Visit(left ? op->a : op->b); Visit(left ? op->a : op->b);
} }
void Visit_(const Sub* op) final { void Visit_(const Sub* op) final {
bool left = op->a.get() == path_[iter_]; bool left = op->a.get() == path_[iter_];
if (left) { if (left) {
result += op->b; result_ += op->b;
} else { } else {
result -= op->a; result_ -= op->a;
result = - result; result_ = - result_;
is_greater = !is_greater; is_greater_ = !is_greater_;
} }
Visit(left ? op->a : op->b); Visit(left ? op->a : op->b);
} }
...@@ -130,43 +130,65 @@ class BoundDeducer: public IRVisitor { ...@@ -130,43 +130,65 @@ class BoundDeducer: public IRVisitor {
void Visit_(const Mul* op) final { void Visit_(const Mul* op) final {
bool left = op->a.get() == path_[iter_]; bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a; Expr operand = left ? op->b : op->a;
Expr target_var = left ? op->a : op->b;
SignType sign; SignType sign_operand;
if (operand.type().is_uint()) { if (operand.type().is_uint()) {
sign = kPositive; sign_operand = kPositive;
} else { } else {
sign = expr_map_[operand].sign_type(); sign_operand = expr_map_[operand].sign_type();
} }
if (sign == SignType::kNegative) { if (sign_operand == SignType::kNegative) {
is_greater = !is_greater; is_greater_ = !is_greater_;
} else if (sign == SignType::kUnknown) { } else if (sign_operand == SignType::kUnknown) {
// unable to get the sign of operand // unable to get the sign of operand
success = false; success_ = false;
return; return;
} }
// always use relax bound // always use relax bound
bool divided = can_prove(result % operand == 0); bool divided = analyzer_.CanProve(result_ % operand == 0);
result = result / operand;
// since system will round down when not divided result_ = result_ / operand;
// eg. 2/4 -> 0; -2/4 -> -1
// no need fix for !is_greater:
// eg. a <= 2/4 -> a <= 0
// eg. a <= 0/4 -> a <= 0
// so just fix for not divided and is_greater
// eg. a >= 2/4 -> a >= 0 + 1
// eg. a >= 0/4 -> a >= 0
if (is_greater && !divided) {
result += 1;
}
if (!divided) {
// Handle non-divisible case
// NOTE: this accounts for truc div behavior.
bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative();
if (is_greater_) {
result_ += 1;
} else {
// NOTE: this is a bit sutble hack.
//
// condition:
// - x * operand <= result
// - operand > 0
// - x >= 0
//
// Then it is fine to deduce that x <= result / operand.
// - if result > 0, this division round down
// - if result < 0, (result / operand) rounds up and may violate the constraint
// however, given that x is always non-negative,
// it is fine to have this relaxed bound, given that the user of deduce bound
// will respect the bound of x
//
// TODO(tvm-team): think about a better API to incorporate constraint of x.
// e.g. specify an interval of x and return a bound
// that is in the interval and satisfies the condition.
if (target_is_non_neg && sign_operand == kPositive) {
// do nothing
} else {
result_ -= 1;
}
}
}
Visit(left ? op->a : op->b); Visit(left ? op->a : op->b);
} }
Expr result; Expr result_;
bool is_greater{true}; bool is_greater_{true};
bool success{true}; bool success_{true};
private: private:
void Init(); void Init();
...@@ -180,6 +202,8 @@ class BoundDeducer: public IRVisitor { ...@@ -180,6 +202,8 @@ class BoundDeducer: public IRVisitor {
ExprIntSetMap expr_map_; ExprIntSetMap expr_map_;
std::vector<const Node*> path_; std::vector<const Node*> path_;
size_t iter_{0}; size_t iter_{0};
// internal analzyer
Analyzer analyzer_;
}; };
class BoundDeduceInputChecker: public IRVisitor { class BoundDeduceInputChecker: public IRVisitor {
...@@ -202,7 +226,7 @@ class BoundDeduceInputChecker: public IRVisitor { ...@@ -202,7 +226,7 @@ class BoundDeduceInputChecker: public IRVisitor {
void BoundDeducer::Init() { void BoundDeducer::Init() {
BoundDeduceInputChecker checker; BoundDeduceInputChecker checker;
if (!checker.Check(this)) success = false; if (!checker.Check(this)) success_ = false;
Transform(); Transform();
} }
...@@ -211,66 +235,65 @@ void BoundDeducer::Transform() { ...@@ -211,66 +235,65 @@ void BoundDeducer::Transform() {
if (const LT* op = expr_.as<LT>()) { if (const LT* op = expr_.as<LT>()) {
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1 // a < b -> b >= a + 1
is_greater = true; is_greater_ = true;
expr_ = op->b; expr_ = op->b;
result = op->a + 1; result_ = op->a + 1;
} else { } else {
// a < b -> a <= b - 1 // a < b -> a <= b - 1
is_greater = false; is_greater_ = false;
expr_ = op->a; expr_ = op->a;
result = op->b - 1; result_ = op->b - 1;
} }
} else if (const LE* op = expr_.as<LE>()) { } else if (const LE* op = expr_.as<LE>()) {
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a // a <= b -> b >= a
is_greater = true; is_greater_ = true;
expr_ = op->b; expr_ = op->b;
result = op->a; result_ = op->a;
} else { } else {
is_greater = false; is_greater_ = false;
expr_ = op->a; expr_ = op->a;
result = op->b; result_ = op->b;
} }
} else if (const GT* op = expr_.as<GT>()) { } else if (const GT* op = expr_.as<GT>()) {
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1 // a > b -> b <= a - 1
is_greater = false; is_greater_ = false;
expr_ = op->b; expr_ = op->b;
result = op->a - 1; result_ = op->a - 1;
} else { } else {
// a > b -> a >= b + 1 // a > b -> a >= b + 1
is_greater = true; is_greater_ = true;
expr_ = op->a; expr_ = op->a;
result = op->b + 1; result_ = op->b + 1;
} }
} else if (const GE* op = expr_.as<GE>()) { } else if (const GE* op = expr_.as<GE>()) {
if (GetPath(target_, op->a).empty()) { if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a // a >= b -> b <= a
is_greater = false; is_greater_ = false;
expr_ = op->b; expr_ = op->b;
result = op->a; result_ = op->a;
} else { } else {
is_greater = true; is_greater_ = true;
expr_ = op->a; expr_ = op->a;
result = op->b; result_ = op->b;
} }
} else { } else {
success = false; success_ = false;
} }
} }
void BoundDeducer::Deduce() { void BoundDeducer::Deduce() {
Init(); Init();
if (!success) return; if (!success_) return;
Relax(); Relax();
if (!success) return; if (!success_) return;
// get the path // get the path
path_ = GetPath(target_, expr_); path_ = GetPath(target_, expr_);
if (!path_.size()) { if (!path_.size()) {
success = false; success_ = false;
return; return;
} }
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
Visit(expr_); Visit(expr_);
...@@ -278,13 +301,13 @@ void BoundDeducer::Deduce() { ...@@ -278,13 +301,13 @@ void BoundDeducer::Deduce() {
void BoundDeducer::Relax() { void BoundDeducer::Relax() {
IntSet a = EvalSet(expr_, relax_map_); IntSet a = EvalSet(expr_, relax_map_);
IntSet b = EvalSet(result, relax_map_); IntSet b = EvalSet(result_, relax_map_);
if (a.is_everything() || b.is_everything()) { if (a.is_everything() || b.is_everything()) {
success = false; success_ = false;
return; return;
} }
expr_ = is_greater ? a.min() : a.max(); expr_ = is_greater_ ? a.min() : a.max();
result = is_greater ? b.max() : b.min(); result_ = is_greater_ ? b.max() : b.min();
} }
IntSet DeduceBound(Expr v, Expr e, IntSet DeduceBound(Expr v, Expr e,
...@@ -292,12 +315,12 @@ IntSet DeduceBound(Expr v, Expr e, ...@@ -292,12 +315,12 @@ IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& relax_map) { const std::unordered_map<const Variable*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map); BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce(); d.Deduce();
if (!d.success) return IntSet::nothing(); if (!d.success_) return IntSet::nothing();
Expr min = neg_inf(), max = pos_inf(); Expr min = neg_inf(), max = pos_inf();
if (d.is_greater) { if (d.is_greater_) {
min = d.result; min = d.result_;
} else { } else {
max = d.result; max = d.result_;
} }
return IntSet::interval(min, max); return IntSet::interval(min, max);
} }
......
...@@ -155,9 +155,10 @@ template<> ...@@ -155,9 +155,10 @@ template<>
inline Expr TryConstFold<ir::Div>(Expr a, Expr b) { inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({ TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type(); const Type& rtype = a.type();
if (pa && pb) {
// due to division and mod can have different modes // due to division and mod can have different modes
// only constant fold positive number where rule is fixed. // NOTE: this will assumes truc div.
if (pa && pb && pa->value >= 0 && pb->value > 0) { CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm::make(rtype, pa->value / pb->value); return IntImm::make(rtype, pa->value / pb->value);
} }
if (pa) { if (pa) {
......
...@@ -155,7 +155,6 @@ Mutate_(const Add* op, const Expr& self) { ...@@ -155,7 +155,6 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y)); TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y));
TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z)); TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z));
TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y), TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value); c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y), TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y),
......
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include "arithmetic/Simplify.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -158,42 +157,18 @@ Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) { ...@@ -158,42 +157,18 @@ Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
return analyzer.canonical_simplify(expr); return analyzer.canonical_simplify(expr);
} }
template<typename T> Expr Simplify(Expr expr, Map<Var, Range> vrange) {
T Simplify_(T a, Map<Var, Range> vrange) { arith::Analyzer analyzer;
using namespace HalideIR::Internal;
Scope<Interval> rscope;
for (auto kv : vrange) { for (auto kv : vrange) {
Range r = kv.second; analyzer.Bind(kv.first, kv.second);
rscope.push(
kv.first.get(),
Interval(r->min,
simplify(r->min + r->extent - make_const(r->min.type(), 1))));
}
return HalideIR::Internal::simplify(a, true, rscope);
}
Expr Simplify(Expr a, Map<Var, Range> vrange) {
// Simplify top level reduce.
if (const Reduce* r = a.as<Reduce>()) {
Array<Expr> new_source;
for (auto& e : r->source) {
new_source.push_back(Simplify_(e, vrange));
}
Expr new_condition = Simplify_(r->condition, vrange);
if (r->source.same_as(new_source) &&
r->condition.same_as(new_condition)) {
return a;
} else {
return Reduce::make(
r->combiner, new_source, r->axis, new_condition, r->value_index);
}
} }
return Simplify_(a, vrange); expr = analyzer.Simplify(expr);
return expr;
} }
Stmt Simplify(Stmt a, Map<Var, Range> vrange) { Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return Simplify_(a, vrange); return arith::CanonicalStmtSimplifier().CanonicalSimplify(
stmt, vrange);
} }
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <iterator> #include <iterator>
#include <stack>
#include "../arithmetic/compute_expr.h" #include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
......
...@@ -80,7 +80,7 @@ Operation ScanOpNode::make(std::string name, ...@@ -80,7 +80,7 @@ Operation ScanOpNode::make(std::string name,
for (size_t i = 0; i < init.size(); ++i) { for (size_t i = 0; i < init.size(); ++i) {
CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype); CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
CHECK_EQ(init[i]->dtype, update[i]->dtype); CHECK_EQ(init[i]->dtype, update[i]->dtype);
CHECK(can_prove(init[i]->shape[0] == axis->dom->min)) CHECK(prove_equal(init[i]->shape[0], axis->dom->min))
<< "init.shape[0] need to match scan_axis.dom.min"; << "init.shape[0] need to match scan_axis.dom.min";
CHECK(prove_equal( CHECK(prove_equal(
state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
......
...@@ -466,8 +466,13 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -466,8 +466,13 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Stmt body, Stmt body,
bool partition_thread_scope) { bool partition_thread_scope) {
using namespace arith; using namespace arith;
// include hint of var.
hint_map_.insert({var.get(), IntSet::interval(min, max)});
PartitionFinder finder(var, hint_map_, relax_map_); PartitionFinder finder(var, hint_map_, relax_map_);
finder.Visit(body); finder.Visit(body);
hint_map_.erase(var.get());
if (finder.partitions.empty()) return Stmt(); if (finder.partitions.empty()) return Stmt();
arith::IntervalSet for_interval(min, max); arith::IntervalSet for_interval(min, max);
...@@ -504,9 +509,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -504,9 +509,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool pre_stmt_recurse = true; bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) { if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min()); body_begin = ir::Simplify(middle_interval.min());
if (!can_prove(body_begin == min)) { if (!analyzer_.CanProve(body_begin == min)) {
Expr cond = (body_begin - min >= 0); Expr cond = (body_begin - min >= 0);
if (!can_prove(cond)) { if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop"; << ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min); body_begin = Max::make(body_begin, min);
...@@ -529,10 +534,10 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -529,10 +534,10 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
bool post_stmt_recurse = true; bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) { if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1); post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
if (!can_prove(middle_interval.max() == max)) { if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative // require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0); Expr cond = (max - post_doubt_begin + 1 >= 0);
if (!can_prove(cond)) { if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop"; << ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max); post_doubt_begin = Min::make(post_doubt_begin, max);
...@@ -554,7 +559,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -554,7 +559,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
// Generating code for middle subrange // Generating code for middle subrange
if (!partition_thread_scope) { if (!partition_thread_scope) {
Stmt mid_stmt; Stmt mid_stmt;
if (!can_prove(body_begin >= post_doubt_begin)) { if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) {
// [body_begin, post_doubt_begin) // [body_begin, post_doubt_begin)
Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body); Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
...@@ -576,8 +581,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -576,8 +581,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
s = AppendStmts(s, post_stmt); s = AppendStmts(s, post_stmt);
} else { } else {
Expr cond = const_true(); Expr cond = const_true();
if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin); if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt); s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt);
} }
s = ConvertSSA(s); s = ConvertSSA(s);
...@@ -587,7 +592,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, ...@@ -587,7 +592,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) { inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) {
const For *for_node = static_cast<const For*>(node); const For *for_node = static_cast<const For*>(node);
CHECK(for_node); CHECK(for_node);
if (can_prove(extent == make_const(Int(32), 1))) { if (analyzer_.CanProve(extent == make_const(Int(32), 1))) {
// If the loop extent is 1, do not create the loop anymore // If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}}); return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}});
} else { } else {
......
...@@ -200,7 +200,7 @@ class ChannelAccessRewriter : public IRMutator { ...@@ -200,7 +200,7 @@ class ChannelAccessRewriter : public IRMutator {
Expr base = linear_eq[1]; Expr base = linear_eq[1];
if (!is_zero(base)) return body; if (!is_zero(base)) return body;
Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent); Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent);
if (!can_prove(left >= 0)) return body; if (!analyzer_.CanProve(left >= 0)) return body;
// rewrite access index. // rewrite access index.
ChannelAccessIndexRewriter rw( ChannelAccessIndexRewriter rw(
ch->handle_var.get(), var * coeff, read_access); ch->handle_var.get(), var * coeff, read_access);
...@@ -233,6 +233,7 @@ class ChannelAccessRewriter : public IRMutator { ...@@ -233,6 +233,7 @@ class ChannelAccessRewriter : public IRMutator {
return body; return body;
} }
arith::Analyzer analyzer_;
std::vector<RewriteEntry> tasks_; std::vector<RewriteEntry> tasks_;
}; };
......
...@@ -606,7 +606,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -606,7 +606,7 @@ class StoragePlanRewriter : public IRMutator {
} }
// transform to alloc bytes // transform to alloc bytes
auto type_bits = alloc_type.bits() * alloc_type.lanes(); auto type_bits = alloc_type.bits() * alloc_type.lanes();
bool divided = can_prove(combo_size % type_bits == 0); bool divided = analyzer_.CanProve(combo_size % type_bits == 0);
combo_size = combo_size / type_bits; combo_size = combo_size / type_bits;
// round up for can not divided // round up for can not divided
if (!divided) { if (!divided) {
...@@ -920,6 +920,8 @@ class StoragePlanRewriter : public IRMutator { ...@@ -920,6 +920,8 @@ class StoragePlanRewriter : public IRMutator {
std::unordered_map<const Variable*, StorageEntry*> alloc_map_; std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
// The allocations // The allocations
std::vector<std::unique_ptr<StorageEntry> > alloc_vec_; std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
// analyzer
arith::Analyzer analyzer_;
}; };
// Turn alloc into vector alloc // Turn alloc into vector alloc
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/arithmetic.h>
#include <unordered_set> #include <unordered_set>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -132,11 +133,11 @@ class Vectorizer : public IRMutator { ...@@ -132,11 +133,11 @@ class Vectorizer : public IRMutator {
if (lanes != 1) { if (lanes != 1) {
const Ramp* b_ramp = b.as<Ramp>(); const Ramp* b_ramp = b.as<Ramp>();
const Ramp* a_ramp = a.as<Ramp>(); const Ramp* a_ramp = a.as<Ramp>();
if (a_ramp && b.type().lanes() == 1 && can_prove(b > 0)) { if (a_ramp && b.type().lanes() == 1 && analyzer_.CanProve(b > 0)) {
return Ramp::make( return Ramp::make(
a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
} }
if (b_ramp && a.type().lanes() == 1 && can_prove(a > 0)) { if (b_ramp && a.type().lanes() == 1 && analyzer_.CanProve(a > 0)) {
return Ramp::make( return Ramp::make(
b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
} }
...@@ -186,7 +187,7 @@ class Vectorizer : public IRMutator { ...@@ -186,7 +187,7 @@ class Vectorizer : public IRMutator {
Expr stride = this->Mutate(op->stride); Expr stride = this->Mutate(op->stride);
if (base.type().lanes() > 1 && stride.type().lanes() == 1) { if (base.type().lanes() > 1 && stride.type().lanes() == 1) {
const Ramp* base_ramp = base.as<Ramp>(); const Ramp* base_ramp = base.as<Ramp>();
if (can_prove(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) { if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) {
return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes); return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes);
} }
} }
...@@ -423,6 +424,8 @@ class Vectorizer : public IRMutator { ...@@ -423,6 +424,8 @@ class Vectorizer : public IRMutator {
} }
private: private:
// analyzer
arith::Analyzer analyzer_;
// variable to be replaced // variable to be replaced
Var var_; Var var_;
// the lanes. // the lanes.
......
...@@ -432,9 +432,9 @@ void PassDownBitMaskOr(const Stage& stage, ...@@ -432,9 +432,9 @@ void PassDownBitMaskOr(const Stage& stage,
*/ */
void PassUpBoundCheck(const Stage& s, void PassUpBoundCheck(const Stage& s,
const Map<IterVar, Range>& dom_map, const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, bool>* p_state) { std::unordered_map<IterVar, bool>* p_state,
arith::Analyzer* analyzer) {
auto& state = *p_state; auto& state = *p_state;
using HalideIR::Internal::can_prove;
for (size_t i = s->relations.size(); i != 0; --i) { for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1]; IterVarRelation rel = s->relations[i - 1];
if (const SplitNode* s = rel.as<SplitNode>()) { if (const SplitNode* s = rel.as<SplitNode>()) {
...@@ -447,7 +447,7 @@ void PassUpBoundCheck(const Stage& s, ...@@ -447,7 +447,7 @@ void PassUpBoundCheck(const Stage& s,
if (outer || inner) { if (outer || inner) {
state[s->parent] = true; state[s->parent] = true;
} else { } else {
if (can_prove(dom_map.at(s->parent)->extent == factor * step)) { if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) {
state[s->parent] = false; state[s->parent] = false;
} else { } else {
state[s->parent] = true; state[s->parent] = true;
...@@ -476,11 +476,13 @@ std::vector<Expr> MakeBoundCheck( ...@@ -476,11 +476,13 @@ std::vector<Expr> MakeBoundCheck(
const std::unordered_map<IterVar, Expr>& value_map, const std::unordered_map<IterVar, Expr>& value_map,
bool skip_ivar_domain, bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter) { const std::unordered_set<IterVar>& skip_iter) {
Analyzer analyzer;
std::unordered_map<IterVar, bool> bound_state; std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) { for (IterVar iv : stage->leaf_iter_vars) {
bound_state[iv] = false; bound_state[iv] = false;
} }
PassUpBoundCheck(stage, dom_map, &bound_state); PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
std::vector<Expr> preds; std::vector<Expr> preds;
std::unordered_map<const Variable*, IntSet> iset_dmap; std::unordered_map<const Variable*, IntSet> iset_dmap;
...@@ -496,7 +498,7 @@ std::vector<Expr> MakeBoundCheck( ...@@ -496,7 +498,7 @@ std::vector<Expr> MakeBoundCheck(
Range dom = dom_map.at(iv); Range dom = dom_map.at(iv);
Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min); Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min);
Expr vmax = EvalSet(value, iset_dmap).max(); Expr vmax = EvalSet(value, iset_dmap).max();
if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) { if (vmax.type() != value.type() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent); preds.emplace_back(value < dom->extent);
} }
} }
...@@ -511,10 +513,10 @@ std::vector<Expr> MakeBoundCheck( ...@@ -511,10 +513,10 @@ std::vector<Expr> MakeBoundCheck(
Expr vmin = s.min(); Expr vmin = s.min();
Expr vmax = s.max(); Expr vmax = s.max();
// The range of `value` resides in [vmin, vmax] // The range of `value` resides in [vmin, vmax]
if (vmin.type() != value.type() || !can_prove(vmin >= 0)) { if (vmin.type() != value.type() || !analyzer.CanProve(vmin >= 0)) {
preds.emplace_back(value >= 0); preds.emplace_back(value >= 0);
} }
if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) { if (vmax.type() != value.type() || !analyzer.CanProve(vmax < iv->dom->extent)) {
preds.emplace_back(value < iv->dom->extent); preds.emplace_back(value < iv->dom->extent);
} }
} }
......
...@@ -740,7 +740,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, ...@@ -740,7 +740,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
const Reduce* reduce = compute_op->body[idx].as<Reduce>(); const Reduce* reduce = compute_op->body[idx].as<Reduce>();
CHECK(reduce) << "Can only rfactor non-inline reductions"; CHECK(reduce) << "Can only rfactor non-inline reductions";
predicates.push_back(reduce->condition); predicates.push_back(reduce->condition);
Expr predicate = likely(simplify(arith::ComputeReduce<ir::And>(predicates, Expr()))); Expr predicate = likely(arith::ComputeReduce<ir::And>(predicates, Expr()));
std::unordered_map<const Variable*, Expr> vsub; std::unordered_map<const Variable*, Expr> vsub;
......
...@@ -21,12 +21,6 @@ ...@@ -21,12 +21,6 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/tvm.h> #include <tvm/tvm.h>
#include <arithmetic/Simplify.h>
TEST(IRSIMPLIFY, Basic) {
using namespace HalideIR::Internal;
simplify_test();
}
TEST(IRSIMPLIFY, MinMax) { TEST(IRSIMPLIFY, MinMax) {
auto x = tvm::var("x"); auto x = tvm::var("x");
......
...@@ -16,6 +16,14 @@ ...@@ -16,6 +16,14 @@
# under the License. # under the License.
import tvm import tvm
def assert_expr_equal(a, b):
res = tvm.ir_pass.Simplify(a - b)
equal = isinstance(res, tvm.expr.IntImm) and res.value == 0
if not equal:
raise ValueError("{} and {} are not equal".format(a, b))
def test_deduce(): def test_deduce():
a = tvm.var('a') a = tvm.var('a')
b = tvm.var('b') b = tvm.var('b')
...@@ -29,31 +37,34 @@ def test_deduce(): ...@@ -29,31 +37,34 @@ def test_deduce():
e0 = (-b)*a+c-d e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((d - c) /(b*-1)) ans0 = ((d - c) /(b*-1) + (-1))
assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) assert_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs # expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) assert_expr_equal(res0.max_value, ans0)
e0 = d*a+c-d e0 = d*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = ((0-c)/d + 1) ans0 = ((d-c)/d - 1)
assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) assert_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs # expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) assert_expr_equal(res0.max_value, ans0)
e1 = (a*4+b < c) e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = (((c - b) + -1)/4) ans1 = (((c - b) + -1)/4 -1)
assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1) assert_expr_equal(res1.max_value, ans1)
# expression containing variable a is on rhs # expression containing variable a is on rhs
e1 = (c > a*4+b) e1 = (c > a*4+b)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1) assert_expr_equal(res1.max_value, ans1)
e2 = (tvm.max(5, a * 4) < 0) e2 = (tvm.max(5, a * 4) < 0)
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
...@@ -66,7 +77,6 @@ def test_deduce(): ...@@ -66,7 +77,6 @@ def test_deduce():
assert str(res2.max_value) == "neg_inf" assert str(res2.max_value) == "neg_inf"
assert str(res2.min_value) == "pos_inf" assert str(res2.min_value) == "pos_inf"
e3 = (-b)+a*c-d e3 = (-b)+a*c-d
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = 2/c+1 ans3 = 2/c+1
...@@ -75,6 +85,7 @@ def test_deduce(): ...@@ -75,6 +85,7 @@ def test_deduce():
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
def test_check(): def test_check():
a = tvm.var('a') a = tvm.var('a')
b = tvm.var('b') b = tvm.var('b')
......
...@@ -24,9 +24,6 @@ def test_simplify(): ...@@ -24,9 +24,6 @@ def test_simplify():
assert(tvm.ir_pass.Equal(e2, x * 8)) assert(tvm.ir_pass.Equal(e2, x * 8))
e3 = tvm.ir_pass.Simplify(x - x / 3 * 3) e3 = tvm.ir_pass.Simplify(x - x / 3 * 3)
assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3))) assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3)))
let = tvm.make.Let(x, 1, x + 3)
e4 = tvm.ir_pass.Simplify(let)
assert(tvm.ir_pass.Equal(e4, 4))
def test_verify_ssa(): def test_verify_ssa():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment