Commit c8ebfbe3 by Ziheng Jiang Committed by Tianqi Chen

[PASS]LoopPartition (#56)

* loop_partition draft

* divide loop variable into constant domain and variable domain & consider multiple partitions

* process doubt interval

* fix and refactor, add relax_map arg in BoundDeduce

* fix testcase and comment

* rebase to zero, convert to SSA

* change the logic of generating loop code & fix issues

* add a testcase for relax map in deducebound && fix issues

* clean code

* const auto&

* add test_multi_if
parent 24bca6af
...@@ -35,6 +35,8 @@ using Halide::Internal::make_const; ...@@ -35,6 +35,8 @@ using Halide::Internal::make_const;
using Halide::Internal::make_zero; using Halide::Internal::make_zero;
using Halide::Internal::as_const_int; using Halide::Internal::as_const_int;
using Halide::Internal::as_const_uint; using Halide::Internal::as_const_uint;
using Halide::Internal::const_true;
using Halide::Internal::const_false;
inline Type TVMType2Type(TVMType t) { inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes); return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
...@@ -53,8 +55,8 @@ class Var : public Halide::VarExpr { ...@@ -53,8 +55,8 @@ class Var : public Halide::VarExpr {
public: public:
explicit Var(const std::string& name_hint = "v", explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {} Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {} explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
explicit Var(VarExpr v) : VarExpr(v) {}
/*! \brief type indicate the container type */ /*! \brief type indicate the container type */
using ContainerType = Variable; using ContainerType = Variable;
......
...@@ -62,12 +62,12 @@ class IRMutator { ...@@ -62,12 +62,12 @@ class IRMutator {
virtual Stmt Mutate_(const Allocate* op, const Stmt& s); virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s); virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s); virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e); virtual Stmt Mutate_(const AssertStmt* op, const Stmt& s);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e); virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& e); virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s); virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s); virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& e); virtual Stmt Mutate_(const Evaluate* op, const Stmt& s);
virtual Expr Mutate_(const Variable* op, const Expr& e); virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& e); virtual Expr Mutate_(const Load* op, const Expr& e);
......
...@@ -138,6 +138,13 @@ Stmt InjectVirtualThread(Stmt stmt); ...@@ -138,6 +138,13 @@ Stmt InjectVirtualThread(Stmt stmt);
Stmt LiftAllocate(Stmt stmt); Stmt LiftAllocate(Stmt stmt);
/*! /*!
* \brief partition loops in the stmt
* \param stmt The stmt to do loop partition
* \return Transformed stmt.
*/
Stmt LoopPartition(Stmt stmt);
/*!
* \brief Make an user callable API LoweredFunc. * \brief Make an user callable API LoweredFunc.
* *
* The main task of this function is to create code to : * The main task of this function is to create code to :
......
...@@ -29,7 +29,9 @@ TVM_REGISTER_API(_arith_EvalModular) ...@@ -29,7 +29,9 @@ TVM_REGISTER_API(_arith_EvalModular)
TVM_REGISTER_API(_arith_DeduceBound) TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], args[2]); *ret = DeduceBound(args[0], args[1],
args[2].operator Map<Var, IntSet>(),
args[3].operator Map<Var, IntSet>());
}); });
TVM_REGISTER_API(_IntervalSetGetMin) TVM_REGISTER_API(_IntervalSetGetMin)
......
...@@ -69,6 +69,7 @@ REGISTER_PASS4(MakeAPI); ...@@ -69,6 +69,7 @@ REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice); REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(LiftAllocate); REGISTER_PASS1(LiftAllocate);
REGISTER_PASS1(InjectVirtualThread); REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(LoopPartition);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -21,7 +21,7 @@ using Halide::Internal::Interval; ...@@ -21,7 +21,7 @@ using Halide::Internal::Interval;
// from a expression. // from a expression.
class VariablePathFinder: public IRVisitor { class VariablePathFinder: public IRVisitor {
public: public:
explicit VariablePathFinder(Var target) : target_(target) {} explicit VariablePathFinder(Expr target) : target_(target) {}
void Visit(const NodeRef& node) final { void Visit(const NodeRef& node) final {
if (visited_.count(node.get()) != 0) return; if (visited_.count(node.get()) != 0) return;
...@@ -37,13 +37,13 @@ class VariablePathFinder: public IRVisitor { ...@@ -37,13 +37,13 @@ class VariablePathFinder: public IRVisitor {
private: private:
bool found_{false}; bool found_{false};
Var target_; Expr target_;
std::unordered_set<const Node*> visited_; std::unordered_set<const Node*> visited_;
}; };
// get the path to the variable, // get the path to the variable,
// return empty vector to represent failure // return empty vector to represent failure
std::vector<const Node*> GetPath(Var target, Expr expr) { std::vector<const Node*> GetPath(Expr target, Expr expr) {
VariablePathFinder v(target); VariablePathFinder v(target);
v.Visit(expr); v.Visit(expr);
return v.path_; return v.path_;
...@@ -56,11 +56,11 @@ class BoundDeducer: public IRVisitor { ...@@ -56,11 +56,11 @@ class BoundDeducer: public IRVisitor {
public: public:
friend class BoundDeduceInputChecker; friend class BoundDeduceInputChecker;
friend class Converter; friend class Converter;
BoundDeducer(Var target, Expr expr, BoundDeducer(Expr target, Expr expr,
const std::unordered_map<const Variable*, IntSet>& dom_map) const std::unordered_map<const Variable*, IntSet>& hint_map,
: target_(target), expr_(expr), dom_map_(dom_map) {} const std::unordered_map<const Variable*, IntSet>& relax_map)
: target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
bool Init();
void Deduce(); void Deduce();
void Visit(const NodeRef& e) final { void Visit(const NodeRef& e) final {
...@@ -137,9 +137,14 @@ class BoundDeducer: public IRVisitor { ...@@ -137,9 +137,14 @@ class BoundDeducer: public IRVisitor {
bool success{true}; bool success{true};
private: private:
Var target_; void Init();
void Transform();
void Relax();
Expr target_;
Expr expr_; Expr expr_;
const std::unordered_map<const Variable*, IntSet>& dom_map_; const std::unordered_map<const Variable*, IntSet>& hint_map_;
const std::unordered_map<const Variable*, IntSet>& relax_map_;
ExprIntSetMap expr_map_; ExprIntSetMap expr_map_;
std::vector<const Node*> path_; std::vector<const Node*> path_;
size_t iter_{0}; size_t iter_{0};
...@@ -163,10 +168,13 @@ class BoundDeduceInputChecker: public IRVisitor { ...@@ -163,10 +168,13 @@ class BoundDeduceInputChecker: public IRVisitor {
size_t target_count{0}; size_t target_count{0};
}; };
bool BoundDeducer::Init() { void BoundDeducer::Init() {
BoundDeduceInputChecker checker; BoundDeduceInputChecker checker;
if (!checker.Check(this)) success = false; if (!checker.Check(this)) success = false;
Transform();
}
void BoundDeducer::Transform() {
if (const LT* op = expr_.as<LT>()) { if (const LT* op = expr_.as<LT>()) {
is_greater = false; is_greater = false;
is_equal = false; is_equal = false;
...@@ -190,30 +198,35 @@ bool BoundDeducer::Init() { ...@@ -190,30 +198,35 @@ bool BoundDeducer::Init() {
} else { } else {
success = false; success = false;
} }
return success;
} }
void BoundDeducer::Deduce() { void BoundDeducer::Deduce() {
Init(); Init();
if (!success) return; if (!success) return;
Relax();
// get the path // get the path
path_ = GetPath(target_, expr_); path_ = GetPath(target_, expr_);
// get the sign of every subexpr // get the sign of every subexpr
expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_); expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
Visit(expr_); Visit(expr_);
} }
// assuming e >= 0, deduce the bound of variable from it. void BoundDeducer::Relax() {
// return empty set to represent deduce failure. if (is_greater) {
IntSet DeduceBound(Var v, Expr e, expr_ = EvalSet(expr_ , relax_map_).min();
const Map<Var, IntSet>& dom_map) { result = EvalSet(result, relax_map_).max();
std::unordered_map<const Variable*, IntSet> dmap; } else {
for (auto kv : dom_map) { expr_ = EvalSet(expr_ , relax_map_).max();
dmap[kv.first.get()] = kv.second; result = EvalSet(result, relax_map_).min();
} }
BoundDeducer d(v, e, dmap); }
IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce(); d.Deduce();
if (!d.success) return IntSet::nothing(); if (!d.success) return IntSet::nothing();
Expr min = Interval::neg_inf, max = Interval::pos_inf; Expr min = Interval::neg_inf, max = Interval::pos_inf;
...@@ -225,5 +238,21 @@ IntSet DeduceBound(Var v, Expr e, ...@@ -225,5 +238,21 @@ IntSet DeduceBound(Var v, Expr e,
return IntSet::interval(min, max); return IntSet::interval(min, max);
} }
// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
IntSet DeduceBound(Expr v, Expr e,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map) {
std::unordered_map<const Variable*, IntSet> hmap;
for (auto kv : hint_map) {
hmap[kv.first.get()] = kv.second;
}
std::unordered_map<const Variable*, IntSet> rmap;
for (auto kv : relax_map) {
rmap[kv.first.get()] = kv.second;
}
return DeduceBound(v, e, hmap, rmap);
}
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
...@@ -162,11 +162,11 @@ inline bool MatchPoint(const IntSet& a, ...@@ -162,11 +162,11 @@ inline bool MatchPoint(const IntSet& a,
return i.is_single_point() && i.min.same_as(b); return i.is_single_point() && i.min.same_as(b);
} }
IntSet Union(const Array<IntSet>& set) { IntSet Union(const Array<IntSet>& sets) {
if (set.size() == 1) return set[0]; if (sets.size() == 1) return sets[0];
Interval x = set[0].cover_interval().as<IntervalSet>()->i; Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < set.size(); ++i) { for (size_t i = 1; i < sets.size(); ++i) {
IntSet s = set[i].cover_interval(); IntSet s = sets[i].cover_interval();
const Interval& y = s.as<IntervalSet>()->i; const Interval& y = s.as<IntervalSet>()->i;
if (can_prove(x.max + 1 >= y.min)) { if (can_prove(x.max + 1 >= y.min)) {
x.max = y.max; x.max = y.max;
...@@ -179,6 +179,15 @@ IntSet Union(const Array<IntSet>& set) { ...@@ -179,6 +179,15 @@ IntSet Union(const Array<IntSet>& set) {
return IntervalSet::make(x); return IntervalSet::make(x);
} }
IntSet Intersect(const Array<IntSet>& sets) {
Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < sets.size(); ++i) {
Interval y = sets[i].cover_interval().as<IntervalSet>()->i;
x = Interval::make_intersection(x, y);
}
return IntervalSet::make(x);
}
// type traits // type traits
template<typename OP> template<typename OP>
struct is_logical_op { struct is_logical_op {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <vector>
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -157,6 +158,13 @@ ExprIntSetMap EvalSetForEachSubExpr(Expr r, ...@@ -157,6 +158,13 @@ ExprIntSetMap EvalSetForEachSubExpr(Expr r,
*/ */
IntSet Union(const Array<IntSet>& sets); IntSet Union(const Array<IntSet>& sets);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be intersected
* \return the set after intersected
*/
IntSet Intersect(const Array<IntSet>& sets);
// implementation // implementation
inline const IntSetNode* IntSet::operator->() const { inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get()); return static_cast<const IntSetNode*>(node_.get());
...@@ -169,11 +177,17 @@ inline const IntSetNode* IntSet::operator->() const { ...@@ -169,11 +177,17 @@ inline const IntSetNode* IntSet::operator->() const {
* *
* \param v The target variable to be deduced. * \param v The target variable to be deduced.
* \param cond The conditional expression. * \param cond The conditional expression.
* \param dom_map The domain of each variable. * \param hint_map The domain of variable, used to help deduce.
* \param relax The domain of each variable, used to relax the domain.
* \return An integer set that can cover all the possible values. * \return An integer set that can cover all the possible values.
*/ */
IntSet DeduceBound(Var v, Expr cond, IntSet DeduceBound(Expr v, Expr cond,
const Map<Var, IntSet>& dom_map); const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map);
IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map);
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
......
...@@ -128,7 +128,7 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { ...@@ -128,7 +128,7 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
Expr condition = this->Mutate(op->condition); Expr condition = this->Mutate(op->condition);
Stmt then_case = this->Mutate(op->then_case); Stmt then_case = this->Mutate(op->then_case);
Stmt else_case; Stmt else_case;
if (else_case.defined()) { if (op->else_case.defined()) {
else_case = this->Mutate(op->else_case); else_case = this->Mutate(op->else_case);
} }
if (condition.same_as(op->condition) && if (condition.same_as(op->condition) &&
......
/*!
* Copyright (c) 2017 by Contributors
* \file loop_partition.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
#include <unordered_set>
#include "../arithmetic/int_set.h"
#include "../arithmetic/int_set_internal.h"
namespace tvm {
namespace ir {
using arith::IntSet;
// a partition means the expr is equal to true in the interval
struct Partition {
Expr expr;
IntSet interval;
};
bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
bool success = false;
PostOrderVisit(expr, [&vars, &success](const NodeRef& node) {
if (const Variable* v = node.as<Variable>()) {
if (vars.count(v)) {
success = true;
return;
}
}
});
return success;
}
class PartitionFinder : public IRVisitor {
public:
explicit PartitionFinder(VarExpr loop_var,
const std::unordered_map<const Variable*, IntSet>& dom_map)
: target_var_(loop_var), out_vars_(dom_map.size()), hint_map_(dom_map) {
for (const auto& kv : dom_map) out_vars_.insert(kv.first);
}
void Visit_(const For* op) {
if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;
hint_map_.insert({op->loop_var.get(),
IntSet::interval(op->min, op->min + op->extent - 1)});
relax_map_.insert({op->loop_var.get(),
IntSet::interval(op->min, op->min + op->extent - 1)});
IRVisitor::Visit_(op);
relax_map_.erase(op->loop_var.get());
hint_map_.erase(op->loop_var.get());
}
void Visit_(const IfThenElse* op) {
if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({target_var_.get()}))) {
IntSet interval = DeduceBound(target_var_, op->condition, hint_map_, relax_map_);
partitions[op->condition.get()] = Partition{op->condition, interval};
} else {
IRVisitor::Visit_(op);
}
}
std::unordered_map<const Node*, Partition> partitions;
private:
VarExpr target_var_;
std::unordered_set<const Variable*> out_vars_;
std::unordered_map<const Variable*, IntSet> hint_map_;
std::unordered_map<const Variable*, IntSet> relax_map_;
};
class PartitionReplacer : public IRMutator {
public:
explicit PartitionReplacer(const std::unordered_map<const Node*, Partition>& ps)
: ps_(ps) {}
Expr Mutate(Expr e) override {
if (ps_.count(e.get())) {
return Mutate(const_true());
}
return IRMutator::Mutate(e);
}
using IRMutator::Mutate;
private:
const std::unordered_map<const Node*, Partition>& ps_;
};
class LoopPartitioner : public IRMutator {
public:
LoopPartitioner() {}
Stmt Mutate_(const For* op, const Stmt& stmt) {
if (!is_const(op->min) || !is_const(op->extent)) {
Stmt s = DoPartition(op, stmt);
if (s.defined()) return s;
}
dom_map_.insert({op->loop_var.get(),
IntSet::interval(op->min, op->min + op->extent - 1)});
Stmt res = IRMutator::Mutate_(op, stmt);
dom_map_.erase(op->loop_var.get());
return res;
}
private:
Stmt DoPartition(const For* op, const Stmt& stmt);
std::unordered_map<const Variable*, IntSet> dom_map_;
};
Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) {
PartitionFinder finder(op->loop_var, dom_map_);
finder.Visit(op->body);
const auto& partitions = finder.partitions;
if (partitions.empty()) return Stmt();
Expr min = op->min;
Expr max = op->min + op->extent - 1;
Array<IntSet> sets;
// merge partitions (take their intersect)
for (const auto& kv : partitions) {
sets.push_back(kv.second.interval);
}
IntSet true_itrv = Intersect(sets);
Stmt pre_stmt;
Expr body_begin;
if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) {
body_begin = true_itrv.min();
if (!can_prove(body_begin == min)) {
if (!can_prove(body_begin - min >= 0)) {
LOG(WARNING) << "cannot prove: " << (body_begin - min >= 0)
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
}
// [min, body_begin)
Stmt body = Substitute(op->body,
{{Var{op->loop_var}, op->loop_var + min}});
pre_stmt = For::make(op->loop_var, 0,
body_begin - min, op->for_type, op->device_api, body);
}
} else {
body_begin = min;
}
Stmt post_stmt;
Expr post_doubt_begin;
if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) {
post_doubt_begin = true_itrv.max() + 1;
if (!can_prove(true_itrv.max() == max)) {
if (!can_prove(max - post_doubt_begin >= 0)) {
LOG(WARNING) << "Cannot prove: " << (max - post_doubt_begin >= 0)
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max);
}
// [post_doubt_begin, max]
Stmt body = Substitute(op->body,
{{Var{op->loop_var}, op->loop_var + post_doubt_begin}});
post_stmt = For::make(op->loop_var, 0,
max - post_doubt_begin + 1, op->for_type, op->device_api, body);
}
} else {
post_doubt_begin = max + 1;
}
// [body_begin, post_doubt_begin)
Stmt simplified_body = PartitionReplacer(partitions).Mutate(op->body);
Stmt body = Substitute(simplified_body, {{Var{op->loop_var}, op->loop_var + body_begin}});
Stmt simplified_stmt = For::make(op->loop_var, 0,
post_doubt_begin - body_begin, op->for_type, op->device_api, body);
Stmt s = simplified_stmt;
if (pre_stmt.defined()) {
s = Block::make(pre_stmt, s);
}
if (post_stmt.defined()) {
s = Block::make(s, post_stmt);
}
return Simplify(ConvertSSA(s));
}
Stmt LoopPartition(Stmt stmt) {
stmt = LoopPartitioner().Mutate(stmt);
return stmt;
}
} // namespace ir
} // namespace tvm
...@@ -16,20 +16,25 @@ def test_deduce(): ...@@ -16,20 +16,25 @@ def test_deduce():
d_s = tvm.arith.intset_interval(-3, -1) d_s = tvm.arith.intset_interval(-3, -1)
e0 = (-b)*a+c-d e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}) res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = (d-c)/(-b)+(-1) ans0 = (d-c)/(-b)+(-1)
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
e1 = (a*4+b < c) e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = (c-b)/4+(-2) ans1 = (c-b)/4+(-2)
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
e2 = (tvm.max(5, a * 4) < 0) e2 = (tvm.max(5, a * 4) < 0)
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}) res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max()) == "neg_inf" assert str(res2.max()) == "neg_inf"
assert str(res2.min()) == "pos_inf" assert str(res2.min()) == "pos_inf"
e3 = (-b)+a*c-d
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = 2/c+1
assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
def test_check(): def test_check():
a = tvm.Var('a') a = tvm.Var('a')
b = tvm.Var('b') b = tvm.Var('b')
...@@ -41,15 +46,15 @@ def test_check(): ...@@ -41,15 +46,15 @@ def test_check():
d_s = tvm.arith.intset_interval(-3, -1) d_s = tvm.arith.intset_interval(-3, -1)
# no compare operator # no compare operator
res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}) res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
assert res1.is_nothing() assert res1.is_nothing()
# multiple compare operators # multiple compare operators
res2 = tvm.arith.DeduceBound(a, a+b>3>c , {b: b_s, c: c_s}) res2 = tvm.arith.DeduceBound(a, a+b>3>c , {b: b_s, c: c_s}, {})
assert res1.is_nothing() assert res1.is_nothing()
# multiple target variable # multiple target variable
res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}) res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
assert res1.is_nothing() assert res1.is_nothing()
if __name__ == "__main__": if __name__ == "__main__":
......
import tvm
def test_basic():
n = tvm.Var('n')
A = tvm.placeholder((n, ), name='A')
B = tvm.placeholder((n, ), name='B')
T = tvm.compute((n, ), lambda i: A[i]+B[i])
s = tvm.Schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt)
assert('if' not in str(stmt.body.body.body.first))
print(stmt)
def test_multi_loop():
i = tvm.Var('i')
j = tvm.Var('j')
k = tvm.Var('k')
m = tvm.Var('m')
n = tvm.Var('n')
stmt = tvm.make.For(
i, 0, 4, 0, 0,
tvm.make.For(
j, 0, n, 0, 0,
tvm.make.For(
k, 0, m, 0, 0,
tvm.make.IfThenElse(
(i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)))))
stmt = tvm.ir_pass.LoopPartition(stmt)
assert('if' not in str(stmt.body.first))
print(stmt)
def test_multi_if():
i = tvm.Var('i')
j = tvm.Var('j')
k = tvm.Var('k')
m = tvm.Var('m')
n = tvm.Var('n')
stmt = tvm.make.For(
i, 0, 4, 0, 0,
tvm.make.For(
j, 0, n, 0, 0,
tvm.make.For(
k, 0, m, 0, 0,
tvm.make.Block(
tvm.make.IfThenElse((i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)),
tvm.make.IfThenElse((i*m+j-k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n))
))))
stmt = tvm.ir_pass.LoopPartition(stmt)
assert('if' not in str(stmt.body.first))
print(stmt)
if __name__ == "__main__":
test_basic()
test_multi_loop()
test_multi_if()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment