Commit 3ba5c15b by tqchen

IntSet Evaluation, skeleton finish

parent cea88d00
......@@ -66,23 +66,21 @@ void PassUp(const Schedule& s,
const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state) {
auto& state = *p_state;
for (size_t i = s->relations.size(); i != 0;--i) {
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
IntSet parent;
const SplitNode* r = rel.as<SplitNode>();
IntSet::PassUp(
r, dom_map,
state.at(r->outer), state.at(r->inner),
&parent);
PassUp(r, dom_map,
state.at(r->outer), state.at(r->inner),
&parent);
state[r->parent] = parent;
} else if (rel.as<FuseNode>()) {
IntSet outer, inner;
const FuseNode* r = rel.as<FuseNode>();
IntSet::PassUp(
r, dom_map,
state.at(r->fused),
&outer, &inner);
PassUp(r, dom_map,
state.at(r->fused),
&outer, &inner);
state[r->outer] = outer;
state[r->inner] = inner;
} else {
......
/*!
* Copyright (c) 2016 by Contributors
* \file int_set.cc
* \brief The integer set functions
*/
#include <tvm/ir.h>
#include "./int_set.h"
namespace tvm {
namespace bound {
using namespace ir;
/*!
* \brief Internal node container of int set.
*/
class IntSetNode : public Node {
public:
/*! \brief The base range scope */
Range base;
/*! \brief additional strided domain */
Array<Range> domain;
/*! \brief The stride of each strided domain */
Array<Expr> stride;
/*!
* \brief The concrete set,
* used when concrete execution is enabled.
*/
std::vector<int32_t> concrete;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("base", &base);
v->Visit("domain", &domain);
v->Visit("stride", &stride);
}
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_NODE_TYPE_INFO(IntSetNode);
};
TVM_REGISTER_NODE_TYPE(IntSetNode);
namespace {
inline bool Match(const Expr& e, int64_t value) {
const ir::IntImm* v = e.as<ir::IntImm>();
return v != nullptr && v->value;
}
// whether a exactly matches b.
inline bool Match(const IntSet& a,
const Range& b) {
if (a->base == b &&
a->domain.size() == 0 &&
a->concrete.size() == 0) {
return true;
} else {
return false;
}
}
// whether a exactly matches b.
inline bool Match(const IntSet& a,
const Expr& b) {
if (a->domain.size() == 0 &&
a->concrete.size() == 0) {
return Match(a->base->extent, 1) && a->base->min.same_as(b);
} else {
return false;
}
}
inline bool IsNumber(const IntSet& s) {
if (s->domain.size() != 0) return false;
if (s->concrete.size() != 0) {
return s->concrete.size() == 1;
}
return Match(s->base->extent, 1);
}
inline Expr AsNumber(const IntSet& s) {
return s->base->min;
}
// set combination rule by operators
template<typename T>
inline IntSet BinaryCombine(IntSet a, IntSet b) {
LOG(WARNING) << "cannot evaluate binary op " << T::_type_key;
return IntSet::make_all_set();
}
template<>
inline IntSet BinaryCombine<Add>(IntSet a, IntSet b) {
auto n = std::make_shared<IntSetNode>(*(a.operator->()));
for (size_t i = 0; i < b->domain.size(); ++i) {
n->domain.push_back(b->domain[i]);
n->stride.push_back(b->stride[i]);
}
if (IsNumber(a)) {
n->base = Range::make_with_min_extent(
a->base->min + b->base->min,
b->base->extent);
} else if (IsNumber(b)) {
n->base = Range::make_with_min_extent(
a->base->min + b->base->min,
a->base->extent);
} else {
n->base = Range::make_with_min_extent(
a->base->min + b->base->min,
a->base->extent + b->base->extent - 1);
}
return IntSet(n);
}
inline Range Negation(Range a) {
if (Match(a->extent, 1)) {
return Range::make_with_min_extent(-a->min, a->extent);
} else {
return Range::make_with_min_extent(-(a->min + a->extent - 1), a->extent);
}
}
inline IntSet Negation(IntSet a) {
CHECK_EQ(a->concrete.size(), 0);
auto n = std::make_shared<IntSetNode>();
n->base = Negation(a->base);
for (size_t i = 0; i < a->domain.size(); ++i) {
n->domain.push_back(Negation(a->domain[i]));
n->stride.push_back(a->stride[i]);
}
return IntSet(a);
}
template<>
inline IntSet BinaryCombine<Sub>(IntSet a, IntSet b) {
return BinaryCombine<Add>(a, Negation(b));
}
inline IntSet BinaryMul(IntSet a, Expr b) {
// copy construct
if (Match(b, 1)) return a;
if (Match(b, -1)) return Negation(a);
auto n = std::make_shared<IntSetNode>();
n->base = Range::make_with_min_extent(0, 1);
n->domain.push_back(a->base);
n->stride.push_back(b);
for (size_t i = 0; i < a->domain.size(); ++i) {
n->domain.push_back(a->domain[i]);
n->stride.push_back(a->stride[i] * b);
}
return IntSet(a);
}
template<>
inline IntSet BinaryCombine<Mul>(IntSet a, IntSet b) {
if (IsNumber(a)) {
return BinaryMul(a, AsNumber(b));
} else if (IsNumber(b)) {
return BinaryMul(b, AsNumber(a));
} else {
return IntSet::make_all_set();
}
}
} // namespace
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntSetNode>([](const IntSetNode *op, IRPrinter *p) {
p->stream << "int-set(base=";
p->print(op->base);
p->stream << ')';
});
IntSet IntSet::make(Range dom) {
auto n = std::make_shared<IntSetNode>();
n->base = dom;
return IntSet(n);
}
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent) {
if (dom_map.count(s->outer) &&
dom_map.count(s->inner) &&
dom_map.count(s->parent) &&
Match(outer, dom_map.at(s->outer)) &&
Match(inner, dom_map.at(s->inner))) {
*parent = IntSet::make(dom_map.at(s->parent));
return;
}
// copy construct
auto n = std::make_shared<IntSetNode>(*(inner.operator->()));
if (IsNumber(outer)) {
// shift the base offset
n->base = Range::make_with_min_extent(
AsNumber(outer) * s->factor + inner->base->min,
inner->base->extent);
*parent = IntSet(n);
} else {
// default use all domains in the data.
n->domain.push_back(outer->base);
n->stride.push_back(s->factor);
for (size_t i = 0; i < outer->domain.size(); ++i) {
n->domain.push_back(outer->domain[i]);
n->stride.push_back(outer->stride[i] * s->factor);
}
}
}
void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner) {
CHECK(dom_map.count(s->outer));
CHECK(dom_map.count(s->inner));
CHECK(dom_map.count(s->fused));
if (Match(fused, dom_map.at(s->fused))) {
*outer = IntSet::make(dom_map.at(s->outer));
*inner = IntSet::make(dom_map.at(s->inner));
return;
}
if (IsNumber(fused)) {
Expr value = AsNumber(fused);
Expr factor = dom_map.at(s->outer)->extent;
*outer = IntSet::make(Range::make_with_min_extent(value / factor, 1));
*inner = IntSet::make(Range::make_with_min_extent(value % factor, 1));
} else {
LOG(WARNING) << "use fallback inference rule in fuse";
// simply use the entire set, this rule can be enhanced.
*outer = IntSet::make(dom_map.at(s->outer));
*inner = IntSet::make(dom_map.at(s->inner));
return;
}
}
namespace {
// evaluator to evaluate the int set
class IRSetEvaluator {
public:
inline IntSet Eval(Expr expr) {
static const FType& f = vtable();
if (f.can_dispatch(expr)) {
return f(expr, expr, this);
} else {
LOG(WARNING) << "cannot evaluate set type " << expr->type_key();
return IntSet::make_all_set();
}
}
using FType = tvm::IRFunctor<IntSet (const NodeRef&, const Expr&, IRSetEvaluator *)>;
static FType& vtable() { // NOLINT(*)
static FType inst; return inst;
}
std::unordered_map<const Variable*, IntSet> dom_map;
};
inline IntSet ConstOp(const NodeRef&, const Expr& e, IRSetEvaluator*) {
return IntSet::make(Range::make_with_min_extent(e, 1));
}
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
.set_dispatch<IntImm>(ConstOp)
.set_dispatch<UIntImm>(ConstOp)
.set_dispatch<FloatImm>(ConstOp);
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
.set_dispatch<Variable>([](const Variable* op, const Expr& e, IRSetEvaluator* m) {
auto it = m->dom_map.find(op);
if (it != m->dom_map.end()) {
return it->second;
} else {
return IntSet::make(Range::make_with_min_extent(e, 1));
}
});
// binary operator
template<typename T>
inline IntSet Binary(const T* op, const Expr& e, IRSetEvaluator* m) {
IntSet a = m->Eval(op->a);
IntSet b = m->Eval(op->b);
if (IsNumber(a) && IsNumber(b)) {
if (Match(a, op->a) &&
Match(b, op->b)) {
return IntSet::make(Range::make_with_min_extent(e, 1));
} else {
return IntSet::make(Range::make_with_min_extent(
T::make(AsNumber(a), AsNumber(b)), 1));
}
} else {
return BinaryCombine<T>(a, b);
}
}
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
.set_dispatch<Add>(Binary<Add>)
.set_dispatch<Sub>(Binary<Sub>)
.set_dispatch<Mul>(Binary<Mul>)
.set_dispatch<Div>(Binary<Div>)
.set_dispatch<Mod>(Binary<Mod>)
.set_dispatch<Min>(Binary<Min>)
.set_dispatch<Max>(Binary<Max>);
// use simply bound for logical expressions for now.
inline IntSet Logical(const NodeRef&, const Expr& e, IRSetEvaluator*) {
return IntSet::make(Range::make_with_min_extent(0, 2));
}
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
.set_dispatch<EQ>(Logical)
.set_dispatch<NE>(Logical)
.set_dispatch<LT>(Logical)
.set_dispatch<LE>(Logical)
.set_dispatch<GT>(Logical)
.set_dispatch<GE>(Logical)
.set_dispatch<And>(Logical)
.set_dispatch<Or>(Logical);
} // namespace
IntSet Eval(Expr e,
const std::unordered_map<IterVar, IntSet>& dom_map) {
IRSetEvaluator m;
for (auto kv : dom_map) {
m.dom_map[kv.first->var.as<Variable>()] = kv.second;
}
return m.Eval(e);
}
} // namespace bound
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file int_set.h
* \brief Abstract class for iteration integer sets.
* \brief Abstraction for all integer set operations.
*/
#ifndef TVM_BOUND_INT_SET_H_
#define TVM_BOUND_INT_SET_H_
......@@ -11,35 +11,92 @@
namespace tvm {
namespace bound {
// internal node container of int set.
class IntSetNode;
/*!
* \brief abstract class of integer set for iteration sets.
* \brief Integer set class, represent a set of integers in one dimension.
*/
class IntSet {
class IntSet : public NodeRef {
public:
// constructor
IntSet();
// whether the set is same as range
bool SameAs(const Range& r) const;
// make integer set by range
static IntSet make(Range r);
// make integer set as a constant value
static IntSet make(Expr value);
// upward inference function
// get the int set of parent given int set of outer and inner
static void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent);
// upward inference function
// get the int set of outer and inner given int set of fused.
static void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner);
/*! \brief constructor */
IntSet() {}
// constructor from not deontainer.
explicit IntSet(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! \return whether the set is empty */
inline bool is_empty() const {
return !defined();
}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IntSetNode* operator->() const;
/*!
* \param dom The domain to be created.
* \return create integer set from existing domain
*/
static IntSet make(Range dom);
/*!
* \return create integer set that represents everything
*/
static IntSet make_all_set();
};
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
*
* \param e The expression to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet Eval(Expr e,
const std::unordered_map<IterVar, IntSet>& dom_map);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Split relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param outer domain of outer iteration.
* \param inner domain of inner iteration.
* \param parent The result domain of parent.
*/
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param fused domain of fused iteration.
* \param outer The result domain of outer iteration.
* \param inner The result domain of inner iteration.
*/
void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
* \return the set after union
*/
IntSet Union(const Array<IntSet>& sets);
} // namespace bound
} // namespace tvm
......
......@@ -152,7 +152,6 @@ Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor) { // NOLINT(*)
split(x_parent, p_x_outer, p_x_inner, x_factor);
split(y_parent, p_y_outer, p_y_inner, y_factor);
reorder(Array<IterVar>({*p_x_inner, *p_y_inner, *p_x_outer, *p_y_outer}));
......
......@@ -10,8 +10,6 @@
namespace tvm {
namespace ir {
namespace {
} // namespace
} // namespace ir
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment