Commit 39314014 by Tianqi Chen Committed by Haichen Shen

[LANG] Change Schedule->Stage, Use Schedule for global schedule (#8)

* [LANG] Change Schedule->Stage, Use Schedule for global schedule

* add numpy as dep

* add numpy installation, temporary disable osx
parent e953e2e6
......@@ -4,13 +4,11 @@ language: cpp
os:
- linux
- osx
# - osx
env:
# code analysis
- TASK=lint
- TASK=cpp_test
- TASK=python_test
- TASK=all_test
branches:
only:
......@@ -35,6 +33,7 @@ addons:
- g++-4.8
- python-numpy
- python-nose
- python3-numpy
- python3-dev
- python3-nose
- graphviz
......
Subproject commit 1ec478bbd0c20b8659f0c897363b5a76e13ef495
Subproject commit 98e8df564f8543b337ec0528dbcb06a30f91e694
......@@ -12,6 +12,8 @@
namespace tvm {
// Node container for Stage
class StageNode;
// Node container for Schedule
class ScheduleNode;
// Node container for IterVarRelation
......@@ -25,46 +27,48 @@ enum AttachType : int {
kScope = 3
};
/*! \brief schedule container */
class Schedule : public NodeRef {
/*! \brief Stage, contains scheduling for a stage of computation. */
class Stage : public NodeRef {
public:
Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
Stage() {}
explicit Stage(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief create a new schedule for op.
* \param op The operator in the schedule
* \param scope The scope of the schedule
*/
Schedule(Operation op, std::string scope);
explicit Stage(Operation op);
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
inline const StageNode* operator->() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline ScheduleNode* operator->();
inline StageNode* operator->();
/*!
* \brief set the memory scope of the stage
* \param scope The memory scope.
*/
Stage& set_scope(std::string scope); // NOLINT(*)
/*!
* \brief specify the schedule to be computed at the parent schedule's scope.
* \param parent The parent schedule.
* \param scope The iteration point to carry the schedule.
* \return reference to self.
*/
Schedule& compute_at(Schedule parent, IterVar scope); // NOLINT(*)
Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
/*!
* \brief Compute the function inline, attach it at parent.
* \param parent The parent schedule to be attached to.
* \return reference to self.
*/
Schedule& compute_inline(Schedule parent); // NOLINT(*)
Stage& compute_inline(); // NOLINT(*)
/*!
* \brief Compute the function at root, attach it to its parent.
* \param parent The parent schedule to be attached to.
* \return reference to self.
*/
Schedule& compute_root(Schedule parent); // NOLINT(*)
Stage& compute_root(); // NOLINT(*)
/*!
* \brief Split the parent by factor, generate
* \param parent The parent iteration domain.
......@@ -73,7 +77,7 @@ class Schedule : public NodeRef {
* \param factor The split factor of the loop.
* \return reference to self.
*/
Schedule& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
Stage& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
/*!
* \brief Split the iteration with a given outer domain,
* the outer domain must have a thread-tag.
......@@ -85,7 +89,7 @@ class Schedule : public NodeRef {
* factor must be provided such that factor * outer.extent >= parent.extent.
* \return reference to self.
*/
Schedule& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
Stage& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
/*!
* \brief Fuse the inner outer domain to the target
* \param inner The inner domain to be fused
......@@ -93,16 +97,66 @@ class Schedule : public NodeRef {
* \param p_target The result target domain.
* \return reference to self.
*/
Schedule& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
Stage& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
/*!
* \brief Reorder the iteration
* \param order The order of iteration variable.
* \return reference to self.
*/
Schedule& reorder(const Array<IterVar>& order); // NOLINT(*)
Schedule& tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer,
IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor); // NOLINT(*)
Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
/*!
* \brief Perform tiling on two dimensions
* The final loop order from outmost to inner most are
* [x_outer, y_outer, x_inner, y_inner]
*
* \param x_parent The original x dimension
* \param y_parent The original y dimension
* \param p_x_outer Outer axis of x dimension
* \param p_y_outer Outer axis of y dimension
* \param p_x_inner Inner axis of x dimension
* \param p_y_inner Inner axis of y dimension
* \param x_factor The stride factor on x axis
* \param y_factor The stride factor on y axis
* \return reference to self.
*/
Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
};
/*!
* \brief Global schedule container
* For operations and all the operations they depend on.
* The schedule per Operation is named as stage.
*/
class Schedule : public NodeRef {
public:
Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief construct schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
*/
explicit Schedule(Array<Operation> ops);
/*!
* \brief Get the stage corresponds to the op
* \param op The operation.
*/
Stage operator[](const Operation& op);
/*!
* \brief Short hand for getting the stage of tensor's operation.
* \param tensor The tensor
* \return The stage corresponding to the tensor's op
*/
Stage operator[](const Tensor& tensor) {
return this->operator[](tensor->op);
}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
};
/*!
......@@ -135,11 +189,11 @@ class IterVarRelation : public NodeRef {
*
* The relations connects the IterVars in the graph.
*/
class ScheduleNode : public Node {
class StageNode : public Node {
public:
/*! \brief The operation to be scheduled */
Operation op;
/*! \brief The thread scope level of the schedule */
/*! \brief The thread scope level of the stage */
std::string scope;
/*! \brief All the nodes in the iter var */
Array<IterVar> all_iter_vars;
......@@ -152,12 +206,10 @@ class ScheduleNode : public Node {
Array<IterVarRelation> relations;
/*! \brief The attachment type of the schedule */
AttachType attach_type{kNone};
/*!
* \brief The attach point of this schedule.
*/
IterVar attach_parent;
/*! \brief the schedules that this schedule depend on */
Array<Schedule> children;
/*! \brief The attach point of this schedule. */
IterVar attach_ivar;
/*! \brief The stage this node attaches to */
Stage attach_stage;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
......@@ -166,8 +218,31 @@ class ScheduleNode : public Node {
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("relations", &relations);
v->Visit("attach_type", &attach_type);
v->Visit("attach_parent", &attach_parent);
v->Visit("children", &children);
v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage);
}
static constexpr const char* _type_key = "Stage";
TVM_DECLARE_NODE_TYPE_INFO(StageNode);
};
/*! \brief node container for schedule */
class ScheduleNode : public Node {
public:
/*! \brief The root operations */
Array<Operation> roots;
/*!
* \brief list of all stages for non-placeholder ops
* The stage are ordered in PostDFS order of their op.
*/
Array<Stage> stages;
/*! \brief map of operation to the stages */
Map<Operation, Stage> stage_map;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("roots", &roots);
v->Visit("stages", &stages);
v->Visit("stage_map", &stage_map);
}
static constexpr const char* _type_key = "Schedule";
......@@ -234,12 +309,16 @@ class FuseNode : public IterVarRelationNode {
};
// implementations
inline const StageNode* Stage::operator->() const {
return static_cast<const StageNode*>(node_.get());
}
inline StageNode* Stage::operator->() {
return static_cast<StageNode*>(node_.get());
}
inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
inline ScheduleNode* Schedule::operator->() {
return static_cast<ScheduleNode*>(node_.get());
}
inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get());
......
......@@ -174,8 +174,17 @@ def max(expr, rdom):
return x
def Schedule(tensor, scope="global"):
return _function_internal._Schedule(tensor, scope)
def Schedule(ops):
"""Create a schedule for list of ops
Parameters
----------
ops : list of Operations
The source expression.
"""
if not isinstance(ops, (list, _collections.Array)):
ops = [ops]
return _function_internal._Schedule(ops)
_init_function_module("tvm")
......@@ -2,6 +2,7 @@
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
from . import tensor as _tensor
@register_node
class Split(NodeBase):
......@@ -11,10 +12,22 @@ class Split(NodeBase):
class Fuse(NodeBase):
pass
@register_node
class Schedule(NodeBase):
def __getitem__(self, k):
if isinstance(k, _tensor.Tensor):
k = k.op
if not isinstance(k, _tensor.Operation):
raise ValueError("Expect schedule key to be Tensor or Operation")
if not k in self.stage_map:
raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
return self.stage_map[k]
@register_node
class Stage(NodeBase):
def split(self, parent, factor=None, outer=None):
"""Split the schedule either by factor providing outer scope, or both
"""Split the stage either by factor providing outer scope, or both
Parameters
----------
......@@ -40,11 +53,11 @@ class Schedule(NodeBase):
raise ValueError("split by outer must have special thread_tag")
if outer.dom is None:
raise ValueError("split by outer must have specified domain")
inner = _function_internal._ScheduleSplitByOuter(self, parent, outer, factor)
inner = _function_internal._StageSplitByOuter(self, parent, outer, factor)
else:
if factor is None:
raise ValueError("either outer or factor need to be provided")
outer, inner = _function_internal._ScheduleSplitByFactor(self, parent, factor)
outer, inner = _function_internal._StageSplitByFactor(self, parent, factor)
return outer, inner
def fuse(self, inner, outer):
......@@ -63,40 +76,50 @@ class Schedule(NodeBase):
inner : IterVar
The fused variable of iteration.
"""
return _function_internal._ScheduleFuse(self, inner, outer)
return _function_internal._StageFuse(self, inner, outer)
def set_scope(self, scope):
"""Set the thread scope of this stage
Parameters
----------
scope : str
The thread scope of this stage
"""
return _function_internal._StageSetScope(self, scope)
def compute_at(self, parent, scope):
"""Attach the schedule at parent's scope
"""Attach the stage at parent's scope
Parameters
----------
parent : Schedule
The parent schedule
parent : Stage
The parent stage
scope : IterVar
The loop scope t be attached to.
"""
_function_internal._ScheduleComputeAt(self, parent, scope)
_function_internal._StageComputeAt(self, parent, scope)
def compute_inline(self, parent):
"""Attach the schedule at parent, and mark it as inline
def compute_inline(self):
"""Mark stage as inline
Parameters
----------
parent : Schedule
The parent schedule
parent : Stage
The parent stage
"""
_function_internal._ScheduleComputeInline(self, parent)
_function_internal._StageComputeInline(self)
def compute_root(self, parent):
"""Attach the schedule at parent, and mark it as root
def compute_root(self):
"""Attach the stage at parent, and mark it as root
Parameters
----------
parent : Schedule
The parent schedule
parent : Stage
The parent stage
"""
_function_internal._ScheduleComputeInline(self, parent)
_function_internal._StageComputeInline(self)
def reorder(self, *args):
"""reorder the arguments in the specified order.
......@@ -106,9 +129,9 @@ class Schedule(NodeBase):
args : list of IterVar
The order to be ordered
"""
_function_internal._ScheduleReorder(self, args)
_function_internal._StageReorder(self, args)
def tile(self, x_parent, y_parent, x_factor, y_factor):
x_outer, y_outer, x_inner, y_inner = _function_internal._ScheduleTile(
x_outer, y_outer, x_inner, y_inner = _function_internal._StageTile(
self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner
......@@ -176,7 +176,7 @@ TVM_REGISTER_API(_ComputeOp)
TVM_REGISTER_API(_OpGetOutput)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = args.at(0).operator Operation().output(
args.at(1).operator size_t());
args.at(1).operator int64_t());
});
......@@ -185,64 +185,69 @@ TVM_REGISTER_API(_IterVar)
*ret = IterVar(args.at(0), args.at(1), args.at(2));
});
TVM_REGISTER_API(_Schedule)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Schedule(args.at(0), args.at(1));
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Schedule(args.at(0).operator Array<Operation>());
});
TVM_REGISTER_API(_StageSetScope)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Stage()
.set_scope(args.at(1));
});
TVM_REGISTER_API(_ScheduleSplitByFactor)
TVM_REGISTER_API(_StageSplitByFactor)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar outer, inner;
args.at(0).operator Schedule()
args.at(0).operator Stage()
.split(args.at(1), &outer, &inner, args.at(2));
*ret = Array<IterVar>({outer, inner});
});
TVM_REGISTER_API(_ScheduleSplitByOuter)
TVM_REGISTER_API(_StageSplitByOuter)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar inner;
args.at(0).operator Schedule()
args.at(0).operator Stage()
.split(args.at(1), args.at(2), &inner, args.at(3));
*ret = inner;
});
TVM_REGISTER_API(_ScheduleFuse)
TVM_REGISTER_API(_StageFuse)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar fused;
args.at(0).operator Schedule()
args.at(0).operator Stage()
.split(args.at(1), args.at(2), &fused);
*ret = fused;
});
TVM_REGISTER_API(_ScheduleComputeAt)
TVM_REGISTER_API(_StageComputeAt)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule()
args.at(0).operator Stage()
.compute_at(args.at(1), args.at(2));
});
TVM_REGISTER_API(_ScheduleComputeInline)
TVM_REGISTER_API(_StageComputeInline)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule()
.compute_inline(args.at(1));
args.at(0).operator Stage()
.compute_inline();
});
TVM_REGISTER_API(_ScheduleComputeRoot)
TVM_REGISTER_API(_StageComputeRoot)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule()
.compute_root(args.at(1));
args.at(0).operator Stage()
.compute_root();
});
TVM_REGISTER_API(_ScheduleReorder)
TVM_REGISTER_API(_StageReorder)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule()
args.at(0).operator Stage()
.reorder(args.at(1));
});
TVM_REGISTER_API(_ScheduleTile)
TVM_REGISTER_API(_StageTile)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar x_outer, y_outer, x_inner, y_inner;
args.at(0).operator Schedule()
args.at(0).operator Stage()
.tile(args.at(1), args.at(2), &x_outer, &y_outer,
&x_inner, &y_inner, args.at(3), args.at(4));
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
......
......@@ -22,7 +22,7 @@ namespace {
* \param p_state The message passing state
* IterVar->The assignment.
*/
void PassUpOffset(const Schedule& s,
void PassUpOffset(const Stage& s,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state) {
auto& state = *p_state;
......@@ -130,7 +130,7 @@ Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
* The flattened Stmt are ordered from outmost to inner most order.
*/
std::vector<std::vector<Stmt> > MakeLoopNest(
const Schedule& sch,
const Stage& sch,
const Map<IterVar, Range>& dom_map) {
// optional, use let to define some CSE in dom_map.
auto leaf_iter_vars = sch->leaf_iter_vars;
......@@ -244,7 +244,7 @@ Stmt MakeRealize(const ComputeOpNode* op,
bounds, make_const(Bool(1), true), body);
}
Stmt MakePipeline(const Schedule& sch,
Stmt MakePipeline(const Stage& sch,
const Map<IterVar, Range>& dom_map,
Stmt consumer) {
std::vector<Tensor> tensors;
......@@ -280,7 +280,7 @@ Stmt MakePipeline(const Schedule& sch,
// inject the operator's realization on the stmt.
class InjectRealize : public IRMutator {
public:
InjectRealize(Schedule schedule, Map<IterVar, Range> dom_map)
InjectRealize(Stage schedule, Map<IterVar, Range> dom_map)
: schedule(schedule), dom_map(dom_map) {}
Stmt Mutate(Stmt stmt) final {
......@@ -289,7 +289,7 @@ class InjectRealize : public IRMutator {
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr &&
op->type_key == "scope") {
if (op->node == schedule->attach_parent) {
if (op->node == schedule->attach_ivar) {
CHECK(!found_attach);
found_attach = true;
stmt = AttrStmt::make(
......@@ -301,41 +301,13 @@ class InjectRealize : public IRMutator {
return stmt;
}
// the operations to be carried
Schedule schedule;
Stage schedule;
// domain map
Map<IterVar, Range> dom_map;
// whether attach point is found
bool found_attach{false};
};
void GetOpToScheduleMap(
Schedule s,
std::unordered_map<Operation, Schedule>* ret) {
CHECK(!ret->count(s->op))
<< "Duplicated schedule for op";
(*ret)[s->op] = s;
for (Schedule c : s->children) {
GetOpToScheduleMap(c, ret);
}
}
// order schedule by DFS calling order of ops
std::vector<Schedule> OrderSchedule(Schedule s) {
auto g = schedule::CreateReadGraph(s->op);
auto post_order = schedule::PostDFSOrder(s->op, g);
std::unordered_map<Operation, Schedule> op2sch;
GetOpToScheduleMap(s, &op2sch);
std::vector<Schedule> sorder;
// reverse iteration.
for (size_t i = post_order.size(); i != 0; --i) {
sorder.push_back(op2sch.at(post_order[i - 1]));
}
return sorder;
}
Stmt InjectInline(const Operation op, Stmt body) {
CHECK(body.defined());
const ComputeOpNode* compute = op.as<ComputeOpNode>();
......@@ -351,11 +323,11 @@ Stmt InjectInline(const Operation op, Stmt body) {
} // namespace
Stmt ScheduleOps(
Schedule s, Map<IterVar, Range> dom_map) {
std::vector<Schedule> svec = OrderSchedule(s);
Schedule sch, Map<IterVar, Range> dom_map) {
Stmt body = Stmt();
for (Schedule s : svec) {
// reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1];
if (s->attach_type == kInline) {
body = InjectInline(s->op, body);
} else if (s->attach_type == kRoot || s-> attach_type == kNone) {
......
......@@ -17,10 +17,10 @@ inline Expr DivCeil(Expr a, Expr b) {
return (a + b - 1) / b;
}
// Downward message passing algorithm on schedule s,
// Downward message passing algorithm on stage schedule s,
// pass the range state down from the root to the leaves
// after this pass, every IterVar in the schedule hyper graph will have a range(domain)
void PassDown(const Schedule& s,
// after this pass, every IterVar in the stage hyper graph will have a range(domain)
void PassDown(const Stage& s,
std::unordered_map<IterVar, Range>* p_state) {
auto& state = *p_state;
// forwar iteration on relations
......@@ -63,7 +63,7 @@ void PassDown(const Schedule& s,
// pass the integer set on each leave loop up to the root
// dom_map is the result of PassDown, it records the domain of each IterVar.
// dom_map can be used to get cached result in reverse construction.
void PassUp(const ScheduleNode* s,
void PassUp(const Stage& s,
const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state) {
auto& state = *p_state;
......@@ -180,13 +180,11 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
return scope_rank.at(scope) <= thread_tag_rank.at(iv->thread_tag);
}
void InferBound(
const ScheduleNode* parent,
const Schedule& sch,
std::unordered_map<IterVar, Range>* rmap) {
if (sch->attach_type == kInline) return;
if (sch->attach_type == kRoot || sch->attach_type == kNone) {
auto root_iter_vars = sch->op->root_iter_vars();
void InferBound(const Stage& stage,
std::unordered_map<IterVar, Range>* rmap) {
if (stage->attach_type == kInline) return;
if (stage->attach_type == kRoot || stage->attach_type == kNone) {
auto root_iter_vars = stage->op->root_iter_vars();
for (auto iv : root_iter_vars) {
CHECK(iv->dom.defined());
CHECK(!rmap->count(iv));
......@@ -194,22 +192,23 @@ void InferBound(
}
}
// get range of all child iter vars.
PassDown(sch, rmap);
PassDown(stage, rmap);
if (sch->attach_type == kScope) {
CHECK(parent != nullptr);
auto g = CreateReadGraph(parent->op);
auto post_order = PostDFSOrder(parent->op, g);
if (stage->attach_type == kScope) {
Stage parent = stage->attach_stage;
CHECK(parent.defined());
auto g = CreateReadGraph({parent->op});
auto post_order = PostDFSOrder({parent->op}, g);
std::unordered_map<IterVar, IntSet> up_state;
bool fix_value = true;
for (auto iv : parent->leaf_iter_vars) {
if (fix_value && !ScopeRelax(iv, sch->scope)) {
if (fix_value && !ScopeRelax(iv, stage->scope)) {
up_state[iv] = IntSet::make_point(iv->var);
} else {
up_state[iv] = IntSet::make_range(rmap->at(iv));
}
if (sch->attach_parent == iv) {
if (stage->attach_ivar == iv) {
fix_value = false;
}
}
......@@ -221,24 +220,22 @@ void InferBound(
bp_state[iv] = {up_state.at(iv)};
}
auto result = BoundProp(post_order, &bp_state);
for (auto iv : sch->op->root_iter_vars()) {
for (auto iv : stage->op->root_iter_vars()) {
CHECK(result.count(iv));
CHECK(!rmap->count(iv));
(*rmap)[iv] = result.at(iv).GetCoverRange();
}
}
// also call infer bound on children
for (Schedule child : sch->children) {
InferBound(sch.operator->(), child, rmap);
}
}
Map<IterVar, Range> InferBound(Schedule sch) {
std::unordered_map<IterVar, Range> ret;
CHECK(sch->attach_type != kInline && sch->attach_type != kScope)
<< "the Schedule is not a root Schedule";
InferBound(nullptr, sch, &ret);
// reverse post DFS order, from out most stage to the innermost
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage stage = sch->stages[i - 1];
InferBound(stage, &ret);
}
return Map<IterVar, Range>(ret.begin(), ret.end());
}
......
......@@ -14,10 +14,15 @@ namespace schedule {
// construct a read graph that gives readers of each operation
// that the root depend on
ReadGraph CreateReadGraph(const Operation& root) {
ReadGraph CreateReadGraph(const Array<Operation>& roots) {
ReadGraph rmap;
std::vector<Operation> stack{root};
std::unordered_set<const Node*> visited{root.get()};
std::vector<Operation> stack;
std::unordered_set<const Node*> visited;
// initialize the roots
for (Operation op : roots) {
stack.push_back(op);
visited.insert(op.get());
}
while (!stack.empty()) {
Operation op = stack.back();
......@@ -51,20 +56,22 @@ void PostDFSOrder(const Operation& op,
const ReadGraph& g,
std::unordered_set<Operation>* visited,
Array<Operation>* post_order) {
if (op.as<PlaceholderOpNode>() || visited->count(op)) return;
visited->insert(op);
for (const auto& t : g.at(op)) {
if (!t->op.as<PlaceholderOpNode>() && !visited->count(t->op)) {
PostDFSOrder(t->op, g, visited, post_order);
}
PostDFSOrder(t->op, g, visited, post_order);
}
post_order->push_back(op);
}
Array<Operation> PostDFSOrder(
const Operation& root, const ReadGraph& g) {
const Array<Operation>& roots,
const ReadGraph& g) {
std::unordered_set<Operation> visited;
Array<Operation> post_order;
PostDFSOrder(root, g, &visited, &post_order);
for (Operation op : roots) {
PostDFSOrder(op, g, &visited, &post_order);
}
return post_order;
}
......
......@@ -24,14 +24,14 @@ using ReadGraph = Map<Operation, Array<Tensor> >;
* Tensors that it directly depends on.
*
* The result map contains Operations needed to finish root Operation.
* \param root The root operation.
* \param roots The root operation.
* \return The result map.
*/
ReadGraph CreateReadGraph(const Operation& root);
ReadGraph CreateReadGraph(const Array<Operation>& roots);
/*!
* \brief Get a post DFS ordered of operations in the graph.
* \param root The root of the graph.
* \param roots The root of the graph.
* \param g The read graph.
* \return vector order of Operations in PostDFS order.
*
......@@ -39,7 +39,7 @@ ReadGraph CreateReadGraph(const Operation& root);
* and can be used when topoligical order is needed.
*/
Array<Operation> PostDFSOrder(
const Operation& root, const ReadGraph& g);
const Array<Operation>& roots, const ReadGraph& g);
} // namespace schedule
} // namespace tvm
......
......@@ -3,6 +3,7 @@
* \file schedule.cc
*/
#include <tvm/schedule.h>
#include "./graph.h"
namespace tvm {
......@@ -31,7 +32,7 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v)
return 0;
}
void Split(ScheduleNode* self, IterVar parent,
void Split(StageNode* self, IterVar parent,
IterVar outer, IterVar inner, Expr factor) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
......@@ -49,19 +50,30 @@ void Split(ScheduleNode* self, IterVar parent,
} // namespace
Schedule::Schedule(Operation op, std::string scope) {
auto n = std::make_shared<ScheduleNode>();
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<StageNode>([](const StageNode *op, IRPrinter *p) {
p->stream << "stage("
<< op->op
<< ")";
});
Stage::Stage(Operation op) {
auto n = std::make_shared<StageNode>();
n->op = op;
n->scope = scope;
n->all_iter_vars = op->root_iter_vars();
n->leaf_iter_vars = op->root_iter_vars();
node_ = n;
}
Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*)
CHECK_EQ((*this)->attach_type, kNone);
Stage& Stage::set_scope(std::string scope) { // NOLINT(*)
(*this)->scope = scope;
return *this;
}
Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
(*this)->attach_type = kScope;
(*this)->attach_parent = scope;
(*this)->attach_ivar = scope;
(*this)->attach_stage = parent;
bool found = false;
for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) {
if (scope == parent->leaf_iter_vars[i]) {
......@@ -70,25 +82,20 @@ Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*)
}
CHECK(found)
<< "Cannot compute at a iteration variable that is not part of parent leaf vars";
parent->children.push_back(*this);
return *this;
}
Schedule& Schedule::compute_inline(Schedule parent) { // NOLINT(*)
CHECK_EQ((*this)->attach_type, kNone);
Stage& Stage::compute_inline() { // NOLINT(*)
(*this)->attach_type = kInline;
parent->children.push_back(*this);
return *this;
}
Schedule& Schedule::compute_root(Schedule parent) { // NOLINT(*)
CHECK_EQ((*this)->attach_type, kNone);
Stage& Stage::compute_root() { // NOLINT(*)
(*this)->attach_type = kRoot;
parent->children.push_back(*this);
return *this;
}
Schedule& Schedule::split(
Stage& Stage::split(
IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
// place holder for the splitted results.
IterVar outer(Range(), parent->var->name_hint + ".outer");
......@@ -99,7 +106,7 @@ Schedule& Schedule::split(
return *this;
}
Schedule& Schedule::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
// place holder for the splitted results.
IterVar inner(Range(), parent->var->name_hint + ".inner");
*p_inner = inner;
......@@ -108,9 +115,9 @@ Schedule& Schedule::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr
return *this;
}
Schedule& Schedule::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*)
Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*)
IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused");
ScheduleNode* self = operator->();
StageNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
......@@ -128,8 +135,8 @@ Schedule& Schedule::fuse(IterVar inner, IterVar outer, IterVar* p_target) { //
return *this;
}
Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*)
ScheduleNode* self = operator->();
Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
StageNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
std::vector<size_t> pos;
......@@ -148,16 +155,34 @@ Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*)
return *this;
}
Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor) { // NOLINT(*)
Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor) { // NOLINT(*)
split(x_parent, p_x_outer, p_x_inner, x_factor);
split(y_parent, p_y_outer, p_y_inner, y_factor);
reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
return *this;
}
Schedule::Schedule(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>();
n->roots = ops;
auto g = schedule::CreateReadGraph(n->roots);
Array<Operation> post_order = schedule::PostDFSOrder(n->roots, g);
for (Operation op : post_order) {
Stage stage(op);
n->stages.push_back(stage);
n->stage_map.Set(op, stage);
}
node_ = std::move(n);
}
Stage Schedule::operator[](const Operation& op) {
return (*this)->stage_map.at(op);
}
IterVarRelation SplitNode::make(
IterVar parent, IterVar outer,
IterVar inner, Expr factor) {
......@@ -178,7 +203,7 @@ IterVarRelation FuseNode::make(
return IterVarRelation(n);
}
TVM_REGISTER_NODE_TYPE(ScheduleNode);
TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode);
......
......@@ -7,41 +7,37 @@ def test_schedule_create():
A = tvm.placeholder((m, l), name='A')
B = tvm.placeholder((n, l), name='B')
AA = tvm.compute((m, l), lambda i, j: A[i, j])
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
sch_T = tvm.Schedule(T.op, scope="shared")
sch_A = tvm.Schedule(AA.op, scope="global")
xo, xi = sch_T.split(T.op.axis[0], factor=10)
xi1, xi2 = sch_T.split(xi, factor=2)
sch_A.compute_at(sch_T, xi1)
xo, xi = sch_A.split(AA.op.axis[0], factor=10)
sch_T.reorder(xi2, xi1)
assert T.op.axis[1] in sch_T.leaf_iter_vars
T = tvm.compute((m, n, l), lambda i, j, k: AA(i, k) * B(j, k))
s = tvm.Schedule(T.op)
s[AA].set_scope("shared")
xo, xi = s[T].split(T.op.axis[0], factor=10)
xi1, xi2 = s[T].split(xi, factor=2)
s[AA].compute_at(s[T], xi1)
xo, xi = s[AA].split(AA.op.axis[0], factor=10)
s[T].reorder(xi2, xi1)
assert T.op.axis[1] in s[T].leaf_iter_vars
def test_reorder():
m = tvm.Var('m')
A = tvm.placeholder((m,), name='A')
T = tvm.compute(m, lambda i: A[i+1])
sch_T = tvm.Schedule(T.op, scope="shared")
xo, xi = sch_T.split(T.op.axis[0], factor=10)
xi1, xi2 = sch_T.split(xi, factor=2)
s = tvm.Schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=10)
xi1, xi2 = s[T].split(xi, factor=2)
order = (xi2, xi1, xo)
assert tuple(sch_T.leaf_iter_vars) != order
sch_T.reorder(*order)
assert tuple(sch_T.leaf_iter_vars) == order
assert tuple(s[T].leaf_iter_vars) != order
s[T].reorder(*order)
assert tuple(s[T].leaf_iter_vars) == order
def test_split():
m = tvm.Var('m')
A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i: A[i])
sT = tvm.Schedule(T.op)
xo, xi = sT.split(T.op.axis[0], factor=10)
assert tuple(sT.leaf_iter_vars) == (xo, xi)
s = tvm.Schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=10)
assert tuple(s[T].leaf_iter_vars) == (xo, xi)
def test_tile():
......@@ -50,9 +46,9 @@ def test_tile():
A = tvm.placeholder((m, n), name='A')
T = tvm.compute((m, n), lambda i, j: A[i, j])
sch_T = tvm.Schedule(T.op, scope="shared")
xo, yo, xi, yi = sch_T.tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
assert tuple(sch_T.leaf_iter_vars) == (xo, yo, xi, yi)
s = tvm.Schedule(T.op)
xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
assert tuple(s[T].leaf_iter_vars) == (xo, yo, xi, yi)
if __name__ == "__main__":
test_schedule_create()
......
......@@ -6,10 +6,12 @@ def test_schedule0():
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
sA1 = tvm.Schedule(A1.op)
bounds = tvm.schedule.InferBound(sA1)
s = tvm.Schedule(A1.op)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA1, bounds)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
print(stmt)
def test_schedule1():
......@@ -17,11 +19,12 @@ def test_schedule1():
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
sA1 = tvm.Schedule(A1.op)
xo, xi = sA1.split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(sA1)
s = tvm.Schedule(A1.op)
xo, xi = s[A1].split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA1, bounds)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
print(stmt)
def test_schedule2():
......@@ -30,13 +33,13 @@ def test_schedule2():
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, xi = sA2.split(A2.op.axis[0], 8)
sA1.compute_at(sA2, xo)
bounds = tvm.schedule.InferBound(sA2)
s = tvm.Schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], 8)
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA2, bounds)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
print(stmt)
......
......@@ -6,11 +6,11 @@ def test_bound1():
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, xi = sA2.split(A2.op.axis[0], 8)
sA1.compute_at(sA2, xo)
bounds = tvm.schedule.InferBound(sA2)
s = tvm.Schedule([A2.op])
xo, xi = s[A2].split(s[A2].op.axis[0], 8)
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.axis[0]].extent.value == 8)
......@@ -20,11 +20,10 @@ def test_bound2():
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, yo, xi, yi = sA2.tile(A2.op.axis[0], A2.op.axis[1], 8, 8)
sA1.compute_at(sA2, yo)
bounds = tvm.schedule.InferBound(sA2)
s = tvm.Schedule(A2.op)
xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8)
s[A1].compute_at(s[A2], yo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.axis[0]].extent.value == 8)
assert(bounds[A1.op.axis[1]].extent.value == 8)
......@@ -35,16 +34,18 @@ def test_bound3():
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op, scope="shared")
sA2 = tvm.Schedule(A2.op)
s = tvm.Schedule(A2.op)
s[A1].set_scope("shared")
thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x")
xo, xi = sA2.split(A2.op.axis[0], 32)
xi0, xi1 = sA2.split(xi, outer=thread_x)
yo, yi = sA2.split(A2.op.axis[1], 16)
sA2.reorder(xo, xi0, yo, xi1, yi)
sA1.compute_at(sA2, yo)
xo, xi = s[A2].split(A2.op.axis[0], 32)
xi0, xi1 = s[A2].split(xi, outer=thread_x)
yo, yi = s[A2].split(A2.op.axis[1], 16)
s[A2].reorder(xo, xi0, yo, xi1, yi)
s[A1].compute_at(s[A2], yo)
bounds = tvm.schedule.InferBound(sA2)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)
......@@ -56,10 +57,12 @@ def test_create_read_graph():
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j])
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3)
g = tvm.schedule.CreateReadGraph(A2.op)
g = tvm.schedule.CreateReadGraph([A2.op])
assert g[A2.op][0] == A1
assert g[A1.op][0] == A
post_order = tvm.schedule.PostDFSOrder(A2.op, g)
post_order = tvm.schedule.PostDFSOrder([A2.op], g)
assert(post_order[0] == A1.op)
assert(post_order[1] == A2.op)
......
#!/bin/bash
if [ ${TASK} == "lint" ]; then
make lint || exit -1
echo "Check documentations of c++ code..."
make doc 2>log.txt
(cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag") > logclean.txt
echo "---------Error Log----------"
cat logclean.txt
echo "----------------------------"
(cat logclean.txt|grep warning) && exit -1
(cat logclean.txt|grep error) && exit -1
exit 0
if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then
if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then
make lint || exit -1
echo "Check documentations of c++ code..."
make doc 2>log.txt
(cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag") > logclean.txt
echo "---------Error Log----------"
cat logclean.txt
echo "----------------------------"
(cat logclean.txt|grep warning) && exit -1
(cat logclean.txt|grep error) && exit -1
fi
fi
......@@ -22,19 +22,16 @@ if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then
fi
fi
if [ ${TASK} == "cpp_test" ]; then
if [ ${TASK} == "cpp_test" ] || [ ${TASK} == "all_test" ]; then
make -f dmlc-core/scripts/packages.mk gtest
make test || exit -1
for test in tests/cpp/*_test; do
./$test || exit -1
done
exit 0
fi
# run two test one for cython, one for ctypes
if [ ${TASK} == "python_test" ]; then
make clean
make -j all || exit -1
if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
make all || exit -1
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
python -m nose tests/python/ || exit -1
python3 -m nose tests/python/ || exit -1
......@@ -42,5 +39,4 @@ if [ ${TASK} == "python_test" ]; then
nosetests tests/python/ || exit -1
nosetests3 tests/python/ || exit -1
fi
exit 0
fi
#!/bin/bash
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
brew update
brew install python3
if [ ${TASK} == "python_test" ]; then
python -m pip install --user nose
python3 -m pip install --user nose
if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
brew update
brew install python3
python -m pip install --user nose numpy
python3 -m pip install --user nose numpy
fi
fi
if [ ${TASK} == "lint" ]; then
pip install --user cpplint 'pylint==1.4.4' 'astroid==1.3.6'
if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then
if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then
pip install --user cpplint 'pylint==1.4.4' 'astroid==1.3.6'
fi
fi
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment