Commit f8f02829 by Tianqi Chen Committed by GitHub

[SCHEDULE] Refactor bound inference logic (#41)

parent 5c07413c
...@@ -22,7 +22,7 @@ namespace schedule { ...@@ -22,7 +22,7 @@ namespace schedule {
* \param sch The root schedule to infer all the bounds. * \param sch The root schedule to infer all the bounds.
* \return the result bound of the iteration Variable * \return the result bound of the iteration Variable
*/ */
Map<IterVar, Range> InferBound(Schedule sch); Map<IterVar, Range> InferBound(const Schedule& sch);
/*! /*!
* \brief Schedule s' dependent operations. * \brief Schedule s' dependent operations.
......
...@@ -432,7 +432,6 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) ...@@ -432,7 +432,6 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<And>(Binary<And>) .set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>); .set_dispatch<Or>(Binary<Or>);
IntSet EvalSet(Expr e, IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map) { const std::unordered_map<const Variable*, IntSet>& dom_map) {
return IntSetEvaluator(dom_map).Eval(e); return IntSetEvaluator(dom_map).Eval(e);
...@@ -444,17 +443,12 @@ IntSet EvalSet(Expr e, ...@@ -444,17 +443,12 @@ IntSet EvalSet(Expr e,
for (auto kv : dom_map) { for (auto kv : dom_map) {
dmap[kv.first->var.as<Variable>()] = kv.second; dmap[kv.first->var.as<Variable>()] = kv.second;
} }
IntSetEvaluator m(dmap); return EvalSet(e, dmap);
return m.Eval(e);
} }
IntSet EvalSet(Range r, IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map) { const std::unordered_map<const Variable*, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap; IntSetEvaluator m(dom_map);
for (auto kv : dom_map) {
dmap[kv.first->var.as<Variable>()] = kv.second;
}
IntSetEvaluator m(dmap);
IntSet min_set = m.Eval(r->min); IntSet min_set = m.Eval(r->min);
IntSet ext_set = m.Eval(r->extent).cover_interval(); IntSet ext_set = m.Eval(r->extent).cover_interval();
const Interval& ei = ext_set.as<IntervalSet>()->i; const Interval& ei = ext_set.as<IntervalSet>()->i;
...@@ -463,6 +457,15 @@ IntSet EvalSet(Range r, ...@@ -463,6 +457,15 @@ IntSet EvalSet(Range r,
return Combine<Add>(min_set, ext_set); return Combine<Add>(min_set, ext_set);
} }
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first->var.as<Variable>()] = kv.second;
}
return EvalSet(r, dmap);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) { .set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) {
p->stream << "interval-set[" p->stream << "interval-set["
...@@ -470,6 +473,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -470,6 +473,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< op->i.max << ']'; << op->i.max << ']';
}); });
} // namespace arith } // namespace arith
} // namespace tvm } // namespace tvm
...@@ -103,6 +103,9 @@ IntSet EvalSet(Expr e, ...@@ -103,6 +103,9 @@ IntSet EvalSet(Expr e,
*/ */
IntSet EvalSet(Range r, IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map); const Map<IterVar, IntSet>& dom_map);
IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);
/*! /*!
* \brief Create an union set of all sets * \brief Create an union set of all sets
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment