Commit 6bc0ae12 by Tianqi Chen Committed by ziheng

[ARITH] Refactor intset eval with functor (#295)

parent 10bc2fdf
...@@ -94,6 +94,12 @@ class IntSet : public NodeRef { ...@@ -94,6 +94,12 @@ class IntSet : public NodeRef {
*/ */
static IntSet single_point(Expr point); static IntSet single_point(Expr point);
/*! /*!
* \brief construct a integer set from vector expression.
* \param vec The vector expression, can also be single point.
* \return The result set containing the indices in the vector.
*/
static IntSet vector(Expr vec);
/*!
* \brief Construct a set representing a range. * \brief Construct a set representing a range.
* \param r The range * \param r The range
* \return constructed set. * \return constructed set.
......
...@@ -16,6 +16,11 @@ TVM_REGISTER_API("arith.intset_single_point") ...@@ -16,6 +16,11 @@ TVM_REGISTER_API("arith.intset_single_point")
*ret = IntSet::single_point(args[0]); *ret = IntSet::single_point(args[0]);
}); });
TVM_REGISTER_API("arith.intset_vector")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::vector(args[0]);
});
TVM_REGISTER_API("arith.intset_interval") TVM_REGISTER_API("arith.intset_interval")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::interval(args[0], args[1]); *ret = IntSet::interval(args[0], args[1]);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include <tvm/ir_functor_ext.h>
#include <arithmetic/Interval.h> #include <arithmetic/Interval.h>
#include <unordered_map> #include <unordered_map>
#include "./compute_expr.h" #include "./compute_expr.h"
...@@ -423,80 +424,129 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) { ...@@ -423,80 +424,129 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
return CombineSets<OP>(a, b); return CombineSets<OP>(a, b);
} }
// Evaluator to evalute the epxression. class IntSetEvaluator :
class IntSetEvaluator { public ExprFunctor<IntSet(const Expr&, const Expr&)> {
public: public:
explicit IntSetEvaluator(const std::unordered_map<const Variable*, IntSet>& dom_map) explicit IntSetEvaluator(
: dom_map(dom_map) {} const std::unordered_map<const Variable*, IntSet>& dom_map,
bool eval_vec = false)
inline virtual IntSet Eval(Expr expr) { : dom_map_(dom_map), eval_vec_(eval_vec) {}
static const FType& f = vtable(); // Evaluate.
if (f.can_dispatch(expr)) { IntSet Eval(const Expr& e) {
return f(expr, expr, this); return this->VisitExpr(e, e);
} else { }
LOG(WARNING) << "cannot evaluate set type " << expr->type_key(); IntSet VisitExpr_(const IntImm* op, const Expr& e) final {
return IntSet::nothing(); return IntSet::single_point(e);
}
}
using FType = tvm::IRFunctor<IntSet (const NodeRef&, const Expr&, IntSetEvaluator *)>;
static FType& vtable() { // NOLINT(*)
static FType inst; return inst;
} }
IntSet VisitExpr_(const UIntImm* op, const Expr& e) final {
const std::unordered_map<const Variable*, IntSet>& dom_map;
};
inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator* m) {
return IntSet::single_point(e); return IntSet::single_point(e);
} }
IntSet VisitExpr_(const Variable* op, const Expr& e) final {
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) auto it = dom_map_.find(op);
.set_dispatch<IntImm>(ConstOp) if (it != dom_map_.end()) {
.set_dispatch<UIntImm>(ConstOp)
.set_dispatch<FloatImm>(ConstOp);
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<Variable>([](const Variable* op, const Expr& e, IntSetEvaluator* m) {
auto it = m->dom_map.find(op);
if (it != m->dom_map.end()) {
return it->second; return it->second;
} else { } else {
return IntSet::single_point(e); return IntSet::single_point(e);
} }
}); }
IntSet VisitExpr_(const Add* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Sub* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Mul* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Div* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Mod* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Min* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Max* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const EQ* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const NE* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const LT* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const LE* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const GT* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const GE* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const And* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Or* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Ramp* op, const Expr& e) final {
CHECK(eval_vec_);
IntSet base = Eval(op->base);
int vstride;
if (GetConstInt(op->stride, &vstride)) {
Type t = op->base.type();
if (vstride > 0) {
return Combine<Add>(
base,
IntSet::interval(make_zero(t),
make_const(t, vstride * op->lanes -1)));
} else {
return Combine<Add>(
base,
IntSet::interval(make_const(t, vstride * op->lanes + 1),
make_zero(t)));
}
}
LOG(WARNING) << "cannot evaluate set on expression " << e;
return IntSet::everything();
}
IntSet VisitExpr_(const Broadcast* op, const Expr& e) final {
CHECK(eval_vec_);
return Eval(op->value);
}
IntSet VisitExprDefault_(const Node* op, const Expr& e) final {
LOG(WARNING) << "cannot evaluate set type " << e->type_key();
return IntSet::everything();
}
// binary operator private:
template<typename T> template<typename T>
inline IntSet Binary(const T* op, const Expr& e, IntSetEvaluator* m) { inline IntSet Binary(const T* op, const Expr& e) {
IntSet a = m->Eval(op->a); IntSet a = this->Eval(op->a);
IntSet b = m->Eval(op->b); IntSet b = this->Eval(op->b);
if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
return IntSet::single_point(e); return IntSet::single_point(e);
} }
return Combine<T>(a, b); return Combine<T>(a, b);
} }
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) const std::unordered_map<const Variable*, IntSet>& dom_map_;
.set_dispatch<Add>(Binary<Add>) bool eval_vec_{false};
.set_dispatch<Sub>(Binary<Sub>) };
.set_dispatch<Mul>(Binary<Mul>)
.set_dispatch<Div>(Binary<Div>)
.set_dispatch<Mod>(Binary<Mod>)
.set_dispatch<Min>(Binary<Min>)
.set_dispatch<Max>(Binary<Max>)
.set_dispatch<EQ>(Binary<EQ>)
.set_dispatch<NE>(Binary<NE>)
.set_dispatch<LT>(Binary<LT>)
.set_dispatch<LE>(Binary<LE>)
.set_dispatch<GT>(Binary<GT>)
.set_dispatch<GE>(Binary<GE>)
.set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>);
IntSet EvalSet(Expr e, IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map) { const std::unordered_map<const Variable*, IntSet>& dom_map) {
return IntSetEvaluator(dom_map).Eval(e); return IntSetEvaluator(dom_map, false).Eval(e);
}
IntSet IntSet::vector(Expr x) {
std::unordered_map<const Variable*, IntSet> dmap;
return IntSetEvaluator(dmap, true).Eval(x);
} }
IntSet EvalSet(Expr e, IntSet EvalSet(Expr e,
...@@ -521,12 +571,13 @@ IntSet EvalSet(Range r, ...@@ -521,12 +571,13 @@ IntSet EvalSet(Range r,
class SubExprIntSetEvaluator : public IntSetEvaluator { class SubExprIntSetEvaluator : public IntSetEvaluator {
public: public:
explicit SubExprIntSetEvaluator(const std::unordered_map<const Variable*, IntSet>& dom_map) explicit SubExprIntSetEvaluator(
const std::unordered_map<const Variable*, IntSet>& dom_map)
: IntSetEvaluator(dom_map) {} : IntSetEvaluator(dom_map) {}
inline IntSet Eval(Expr expr) override { IntSet VisitExpr(const Expr& n, const Expr& e) final {
IntSet ret = IntSetEvaluator::Eval(expr); IntSet ret = IntSetEvaluator::VisitExpr(n, e);
expr_map[expr] = ret; expr_map[n] = ret;
return ret; return ret;
} }
......
...@@ -5,6 +5,14 @@ def test_basic(): ...@@ -5,6 +5,14 @@ def test_basic():
assert s.min().value == 2 assert s.min().value == 2
assert s.max().value == 3 assert s.max().value == 3
def test_vector():
base = 10
stride = 3
lanes = 2
s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes))
assert s.min().value == base
assert s.max().value == base + stride * lanes - 1
def test_deduce(): def test_deduce():
a = tvm.var('a') a = tvm.var('a')
b = tvm.var('b') b = tvm.var('b')
...@@ -59,5 +67,6 @@ def test_check(): ...@@ -59,5 +67,6 @@ def test_check():
if __name__ == "__main__": if __name__ == "__main__":
test_basic() test_basic()
test_vector()
test_deduce() test_deduce()
test_check() test_check()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment