Commit 6bc0ae12 by Tianqi Chen Committed by ziheng

[ARITH] Refactor intset eval with functor (#295)

parent 10bc2fdf
......@@ -94,6 +94,12 @@ class IntSet : public NodeRef {
*/
static IntSet single_point(Expr point);
/*!
* \brief construct a integer set from vector expression.
* \param vec The vector expression, can also be single point.
* \return The result set containing the indices in the vector.
*/
static IntSet vector(Expr vec);
/*!
* \brief Construct a set representing a range.
* \param r The range
* \return constructed set.
......
......@@ -16,6 +16,11 @@ TVM_REGISTER_API("arith.intset_single_point")
*ret = IntSet::single_point(args[0]);
});
TVM_REGISTER_API("arith.intset_vector")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::vector(args[0]);
});
TVM_REGISTER_API("arith.intset_interval")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::interval(args[0], args[1]);
......
......@@ -6,6 +6,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/ir_functor_ext.h>
#include <arithmetic/Interval.h>
#include <unordered_map>
#include "./compute_expr.h"
......@@ -423,80 +424,129 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
return CombineSets<OP>(a, b);
}
// Evaluator to evalute the epxression.
class IntSetEvaluator {
class IntSetEvaluator :
public ExprFunctor<IntSet(const Expr&, const Expr&)> {
public:
explicit IntSetEvaluator(const std::unordered_map<const Variable*, IntSet>& dom_map)
: dom_map(dom_map) {}
inline virtual IntSet Eval(Expr expr) {
static const FType& f = vtable();
if (f.can_dispatch(expr)) {
return f(expr, expr, this);
} else {
LOG(WARNING) << "cannot evaluate set type " << expr->type_key();
return IntSet::nothing();
}
explicit IntSetEvaluator(
const std::unordered_map<const Variable*, IntSet>& dom_map,
bool eval_vec = false)
: dom_map_(dom_map), eval_vec_(eval_vec) {}
// Evaluate.
IntSet Eval(const Expr& e) {
return this->VisitExpr(e, e);
}
IntSet VisitExpr_(const IntImm* op, const Expr& e) final {
return IntSet::single_point(e);
}
using FType = tvm::IRFunctor<IntSet (const NodeRef&, const Expr&, IntSetEvaluator *)>;
static FType& vtable() { // NOLINT(*)
static FType inst; return inst;
IntSet VisitExpr_(const UIntImm* op, const Expr& e) final {
return IntSet::single_point(e);
}
const std::unordered_map<const Variable*, IntSet>& dom_map;
};
inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator* m) {
return IntSet::single_point(e);
}
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<IntImm>(ConstOp)
.set_dispatch<UIntImm>(ConstOp)
.set_dispatch<FloatImm>(ConstOp);
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<Variable>([](const Variable* op, const Expr& e, IntSetEvaluator* m) {
auto it = m->dom_map.find(op);
if (it != m->dom_map.end()) {
IntSet VisitExpr_(const Variable* op, const Expr& e) final {
auto it = dom_map_.find(op);
if (it != dom_map_.end()) {
return it->second;
} else {
return IntSet::single_point(e);
}
});
}
IntSet VisitExpr_(const Add* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Sub* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Mul* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Div* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Mod* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Min* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Max* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const EQ* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const NE* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const LT* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const LE* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const GT* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const GE* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const And* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Or* op, const Expr& e) final {
return Binary(op, e);
}
IntSet VisitExpr_(const Ramp* op, const Expr& e) final {
CHECK(eval_vec_);
IntSet base = Eval(op->base);
int vstride;
if (GetConstInt(op->stride, &vstride)) {
Type t = op->base.type();
if (vstride > 0) {
return Combine<Add>(
base,
IntSet::interval(make_zero(t),
make_const(t, vstride * op->lanes -1)));
} else {
return Combine<Add>(
base,
IntSet::interval(make_const(t, vstride * op->lanes + 1),
make_zero(t)));
}
}
LOG(WARNING) << "cannot evaluate set on expression " << e;
return IntSet::everything();
}
IntSet VisitExpr_(const Broadcast* op, const Expr& e) final {
CHECK(eval_vec_);
return Eval(op->value);
}
IntSet VisitExprDefault_(const Node* op, const Expr& e) final {
LOG(WARNING) << "cannot evaluate set type " << e->type_key();
return IntSet::everything();
}
// binary operator
template<typename T>
inline IntSet Binary(const T* op, const Expr& e, IntSetEvaluator* m) {
IntSet a = m->Eval(op->a);
IntSet b = m->Eval(op->b);
if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
return IntSet::single_point(e);
private:
template<typename T>
inline IntSet Binary(const T* op, const Expr& e) {
IntSet a = this->Eval(op->a);
IntSet b = this->Eval(op->b);
if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
return IntSet::single_point(e);
}
return Combine<T>(a, b);
}
return Combine<T>(a, b);
}
TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<Add>(Binary<Add>)
.set_dispatch<Sub>(Binary<Sub>)
.set_dispatch<Mul>(Binary<Mul>)
.set_dispatch<Div>(Binary<Div>)
.set_dispatch<Mod>(Binary<Mod>)
.set_dispatch<Min>(Binary<Min>)
.set_dispatch<Max>(Binary<Max>)
.set_dispatch<EQ>(Binary<EQ>)
.set_dispatch<NE>(Binary<NE>)
.set_dispatch<LT>(Binary<LT>)
.set_dispatch<LE>(Binary<LE>)
.set_dispatch<GT>(Binary<GT>)
.set_dispatch<GE>(Binary<GE>)
.set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>);
const std::unordered_map<const Variable*, IntSet>& dom_map_;
bool eval_vec_{false};
};
IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
return IntSetEvaluator(dom_map).Eval(e);
return IntSetEvaluator(dom_map, false).Eval(e);
}
IntSet IntSet::vector(Expr x) {
std::unordered_map<const Variable*, IntSet> dmap;
return IntSetEvaluator(dmap, true).Eval(x);
}
IntSet EvalSet(Expr e,
......@@ -521,12 +571,13 @@ IntSet EvalSet(Range r,
class SubExprIntSetEvaluator : public IntSetEvaluator {
public:
explicit SubExprIntSetEvaluator(const std::unordered_map<const Variable*, IntSet>& dom_map)
explicit SubExprIntSetEvaluator(
const std::unordered_map<const Variable*, IntSet>& dom_map)
: IntSetEvaluator(dom_map) {}
inline IntSet Eval(Expr expr) override {
IntSet ret = IntSetEvaluator::Eval(expr);
expr_map[expr] = ret;
IntSet VisitExpr(const Expr& n, const Expr& e) final {
IntSet ret = IntSetEvaluator::VisitExpr(n, e);
expr_map[n] = ret;
return ret;
}
......
......@@ -5,6 +5,14 @@ def test_basic():
assert s.min().value == 2
assert s.max().value == 3
def test_vector():
base = 10
stride = 3
lanes = 2
s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes))
assert s.min().value == base
assert s.max().value == base + stride * lanes - 1
def test_deduce():
a = tvm.var('a')
b = tvm.var('b')
......@@ -59,5 +67,6 @@ def test_check():
if __name__ == "__main__":
test_basic()
test_vector()
test_deduce()
test_check()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment