Commit 0f693212 by tqchen

Pass first basic case of bound inference

parent c5395a1f
Subproject commit adaea9e85bc0a213d4eb63edfa4762f2147c73ec
Subproject commit 5d1bd103c2abe19392b4d8def7e3ff1c854e8683
......@@ -6,8 +6,9 @@
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include "../schedule/bound.h"
#include "./c_api_registry.h"
#include "../schedule/bound.h"
#include "../schedule/graph.h"
namespace tvm {
namespace schedule {
......@@ -20,8 +21,16 @@ using RetValue = APIVariantValue;
*ret = PassName(args.at(0)); \
}) \
#define REGISTER_SCHEDULE_PASS2(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0), args.at(1)); \
}) \
REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
} // namespace schedule
} // namespace tvm
......@@ -9,6 +9,11 @@
namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
p->stream << "op(" << op << ")";
});
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
......
......@@ -7,6 +7,7 @@
#include <tvm/ir_visitor.h>
#include "./int_set.h"
#include "./bound.h"
#include "./graph.h"
namespace tvm {
namespace schedule {
......@@ -62,7 +63,7 @@ void PassDown(const Schedule& s,
// pass the integer set on each leave loop up to the root
// dom_map is the result of PassDown, it records the domain of each IterVar.
// dom_map can be used to get cached result in reverse construction.
void PassUp(const Schedule& s,
void PassUp(const ScheduleNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state) {
auto& state = *p_state;
......@@ -89,62 +90,145 @@ void PassUp(const Schedule& s,
}
}
void PassBound(
/*!
* \brief Pass the bound of tensor read
* to the corresponding bound of the IterVar of operation
* \param tensor The tensor to be passed.
* \param dim_bounds The read index set on each dimension.
* \param The result IterVar bound .
*/
void PassToOperation(
const Tensor& tensor,
const std::vector<IntSet>& arg_bounds,
const std::vector<IntSet>& dim_bounds,
std::unordered_map<IterVar, std::vector<IntSet> >* result) {
if (tensor->op.as<ComputeOpNode>()) {
auto root_iter_vars = tensor->op->root_iter_vars();
CHECK_EQ(tensor.ndim(), root_iter_vars.size());
for (size_t i = 0; i < tensor.ndim(); ++i) {
(*result)[root_iter_vars[i]].push_back(arg_bounds[i]);
(*result)[root_iter_vars[i]].push_back(dim_bounds[i]);
}
} else {
LOG(FATAL) << "unknown operation mode";
}
}
void PassBound(
Operation op,
std::unordered_map<IterVar, IntSet>* ebound) {
if (op.as<ComputeOpNode>()) {
auto fvisit = [ebound](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_);
std::vector<IntSet> arg_bounds;
for (size_t i = 0; i < t.ndim(); ++i) {
arg_bounds.push_back(Eval(call->args[i], *ebound));
}
/*!
* \brief Recursively propagate bound
* \param post_order The propagation order.
* \param dom_map The domain map to be propagated
* \return The result bound
*/
std::unordered_map<IterVar, IntSet>
BoundProp(const Array<Operation>& post_order,
std::unordered_map<IterVar, std::vector<IntSet> > *p_state) {
std::unordered_map<IterVar, IntSet> result;
for (size_t i = post_order.size(); i != 0; --i) {
Operation op = post_order[i - 1];
if (op.as<ComputeOpNode>()) {
for (auto iv : op->root_iter_vars()) {
CHECK(p_state->count(iv))
<< "Bound of root operator must exists";
CHECK(!result.count(iv));
result[iv] = Union(p_state->at(iv));
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
} else {
LOG(FATAL) << "unknown operation mode";
auto fvisit = [p_state, &result](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_);
if (t->op.defined()) {
std::vector<IntSet> arg_bounds;
for (size_t i = 0; i < t.ndim(); ++i) {
arg_bounds.push_back(EvalSet(call->args[i], result));
}
PassToOperation(t, arg_bounds, p_state);
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
} else {
LOG(FATAL) << "unknown operation mode";
}
}
return result;
}
void InferBound(const Schedule& sch,
std::unordered_map<IterVar, Range>* rmap) {
CHECK_NE(sch->attach_type, kNone);
// check if scope
bool ScopeRelax(const IterVar& iv, const std::string& scope) {
if (iv->thread_tag.length() == 0) return false;
if (scope.length() == 0) return false;
static std::unordered_map<std::string, int> scope_rank{
{"global", 0},
{"shared", 1},
{"local", 2}
};
return scope_rank.at(scope) <= scope_rank.at(iv->thread_tag);
}
void InferBound(
const ScheduleNode* parent,
const Schedule& sch,
std::unordered_map<IterVar, Range>* rmap) {
if (sch->attach_type == kInline) return;
if (sch->attach_type == kRoot) {
if (sch->attach_type == kRoot || sch->attach_type == kNone) {
auto root_iter_vars = sch->op->root_iter_vars();
for (size_t i = 0; i < root_iter_vars.size(); ++i) {
auto v = root_iter_vars[i];
CHECK(v->dom.defined());
CHECK(!rmap->count(v));
(*rmap)[v] = v->dom;
for (auto iv : root_iter_vars) {
CHECK(iv->dom.defined());
CHECK(!rmap->count(iv));
(*rmap)[iv] = iv->dom;
}
}
// get range of all child iter vars.
PassDown(sch, rmap);
// pass iteration variable to children
if (sch->attach_type == kScope) {
CHECK(parent != nullptr);
auto g = CreateReadGraph(parent->op);
auto post_order = PostDFSOrder(parent->op, g);
std::unordered_map<IterVar, IntSet> up_state;
bool fix_value = true;
for (auto iv : parent->leaf_iter_vars) {
if (fix_value && !ScopeRelax(iv, sch->scope)) {
up_state[iv] = IntSet::make_point(iv->var);
} else {
up_state[iv] = IntSet::make_range(rmap->at(iv));
}
if (sch->attach_parent == iv) {
fix_value = false;
}
}
// get the bound of the root IterVars given the current condition
PassUp(parent, *rmap, &up_state);
std::unordered_map<IterVar, std::vector<IntSet> > bp_state;
for (auto iv : parent->op->root_iter_vars()) {
CHECK(up_state.count(iv));
bp_state[iv] = {up_state.at(iv)};
}
auto result = BoundProp(post_order, &bp_state);
for (auto iv : sch->op->root_iter_vars()) {
CHECK(result.count(iv));
CHECK(!rmap->count(iv));
(*rmap)[iv] = result.at(iv).GetCoverRange();
}
}
// also call infer bound on children
for (Schedule child : sch->children) {
InferBound(sch.operator->(), child, rmap);
}
}
Map<IterVar, Range> InferBound(Schedule sch) {
return {};
std::unordered_map<IterVar, Range> ret;
CHECK(sch->attach_type != kInline && sch->attach_type != kScope)
<< "the Schedule is not a root Schedule";
InferBound(nullptr, sch, &ret);
return Map<IterVar, Range>(ret.begin(), ret.end());
}
} // namespace schedule
......
......@@ -14,26 +14,29 @@ namespace schedule {
// construct a read graph that gives readers of each operation
// that the root depend on
ReadGraph CreateReadGraph(Operation root) {
std::unordered_map<Operation, std::vector<Tensor> > rmap;
rmap[root] = {};
ReadGraph CreateReadGraph(const Operation& root) {
ReadGraph rmap;
std::vector<Operation> stack{root};
std::unordered_set<const Node*> visited{root.get()};
while (!stack.empty()) {
Operation r = stack.back();
Operation op = stack.back();
stack.pop_back();
auto& vec = rmap.at(r);
if (r.as<ComputeOpNode>()) {
auto fvisit = [&vec, &rmap, &stack](const NodeRef& n) {
Array<Tensor> deps;
if (op.as<ComputeOpNode>()) {
auto fvisit = [&deps, &visited, &stack](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_);
vec.push_back(t);
if (t->op.defined() && rmap.count(t->op) == 0) {
rmap[t->op] = {}; stack.push_back(t->op);
deps.push_back(t);
if (t->op.defined() && visited.count(t->op.get()) == 0) {
visited.insert(t->op.get());
stack.push_back(t->op);
}
}
};
ir::PostOrderVisit(r.as<ComputeOpNode>()->body, fvisit);
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
rmap.Set(op, deps);
} else {
LOG(FATAL) << "unknown operation mode";
}
......@@ -43,9 +46,9 @@ ReadGraph CreateReadGraph(Operation root) {
void PostDFSOrder(const Operation& op,
const ReadGraph& g,
std::unordered_set<Operation>* visited,
std::vector<Operation>* post_order) {
const ReadGraph& g,
std::unordered_set<Operation>* visited,
Array<Operation>* post_order) {
visited->insert(op);
for (const auto& t : g.at(op)) {
if (t->op.defined() && !visited->count(t->op)) {
......@@ -55,10 +58,10 @@ void PostDFSOrder(const Operation& op,
post_order->push_back(op);
}
std::vector<Operation> PostDFSOrder(
Array<Operation> PostDFSOrder(
const Operation& root, const ReadGraph& g) {
std::unordered_set<Operation> visited;
std::vector<Operation> post_order;
Array<Operation> post_order;
PostDFSOrder(root, g, &visited, &post_order);
return post_order;
}
......
......@@ -17,7 +17,7 @@ namespace schedule {
/*!
* \brief data structure of Operation->Tensors it reads
*/
using ReadGraph = std::unordered_map<Operation, std::vector<Tensor> >;
using ReadGraph = Map<Operation, Array<Tensor> >;
/*!
* \brief Get read graph of each operation to all the
......@@ -38,7 +38,7 @@ ReadGraph CreateReadGraph(const Operation& root);
* \note PostDFSOrder is a special case of Topoligical order,
* and can be used when topoligical order is needed.
*/
std::vector<Operation> PostDFSOrder(
Array<Operation> PostDFSOrder(
const Operation& root, const ReadGraph& g);
} // namespace schedule
......
......@@ -176,17 +176,37 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ')';
});
IntSet IntSet::make(Range dom) {
IntSet IntSet::make_range(Range dom) {
auto n = std::make_shared<IntSetNode>();
n->base = dom;
return IntSet(n);
}
Range IntSet::GetCoverRange() const {
const IntSetNode* s = operator->();
CHECK(s != nullptr) << "empty set";
if (s->domain.size() == 0 && s->concrete.size() == 0) {
return s->base;
}
LOG(FATAL) << "not yet implemented";
return Range();
}
IntSet IntSet::make_point(Expr point) {
return IntSet::make_range(Range::make_with_min_extent(point, 1));
}
IntSet IntSet::make_all_set() {
LOG(FATAL) << "TODO";
return IntSet();
}
IntSet Union(const Array<IntSet>& set) {
if (set.size() == 1) return set[0];
LOG(FATAL) << "TODO";
return IntSet();
}
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
......@@ -197,7 +217,7 @@ void PassUp(const SplitNode* s,
dom_map.count(s->parent) &&
Match(outer, dom_map.at(s->outer)) &&
Match(inner, dom_map.at(s->inner))) {
*parent = IntSet::make(dom_map.at(s->parent));
*parent = IntSet::make_range(dom_map.at(s->parent));
return;
}
// copy construct
......@@ -230,21 +250,21 @@ void PassUp(const FuseNode* s,
CHECK(dom_map.count(s->fused));
if (Match(fused, dom_map.at(s->fused))) {
*outer = IntSet::make(dom_map.at(s->outer));
*inner = IntSet::make(dom_map.at(s->inner));
*outer = IntSet::make_range(dom_map.at(s->outer));
*inner = IntSet::make_range(dom_map.at(s->inner));
return;
}
if (IsNumber(fused)) {
Expr value = AsNumber(fused);
Expr factor = dom_map.at(s->outer)->extent;
*outer = IntSet::make(Range::make_with_min_extent(value / factor, 1));
*inner = IntSet::make(Range::make_with_min_extent(value % factor, 1));
*outer = IntSet::make_point(value / factor);
*inner = IntSet::make_point(value % factor);
} else {
LOG(WARNING) << "use fallback inference rule in fuse";
// simply use the entire set, this rule can be enhanced.
*outer = IntSet::make(dom_map.at(s->outer));
*inner = IntSet::make(dom_map.at(s->inner));
*outer = IntSet::make_range(dom_map.at(s->outer));
*inner = IntSet::make_range(dom_map.at(s->inner));
return;
}
}
......@@ -272,7 +292,7 @@ class IRSetEvaluator {
};
inline IntSet ConstOp(const NodeRef&, const Expr& e, IRSetEvaluator*) {
return IntSet::make(Range::make_with_min_extent(e, 1));
return IntSet::make_point(e);
}
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
......@@ -286,7 +306,7 @@ TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
if (it != m->dom_map.end()) {
return it->second;
} else {
return IntSet::make(Range::make_with_min_extent(e, 1));
return IntSet::make_point(e);
}
});
......@@ -298,10 +318,9 @@ inline IntSet Binary(const T* op, const Expr& e, IRSetEvaluator* m) {
if (IsNumber(a) && IsNumber(b)) {
if (Match(a, op->a) &&
Match(b, op->b)) {
return IntSet::make(Range::make_with_min_extent(e, 1));
return IntSet::make_point(e);
} else {
return IntSet::make(Range::make_with_min_extent(
T::make(AsNumber(a), AsNumber(b)), 1));
return IntSet::make_point(T::make(AsNumber(a), AsNumber(b)));
}
} else {
return BinaryCombine<T>(a, b);
......@@ -319,7 +338,7 @@ TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
// use simply bound for logical expressions for now.
inline IntSet Logical(const NodeRef&, const Expr& e, IRSetEvaluator*) {
return IntSet::make(Range::make_with_min_extent(0, 2));
return IntSet::make_range(Range::make_with_min_extent(0, 2));
}
TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
......@@ -334,8 +353,8 @@ TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable)
} // namespace
IntSet Eval(Expr e,
const std::unordered_map<IterVar, IntSet>& dom_map) {
IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map) {
IRSetEvaluator m;
for (auto kv : dom_map) {
m.dom_map[kv.first->var.as<Variable>()] = kv.second;
......
......@@ -29,6 +29,10 @@ class IntSet : public NodeRef {
return !defined();
}
/*!
* \return a range that covers the IntSet
*/
Range GetCoverRange() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
......@@ -37,7 +41,12 @@ class IntSet : public NodeRef {
* \param dom The domain to be created.
* \return create integer set from existing domain
*/
static IntSet make(Range dom);
static IntSet make_range(Range dom);
/*!
* \param point
* \return create integer set that only contains one point
*/
static IntSet make_point(Expr point);
/*!
* \return create integer set that represents everything
*/
......@@ -52,8 +61,8 @@ class IntSet : public NodeRef {
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet Eval(Expr e,
const std::unordered_map<IterVar, IntSet>& dom_map);
IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Conditional upward message passing.
*
......
......@@ -4,16 +4,32 @@ def test_bound_inference():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j])
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3)
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, xi = sA1.split(A1.op.dim_var[0], factor=8)
sA2.compute_at(sA1, xi)
bounds = tvm.schedule.InferBound(sA1)
xo, xi = sA2.split(A2.op.dim_var[0], 8)
sA1.compute_at(sA2, xo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
print(bounds)
print(bounds[A1.op.dim_var[0]])
print(bounds[A1.op.dim_var[1]])
def test_create_read_graph():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j])
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3)
g = tvm.schedule.CreateReadGraph(A2.op)
assert g[A2.op][0] == A1
assert g[A1.op][0] == A
post_order = tvm.schedule.PostDFSOrder(A2.op, g)
assert(post_order[0] == A1.op)
assert(post_order[1] == A2.op)
if __name__ == "__main__":
test_bound_inference()
test_create_read_graph()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment