Commit cea88d00 by tqchen

Skeleton of bound inference passing rule

parent f650216b
......@@ -171,4 +171,13 @@ inline IterVar::operator Expr() const {
}
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::IterVar> {
std::size_t operator()(const ::tvm::IterVar& k) const {
return k.hash();
}
};
}
#endif // TVM_EXPR_H_
......@@ -2,5 +2,5 @@
- c_api C API related functions
- lang The definition of DSL related data structure
- schedule The Schedule->Stmt generation logic
- codegen Backend code generation related
\ No newline at end of file
- pass The optimization pass on the IR structure
- bound Bound inference logics.
/*!
* Copyright (c) 2016 by Contributors
* \file bound.cc
* \brief The bound inference logic.
*/
#include <tvm/ir.h>
#include "./int_set.h"
#include "./bound.h"
namespace tvm {
namespace bound {
// result = ceil((a / b)), both a and b are positive integer
inline Expr DivCeil(Expr a, Expr b) {
return (a + b - 1) / b;
}
// Downward message passing algorithm on schedule s,
// pass the range state down from the root to the leaves
// after this pass, every IterVar in the schedule hyper graph will have a range(domain)
void PassDown(const Schedule& s,
std::unordered_map<IterVar, Range>* p_state) {
auto& state = *p_state;
// forwar iteration on relations
for (size_t i = 0; i < s->relations.size(); ++i) {
IterVarRelation rel = s->relations[i];
if (rel.as<SplitNode>()) {
const SplitNode* r = rel.as<SplitNode>();
CHECK(state.count(r->parent));
CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent);
if (r->factor.defined()) {
state[r->inner] = Range::make_with_min_extent(0, r->factor);
if (r->outer->dom.defined()) {
state[r->outer] = r->outer->dom;
} else {
CHECK(!state.count(r->outer));
state[r->outer] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor));
}
} else {
CHECK(r->outer->dom.defined());
state[r->outer] = r->outer->dom;
state[r->inner] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->outer->dom->extent));
}
} else if (rel.as<FuseNode>()) {
const FuseNode* r = rel.as<FuseNode>();
CHECK(state.count(r->outer));
CHECK(state.count(r->inner));
const Range& range_outer = state.at(r->outer);
const Range& range_inner = state.at(r->inner);
state[r->fused] = Range::make_with_min_extent(
0, range_outer->extent * range_inner->extent);
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
// upward message passing algorithm
// pass the integer set on each leave loop up to the root
// dom_map is the result of PassDown, it records the domain of each IterVar.
// dom_map can be used to get cached result in reverse construction.
void PassUp(const Schedule& s,
const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state) {
auto& state = *p_state;
for (size_t i = s->relations.size(); i != 0;--i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
IntSet parent;
const SplitNode* r = rel.as<SplitNode>();
IntSet::PassUp(
r, dom_map,
state.at(r->outer), state.at(r->inner),
&parent);
state[r->parent] = parent;
} else if (rel.as<FuseNode>()) {
IntSet outer, inner;
const FuseNode* r = rel.as<FuseNode>();
IntSet::PassUp(
r, dom_map,
state.at(r->fused),
&outer, &inner);
state[r->outer] = outer;
state[r->inner] = inner;
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
} // namespace bound
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file bound.h
* \brief The bound inference logics on the schedule.
*/
#ifndef TVM_BOUND_BOUND_H_
#define TVM_BOUND_BOUND_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <unordered_map>
namespace tvm {
namespace bound {
/*!
* \brief Infer the bound of all iteration variables relates to the schedule.
*
* \param sch The root schedule to infer all the bounds.
* \return the result bound of the iteration Variable
*/
std::unordered_map<IterVar, Range> InferBound(Schedule sch);
} // namespace bound
} // namespace tvm
#endif // TVM_BOUND_BOUND_H_
/*!
* Copyright (c) 2016 by Contributors
* \file int_set.h
* \brief Abstract class for iteration integer sets.
*/
#ifndef TVM_BOUND_INT_SET_H_
#define TVM_BOUND_INT_SET_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
namespace tvm {
namespace bound {
/*!
* \brief abstract class of integer set for iteration sets.
*/
class IntSet {
public:
// constructor
IntSet();
// whether the set is same as range
bool SameAs(const Range& r) const;
// make integer set by range
static IntSet make(Range r);
// make integer set as a constant value
static IntSet make(Expr value);
// upward inference function
// get the int set of parent given int set of outer and inner
static void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent);
// upward inference function
// get the int set of outer and inner given int set of fused.
static void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner);
};
} // namespace bound
} // namespace tvm
#endif // TVM_BOUND_INT_SET_H_
......@@ -148,8 +148,9 @@ Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*)
return *this;
}
Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer,
IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner,
Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor) { // NOLINT(*)
split(x_parent, p_x_outer, p_x_inner, x_factor);
......
......@@ -11,71 +11,6 @@ namespace tvm {
namespace ir {
namespace {
/*!
* \brief make nest loops given list of stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body The inner-most body of the loop
*/
Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) {
while (!nest.empty()) {
Stmt s = std::move(nest.back()); nest.pop_back();
if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>());
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
n->body = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
}
return body;
}
Stmt MakePipeline(const Schedule& sch, Stmt body) {
return body;
}
// inject the operator's realization on the stmt.
class InjectRealize : public IRMutator {
public:
explicit InjectRealize(Schedule sch)
: sch_(sch) {}
Stmt Mutate(Stmt stmt) final {
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr) {
attr_scope_.Push({op->node, op->type_key}, op->value);
stmt = IRMutator::Mutate(stmt);
attr_scope_.Pop({op->node, op->type_key});
} else {
stmt = IRMutator::Mutate(stmt);
}
if (op != nullptr &&
op->type_key == "split" &&
op->node == sch_->attach_parent) {
return AttrStmt::make(
op->node, op->type_key, op->value,
MakePipeline(sch_, op->body));
} else {
return stmt;
}
}
private:
// the operations to be carried
Schedule sch_;
Scope<AttrKey, Expr> attr_scope_;
};
} // namespace
} // namespace ir
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment