Commit 39314014 by Tianqi Chen Committed by Haichen Shen

[LANG] Change Schedule->Stage, Use Schedule for global schedule (#8)

* [LANG] Change Schedule->Stage, Use Schedule for global schedule

* add numpy as dep

* add numpy installation, temporary disable osx
parent e953e2e6
...@@ -4,13 +4,11 @@ language: cpp ...@@ -4,13 +4,11 @@ language: cpp
os: os:
- linux - linux
- osx # - osx
env: env:
# code analysis # code analysis
- TASK=lint - TASK=all_test
- TASK=cpp_test
- TASK=python_test
branches: branches:
only: only:
...@@ -35,6 +33,7 @@ addons: ...@@ -35,6 +33,7 @@ addons:
- g++-4.8 - g++-4.8
- python-numpy - python-numpy
- python-nose - python-nose
- python3-numpy
- python3-dev - python3-dev
- python3-nose - python3-nose
- graphviz - graphviz
......
Subproject commit 1ec478bbd0c20b8659f0c897363b5a76e13ef495 Subproject commit 98e8df564f8543b337ec0528dbcb06a30f91e694
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
namespace tvm { namespace tvm {
// Node container for Stage
class StageNode;
// Node container for Schedule // Node container for Schedule
class ScheduleNode; class ScheduleNode;
// Node container for IterVarRelation // Node container for IterVarRelation
...@@ -25,46 +27,48 @@ enum AttachType : int { ...@@ -25,46 +27,48 @@ enum AttachType : int {
kScope = 3 kScope = 3
}; };
/*! \brief schedule container */ /*! \brief Stage, contains scheduling for a stage of computation. */
class Schedule : public NodeRef { class Stage : public NodeRef {
public: public:
Schedule() {} Stage() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Stage(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief create a new schedule for op. * \brief create a new schedule for op.
* \param op The operator in the schedule * \param op The operator in the schedule
* \param scope The scope of the schedule
*/ */
Schedule(Operation op, std::string scope); explicit Stage(Operation op);
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
*/ */
inline const ScheduleNode* operator->() const; inline const StageNode* operator->() const;
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
*/ */
inline ScheduleNode* operator->(); inline StageNode* operator->();
/*!
* \brief set the memory scope of the stage
* \param scope The memory scope.
*/
Stage& set_scope(std::string scope); // NOLINT(*)
/*! /*!
* \brief specify the schedule to be computed at the parent schedule's scope. * \brief specify the schedule to be computed at the parent schedule's scope.
* \param parent The parent schedule. * \param parent The parent schedule.
* \param scope The iteration point to carry the schedule. * \param scope The iteration point to carry the schedule.
* \return reference to self. * \return reference to self.
*/ */
Schedule& compute_at(Schedule parent, IterVar scope); // NOLINT(*) Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
/*! /*!
* \brief Compute the function inline, attach it at parent. * \brief Compute the function inline, attach it at parent.
* \param parent The parent schedule to be attached to.
* \return reference to self. * \return reference to self.
*/ */
Schedule& compute_inline(Schedule parent); // NOLINT(*) Stage& compute_inline(); // NOLINT(*)
/*! /*!
* \brief Compute the function at root, attach it to its parent. * \brief Compute the function at root, attach it to its parent.
* \param parent The parent schedule to be attached to.
* \return reference to self. * \return reference to self.
*/ */
Schedule& compute_root(Schedule parent); // NOLINT(*) Stage& compute_root(); // NOLINT(*)
/*! /*!
* \brief Split the parent by factor, generate * \brief Split the parent by factor, generate
* \param parent The parent iteration domain. * \param parent The parent iteration domain.
...@@ -73,7 +77,7 @@ class Schedule : public NodeRef { ...@@ -73,7 +77,7 @@ class Schedule : public NodeRef {
* \param factor The split factor of the loop. * \param factor The split factor of the loop.
* \return reference to self. * \return reference to self.
*/ */
Schedule& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*) Stage& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
/*! /*!
* \brief Split the iteration with a given outer domain, * \brief Split the iteration with a given outer domain,
* the outer domain must have a thread-tag. * the outer domain must have a thread-tag.
...@@ -85,7 +89,7 @@ class Schedule : public NodeRef { ...@@ -85,7 +89,7 @@ class Schedule : public NodeRef {
* factor must be provided such that factor * outer.extent >= parent.extent. * factor must be provided such that factor * outer.extent >= parent.extent.
* \return reference to self. * \return reference to self.
*/ */
Schedule& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*) Stage& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
/*! /*!
* \brief Fuse the inner outer domain to the target * \brief Fuse the inner outer domain to the target
* \param inner The inner domain to be fused * \param inner The inner domain to be fused
...@@ -93,16 +97,66 @@ class Schedule : public NodeRef { ...@@ -93,16 +97,66 @@ class Schedule : public NodeRef {
* \param p_target The result target domain. * \param p_target The result target domain.
* \return reference to self. * \return reference to self.
*/ */
Schedule& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*) Stage& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
/*! /*!
* \brief Reorder the iteration * \brief Reorder the iteration
* \param order The order of iteration variable. * \param order The order of iteration variable.
* \return reference to self. * \return reference to self.
*/ */
Schedule& reorder(const Array<IterVar>& order); // NOLINT(*) Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
Schedule& tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer, /*!
IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner, * \brief Perform tiling on two dimensions
Expr x_factor, Expr y_factor); // NOLINT(*) * The final loop order from outmost to inner most are
* [x_outer, y_outer, x_inner, y_inner]
*
* \param x_parent The original x dimension
* \param y_parent The original y dimension
* \param p_x_outer Outer axis of x dimension
* \param p_y_outer Outer axis of y dimension
* \param p_x_inner Inner axis of x dimension
* \param p_y_inner Inner axis of y dimension
* \param x_factor The stride factor on x axis
* \param y_factor The stride factor on y axis
* \return reference to self.
*/
Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
};
/*!
* \brief Global schedule container
* For operations and all the operations they depend on.
* The schedule per Operation is named as stage.
*/
class Schedule : public NodeRef {
public:
Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief construct schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
*/
explicit Schedule(Array<Operation> ops);
/*!
* \brief Get the stage corresponds to the op
* \param op The operation.
*/
Stage operator[](const Operation& op);
/*!
* \brief Short hand for getting the stage of tensor's operation.
* \param tensor The tensor
* \return The stage corresponding to the tensor's op
*/
Stage operator[](const Tensor& tensor) {
return this->operator[](tensor->op);
}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
}; };
/*! /*!
...@@ -135,11 +189,11 @@ class IterVarRelation : public NodeRef { ...@@ -135,11 +189,11 @@ class IterVarRelation : public NodeRef {
* *
* The relations connects the IterVars in the graph. * The relations connects the IterVars in the graph.
*/ */
class ScheduleNode : public Node { class StageNode : public Node {
public: public:
/*! \brief The operation to be scheduled */ /*! \brief The operation to be scheduled */
Operation op; Operation op;
/*! \brief The thread scope level of the schedule */ /*! \brief The thread scope level of the stage */
std::string scope; std::string scope;
/*! \brief All the nodes in the iter var */ /*! \brief All the nodes in the iter var */
Array<IterVar> all_iter_vars; Array<IterVar> all_iter_vars;
...@@ -152,12 +206,10 @@ class ScheduleNode : public Node { ...@@ -152,12 +206,10 @@ class ScheduleNode : public Node {
Array<IterVarRelation> relations; Array<IterVarRelation> relations;
/*! \brief The attachment type of the schedule */ /*! \brief The attachment type of the schedule */
AttachType attach_type{kNone}; AttachType attach_type{kNone};
/*! /*! \brief The attach point of this schedule. */
* \brief The attach point of this schedule. IterVar attach_ivar;
*/ /*! \brief The stage this node attaches to */
IterVar attach_parent; Stage attach_stage;
/*! \brief the schedules that this schedule depend on */
Array<Schedule> children;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope); v->Visit("scope", &scope);
...@@ -166,8 +218,31 @@ class ScheduleNode : public Node { ...@@ -166,8 +218,31 @@ class ScheduleNode : public Node {
v->Visit("leaf_iter_vars", &leaf_iter_vars); v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("relations", &relations); v->Visit("relations", &relations);
v->Visit("attach_type", &attach_type); v->Visit("attach_type", &attach_type);
v->Visit("attach_parent", &attach_parent); v->Visit("attach_ivar", &attach_ivar);
v->Visit("children", &children); v->Visit("attach_stage", &attach_stage);
}
static constexpr const char* _type_key = "Stage";
TVM_DECLARE_NODE_TYPE_INFO(StageNode);
};
/*! \brief node container for schedule */
class ScheduleNode : public Node {
public:
/*! \brief The root operations */
Array<Operation> roots;
/*!
* \brief list of all stages for non-placeholder ops
* The stage are ordered in PostDFS order of their op.
*/
Array<Stage> stages;
/*! \brief map of operation to the stages */
Map<Operation, Stage> stage_map;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("roots", &roots);
v->Visit("stages", &stages);
v->Visit("stage_map", &stage_map);
} }
static constexpr const char* _type_key = "Schedule"; static constexpr const char* _type_key = "Schedule";
...@@ -234,12 +309,16 @@ class FuseNode : public IterVarRelationNode { ...@@ -234,12 +309,16 @@ class FuseNode : public IterVarRelationNode {
}; };
// implementations // implementations
inline const StageNode* Stage::operator->() const {
return static_cast<const StageNode*>(node_.get());
}
inline StageNode* Stage::operator->() {
return static_cast<StageNode*>(node_.get());
}
inline const ScheduleNode* Schedule::operator->() const { inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get()); return static_cast<const ScheduleNode*>(node_.get());
} }
inline ScheduleNode* Schedule::operator->() {
return static_cast<ScheduleNode*>(node_.get());
}
inline const IterVarRelationNode* IterVarRelation::operator->() const { inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get()); return static_cast<const IterVarRelationNode*>(node_.get());
......
...@@ -174,8 +174,17 @@ def max(expr, rdom): ...@@ -174,8 +174,17 @@ def max(expr, rdom):
return x return x
def Schedule(tensor, scope="global"): def Schedule(ops):
return _function_internal._Schedule(tensor, scope) """Create a schedule for list of ops
Parameters
----------
ops : list of Operations
The source expression.
"""
if not isinstance(ops, (list, _collections.Array)):
ops = [ops]
return _function_internal._Schedule(ops)
_init_function_module("tvm") _init_function_module("tvm")
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
from . import _function_internal from . import _function_internal
from . import tensor as _tensor
@register_node @register_node
class Split(NodeBase): class Split(NodeBase):
...@@ -11,10 +12,22 @@ class Split(NodeBase): ...@@ -11,10 +12,22 @@ class Split(NodeBase):
class Fuse(NodeBase): class Fuse(NodeBase):
pass pass
@register_node @register_node
class Schedule(NodeBase): class Schedule(NodeBase):
def __getitem__(self, k):
if isinstance(k, _tensor.Tensor):
k = k.op
if not isinstance(k, _tensor.Operation):
raise ValueError("Expect schedule key to be Tensor or Operation")
if not k in self.stage_map:
raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
return self.stage_map[k]
@register_node
class Stage(NodeBase):
def split(self, parent, factor=None, outer=None): def split(self, parent, factor=None, outer=None):
"""Split the schedule either by factor providing outer scope, or both """Split the stage either by factor providing outer scope, or both
Parameters Parameters
---------- ----------
...@@ -40,11 +53,11 @@ class Schedule(NodeBase): ...@@ -40,11 +53,11 @@ class Schedule(NodeBase):
raise ValueError("split by outer must have special thread_tag") raise ValueError("split by outer must have special thread_tag")
if outer.dom is None: if outer.dom is None:
raise ValueError("split by outer must have specified domain") raise ValueError("split by outer must have specified domain")
inner = _function_internal._ScheduleSplitByOuter(self, parent, outer, factor) inner = _function_internal._StageSplitByOuter(self, parent, outer, factor)
else: else:
if factor is None: if factor is None:
raise ValueError("either outer or factor need to be provided") raise ValueError("either outer or factor need to be provided")
outer, inner = _function_internal._ScheduleSplitByFactor(self, parent, factor) outer, inner = _function_internal._StageSplitByFactor(self, parent, factor)
return outer, inner return outer, inner
def fuse(self, inner, outer): def fuse(self, inner, outer):
...@@ -63,40 +76,50 @@ class Schedule(NodeBase): ...@@ -63,40 +76,50 @@ class Schedule(NodeBase):
inner : IterVar inner : IterVar
The fused variable of iteration. The fused variable of iteration.
""" """
return _function_internal._ScheduleFuse(self, inner, outer) return _function_internal._StageFuse(self, inner, outer)
def set_scope(self, scope):
"""Set the thread scope of this stage
Parameters
----------
scope : str
The thread scope of this stage
"""
return _function_internal._StageSetScope(self, scope)
def compute_at(self, parent, scope): def compute_at(self, parent, scope):
"""Attach the schedule at parent's scope """Attach the stage at parent's scope
Parameters Parameters
---------- ----------
parent : Schedule parent : Stage
The parent schedule The parent stage
scope : IterVar scope : IterVar
The loop scope t be attached to. The loop scope t be attached to.
""" """
_function_internal._ScheduleComputeAt(self, parent, scope) _function_internal._StageComputeAt(self, parent, scope)
def compute_inline(self, parent): def compute_inline(self):
"""Attach the schedule at parent, and mark it as inline """Mark stage as inline
Parameters Parameters
---------- ----------
parent : Schedule parent : Stage
The parent schedule The parent stage
""" """
_function_internal._ScheduleComputeInline(self, parent) _function_internal._StageComputeInline(self)
def compute_root(self, parent): def compute_root(self):
"""Attach the schedule at parent, and mark it as root """Attach the stage at parent, and mark it as root
Parameters Parameters
---------- ----------
parent : Schedule parent : Stage
The parent schedule The parent stage
""" """
_function_internal._ScheduleComputeInline(self, parent) _function_internal._StageComputeInline(self)
def reorder(self, *args): def reorder(self, *args):
"""reorder the arguments in the specified order. """reorder the arguments in the specified order.
...@@ -106,9 +129,9 @@ class Schedule(NodeBase): ...@@ -106,9 +129,9 @@ class Schedule(NodeBase):
args : list of IterVar args : list of IterVar
The order to be ordered The order to be ordered
""" """
_function_internal._ScheduleReorder(self, args) _function_internal._StageReorder(self, args)
def tile(self, x_parent, y_parent, x_factor, y_factor): def tile(self, x_parent, y_parent, x_factor, y_factor):
x_outer, y_outer, x_inner, y_inner = _function_internal._ScheduleTile( x_outer, y_outer, x_inner, y_inner = _function_internal._StageTile(
self, x_parent, y_parent, x_factor, y_factor) self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner return x_outer, y_outer, x_inner, y_inner
...@@ -176,7 +176,7 @@ TVM_REGISTER_API(_ComputeOp) ...@@ -176,7 +176,7 @@ TVM_REGISTER_API(_ComputeOp)
TVM_REGISTER_API(_OpGetOutput) TVM_REGISTER_API(_OpGetOutput)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = args.at(0).operator Operation().output( *ret = args.at(0).operator Operation().output(
args.at(1).operator size_t()); args.at(1).operator int64_t());
}); });
...@@ -185,64 +185,69 @@ TVM_REGISTER_API(_IterVar) ...@@ -185,64 +185,69 @@ TVM_REGISTER_API(_IterVar)
*ret = IterVar(args.at(0), args.at(1), args.at(2)); *ret = IterVar(args.at(0), args.at(1), args.at(2));
}); });
TVM_REGISTER_API(_Schedule) TVM_REGISTER_API(_Schedule)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = Schedule(args.at(0), args.at(1)); *ret = Schedule(args.at(0).operator Array<Operation>());
});
TVM_REGISTER_API(_StageSetScope)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Stage()
.set_scope(args.at(1));
}); });
TVM_REGISTER_API(_ScheduleSplitByFactor) TVM_REGISTER_API(_StageSplitByFactor)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
IterVar outer, inner; IterVar outer, inner;
args.at(0).operator Schedule() args.at(0).operator Stage()
.split(args.at(1), &outer, &inner, args.at(2)); .split(args.at(1), &outer, &inner, args.at(2));
*ret = Array<IterVar>({outer, inner}); *ret = Array<IterVar>({outer, inner});
}); });
TVM_REGISTER_API(_ScheduleSplitByOuter) TVM_REGISTER_API(_StageSplitByOuter)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
IterVar inner; IterVar inner;
args.at(0).operator Schedule() args.at(0).operator Stage()
.split(args.at(1), args.at(2), &inner, args.at(3)); .split(args.at(1), args.at(2), &inner, args.at(3));
*ret = inner; *ret = inner;
}); });
TVM_REGISTER_API(_ScheduleFuse) TVM_REGISTER_API(_StageFuse)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
IterVar fused; IterVar fused;
args.at(0).operator Schedule() args.at(0).operator Stage()
.split(args.at(1), args.at(2), &fused); .split(args.at(1), args.at(2), &fused);
*ret = fused; *ret = fused;
}); });
TVM_REGISTER_API(_ScheduleComputeAt) TVM_REGISTER_API(_StageComputeAt)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule() args.at(0).operator Stage()
.compute_at(args.at(1), args.at(2)); .compute_at(args.at(1), args.at(2));
}); });
TVM_REGISTER_API(_ScheduleComputeInline) TVM_REGISTER_API(_StageComputeInline)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule() args.at(0).operator Stage()
.compute_inline(args.at(1)); .compute_inline();
}); });
TVM_REGISTER_API(_ScheduleComputeRoot) TVM_REGISTER_API(_StageComputeRoot)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule() args.at(0).operator Stage()
.compute_root(args.at(1)); .compute_root();
}); });
TVM_REGISTER_API(_ScheduleReorder) TVM_REGISTER_API(_StageReorder)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Schedule() args.at(0).operator Stage()
.reorder(args.at(1)); .reorder(args.at(1));
}); });
TVM_REGISTER_API(_ScheduleTile) TVM_REGISTER_API(_StageTile)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
IterVar x_outer, y_outer, x_inner, y_inner; IterVar x_outer, y_outer, x_inner, y_inner;
args.at(0).operator Schedule() args.at(0).operator Stage()
.tile(args.at(1), args.at(2), &x_outer, &y_outer, .tile(args.at(1), args.at(2), &x_outer, &y_outer,
&x_inner, &y_inner, args.at(3), args.at(4)); &x_inner, &y_inner, args.at(3), args.at(4));
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
......
...@@ -22,7 +22,7 @@ namespace { ...@@ -22,7 +22,7 @@ namespace {
* \param p_state The message passing state * \param p_state The message passing state
* IterVar->The assignment. * IterVar->The assignment.
*/ */
void PassUpOffset(const Schedule& s, void PassUpOffset(const Stage& s,
const Map<IterVar, Range>& dom_map, const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state) { std::unordered_map<IterVar, Expr>* p_state) {
auto& state = *p_state; auto& state = *p_state;
...@@ -130,7 +130,7 @@ Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) { ...@@ -130,7 +130,7 @@ Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
* The flattened Stmt are ordered from outmost to inner most order. * The flattened Stmt are ordered from outmost to inner most order.
*/ */
std::vector<std::vector<Stmt> > MakeLoopNest( std::vector<std::vector<Stmt> > MakeLoopNest(
const Schedule& sch, const Stage& sch,
const Map<IterVar, Range>& dom_map) { const Map<IterVar, Range>& dom_map) {
// optional, use let to define some CSE in dom_map. // optional, use let to define some CSE in dom_map.
auto leaf_iter_vars = sch->leaf_iter_vars; auto leaf_iter_vars = sch->leaf_iter_vars;
...@@ -244,7 +244,7 @@ Stmt MakeRealize(const ComputeOpNode* op, ...@@ -244,7 +244,7 @@ Stmt MakeRealize(const ComputeOpNode* op,
bounds, make_const(Bool(1), true), body); bounds, make_const(Bool(1), true), body);
} }
Stmt MakePipeline(const Schedule& sch, Stmt MakePipeline(const Stage& sch,
const Map<IterVar, Range>& dom_map, const Map<IterVar, Range>& dom_map,
Stmt consumer) { Stmt consumer) {
std::vector<Tensor> tensors; std::vector<Tensor> tensors;
...@@ -280,7 +280,7 @@ Stmt MakePipeline(const Schedule& sch, ...@@ -280,7 +280,7 @@ Stmt MakePipeline(const Schedule& sch,
// inject the operator's realization on the stmt. // inject the operator's realization on the stmt.
class InjectRealize : public IRMutator { class InjectRealize : public IRMutator {
public: public:
InjectRealize(Schedule schedule, Map<IterVar, Range> dom_map) InjectRealize(Stage schedule, Map<IterVar, Range> dom_map)
: schedule(schedule), dom_map(dom_map) {} : schedule(schedule), dom_map(dom_map) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
...@@ -289,7 +289,7 @@ class InjectRealize : public IRMutator { ...@@ -289,7 +289,7 @@ class InjectRealize : public IRMutator {
const AttrStmt* op = stmt.as<AttrStmt>(); const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr && if (op != nullptr &&
op->type_key == "scope") { op->type_key == "scope") {
if (op->node == schedule->attach_parent) { if (op->node == schedule->attach_ivar) {
CHECK(!found_attach); CHECK(!found_attach);
found_attach = true; found_attach = true;
stmt = AttrStmt::make( stmt = AttrStmt::make(
...@@ -301,41 +301,13 @@ class InjectRealize : public IRMutator { ...@@ -301,41 +301,13 @@ class InjectRealize : public IRMutator {
return stmt; return stmt;
} }
// the operations to be carried // the operations to be carried
Schedule schedule; Stage schedule;
// domain map // domain map
Map<IterVar, Range> dom_map; Map<IterVar, Range> dom_map;
// whether attach point is found // whether attach point is found
bool found_attach{false}; bool found_attach{false};
}; };
void GetOpToScheduleMap(
Schedule s,
std::unordered_map<Operation, Schedule>* ret) {
CHECK(!ret->count(s->op))
<< "Duplicated schedule for op";
(*ret)[s->op] = s;
for (Schedule c : s->children) {
GetOpToScheduleMap(c, ret);
}
}
// order schedule by DFS calling order of ops
std::vector<Schedule> OrderSchedule(Schedule s) {
auto g = schedule::CreateReadGraph(s->op);
auto post_order = schedule::PostDFSOrder(s->op, g);
std::unordered_map<Operation, Schedule> op2sch;
GetOpToScheduleMap(s, &op2sch);
std::vector<Schedule> sorder;
// reverse iteration.
for (size_t i = post_order.size(); i != 0; --i) {
sorder.push_back(op2sch.at(post_order[i - 1]));
}
return sorder;
}
Stmt InjectInline(const Operation op, Stmt body) { Stmt InjectInline(const Operation op, Stmt body) {
CHECK(body.defined()); CHECK(body.defined());
const ComputeOpNode* compute = op.as<ComputeOpNode>(); const ComputeOpNode* compute = op.as<ComputeOpNode>();
...@@ -351,11 +323,11 @@ Stmt InjectInline(const Operation op, Stmt body) { ...@@ -351,11 +323,11 @@ Stmt InjectInline(const Operation op, Stmt body) {
} // namespace } // namespace
Stmt ScheduleOps( Stmt ScheduleOps(
Schedule s, Map<IterVar, Range> dom_map) { Schedule sch, Map<IterVar, Range> dom_map) {
std::vector<Schedule> svec = OrderSchedule(s);
Stmt body = Stmt(); Stmt body = Stmt();
// reverse the post DFS order.
for (Schedule s : svec) { for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1];
if (s->attach_type == kInline) { if (s->attach_type == kInline) {
body = InjectInline(s->op, body); body = InjectInline(s->op, body);
} else if (s->attach_type == kRoot || s-> attach_type == kNone) { } else if (s->attach_type == kRoot || s-> attach_type == kNone) {
......
...@@ -17,10 +17,10 @@ inline Expr DivCeil(Expr a, Expr b) { ...@@ -17,10 +17,10 @@ inline Expr DivCeil(Expr a, Expr b) {
return (a + b - 1) / b; return (a + b - 1) / b;
} }
// Downward message passing algorithm on schedule s, // Downward message passing algorithm on stage schedule s,
// pass the range state down from the root to the leaves // pass the range state down from the root to the leaves
// after this pass, every IterVar in the schedule hyper graph will have a range(domain) // after this pass, every IterVar in the stage hyper graph will have a range(domain)
void PassDown(const Schedule& s, void PassDown(const Stage& s,
std::unordered_map<IterVar, Range>* p_state) { std::unordered_map<IterVar, Range>* p_state) {
auto& state = *p_state; auto& state = *p_state;
// forwar iteration on relations // forwar iteration on relations
...@@ -63,7 +63,7 @@ void PassDown(const Schedule& s, ...@@ -63,7 +63,7 @@ void PassDown(const Schedule& s,
// pass the integer set on each leave loop up to the root // pass the integer set on each leave loop up to the root
// dom_map is the result of PassDown, it records the domain of each IterVar. // dom_map is the result of PassDown, it records the domain of each IterVar.
// dom_map can be used to get cached result in reverse construction. // dom_map can be used to get cached result in reverse construction.
void PassUp(const ScheduleNode* s, void PassUp(const Stage& s,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state) { std::unordered_map<IterVar, IntSet>* p_state) {
auto& state = *p_state; auto& state = *p_state;
...@@ -180,13 +180,11 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) { ...@@ -180,13 +180,11 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) {
return scope_rank.at(scope) <= thread_tag_rank.at(iv->thread_tag); return scope_rank.at(scope) <= thread_tag_rank.at(iv->thread_tag);
} }
void InferBound( void InferBound(const Stage& stage,
const ScheduleNode* parent, std::unordered_map<IterVar, Range>* rmap) {
const Schedule& sch, if (stage->attach_type == kInline) return;
std::unordered_map<IterVar, Range>* rmap) { if (stage->attach_type == kRoot || stage->attach_type == kNone) {
if (sch->attach_type == kInline) return; auto root_iter_vars = stage->op->root_iter_vars();
if (sch->attach_type == kRoot || sch->attach_type == kNone) {
auto root_iter_vars = sch->op->root_iter_vars();
for (auto iv : root_iter_vars) { for (auto iv : root_iter_vars) {
CHECK(iv->dom.defined()); CHECK(iv->dom.defined());
CHECK(!rmap->count(iv)); CHECK(!rmap->count(iv));
...@@ -194,22 +192,23 @@ void InferBound( ...@@ -194,22 +192,23 @@ void InferBound(
} }
} }
// get range of all child iter vars. // get range of all child iter vars.
PassDown(sch, rmap); PassDown(stage, rmap);
if (sch->attach_type == kScope) { if (stage->attach_type == kScope) {
CHECK(parent != nullptr); Stage parent = stage->attach_stage;
auto g = CreateReadGraph(parent->op); CHECK(parent.defined());
auto post_order = PostDFSOrder(parent->op, g); auto g = CreateReadGraph({parent->op});
auto post_order = PostDFSOrder({parent->op}, g);
std::unordered_map<IterVar, IntSet> up_state; std::unordered_map<IterVar, IntSet> up_state;
bool fix_value = true; bool fix_value = true;
for (auto iv : parent->leaf_iter_vars) { for (auto iv : parent->leaf_iter_vars) {
if (fix_value && !ScopeRelax(iv, sch->scope)) { if (fix_value && !ScopeRelax(iv, stage->scope)) {
up_state[iv] = IntSet::make_point(iv->var); up_state[iv] = IntSet::make_point(iv->var);
} else { } else {
up_state[iv] = IntSet::make_range(rmap->at(iv)); up_state[iv] = IntSet::make_range(rmap->at(iv));
} }
if (sch->attach_parent == iv) { if (stage->attach_ivar == iv) {
fix_value = false; fix_value = false;
} }
} }
...@@ -221,24 +220,22 @@ void InferBound( ...@@ -221,24 +220,22 @@ void InferBound(
bp_state[iv] = {up_state.at(iv)}; bp_state[iv] = {up_state.at(iv)};
} }
auto result = BoundProp(post_order, &bp_state); auto result = BoundProp(post_order, &bp_state);
for (auto iv : sch->op->root_iter_vars()) { for (auto iv : stage->op->root_iter_vars()) {
CHECK(result.count(iv)); CHECK(result.count(iv));
CHECK(!rmap->count(iv)); CHECK(!rmap->count(iv));
(*rmap)[iv] = result.at(iv).GetCoverRange(); (*rmap)[iv] = result.at(iv).GetCoverRange();
} }
} }
// also call infer bound on children
for (Schedule child : sch->children) {
InferBound(sch.operator->(), child, rmap);
}
} }
Map<IterVar, Range> InferBound(Schedule sch) { Map<IterVar, Range> InferBound(Schedule sch) {
std::unordered_map<IterVar, Range> ret; std::unordered_map<IterVar, Range> ret;
CHECK(sch->attach_type != kInline && sch->attach_type != kScope) // reverse post DFS order, from out most stage to the innermost
<< "the Schedule is not a root Schedule"; for (size_t i = sch->stages.size(); i != 0; --i) {
InferBound(nullptr, sch, &ret); Stage stage = sch->stages[i - 1];
InferBound(stage, &ret);
}
return Map<IterVar, Range>(ret.begin(), ret.end()); return Map<IterVar, Range>(ret.begin(), ret.end());
} }
......
...@@ -14,10 +14,15 @@ namespace schedule { ...@@ -14,10 +14,15 @@ namespace schedule {
// construct a read graph that gives readers of each operation // construct a read graph that gives readers of each operation
// that the root depend on // that the root depend on
ReadGraph CreateReadGraph(const Operation& root) { ReadGraph CreateReadGraph(const Array<Operation>& roots) {
ReadGraph rmap; ReadGraph rmap;
std::vector<Operation> stack{root}; std::vector<Operation> stack;
std::unordered_set<const Node*> visited{root.get()}; std::unordered_set<const Node*> visited;
// initialize the roots
for (Operation op : roots) {
stack.push_back(op);
visited.insert(op.get());
}
while (!stack.empty()) { while (!stack.empty()) {
Operation op = stack.back(); Operation op = stack.back();
...@@ -51,20 +56,22 @@ void PostDFSOrder(const Operation& op, ...@@ -51,20 +56,22 @@ void PostDFSOrder(const Operation& op,
const ReadGraph& g, const ReadGraph& g,
std::unordered_set<Operation>* visited, std::unordered_set<Operation>* visited,
Array<Operation>* post_order) { Array<Operation>* post_order) {
if (op.as<PlaceholderOpNode>() || visited->count(op)) return;
visited->insert(op); visited->insert(op);
for (const auto& t : g.at(op)) { for (const auto& t : g.at(op)) {
if (!t->op.as<PlaceholderOpNode>() && !visited->count(t->op)) { PostDFSOrder(t->op, g, visited, post_order);
PostDFSOrder(t->op, g, visited, post_order);
}
} }
post_order->push_back(op); post_order->push_back(op);
} }
Array<Operation> PostDFSOrder( Array<Operation> PostDFSOrder(
const Operation& root, const ReadGraph& g) { const Array<Operation>& roots,
const ReadGraph& g) {
std::unordered_set<Operation> visited; std::unordered_set<Operation> visited;
Array<Operation> post_order; Array<Operation> post_order;
PostDFSOrder(root, g, &visited, &post_order); for (Operation op : roots) {
PostDFSOrder(op, g, &visited, &post_order);
}
return post_order; return post_order;
} }
......
...@@ -24,14 +24,14 @@ using ReadGraph = Map<Operation, Array<Tensor> >; ...@@ -24,14 +24,14 @@ using ReadGraph = Map<Operation, Array<Tensor> >;
* Tensors that it directly depends on. * Tensors that it directly depends on.
* *
* The result map contains Operations needed to finish root Operation. * The result map contains Operations needed to finish root Operation.
* \param root The root operation. * \param roots The root operation.
* \return The result map. * \return The result map.
*/ */
ReadGraph CreateReadGraph(const Operation& root); ReadGraph CreateReadGraph(const Array<Operation>& roots);
/*! /*!
* \brief Get a post DFS ordered of operations in the graph. * \brief Get a post DFS ordered of operations in the graph.
* \param root The root of the graph. * \param roots The root of the graph.
* \param g The read graph. * \param g The read graph.
* \return vector order of Operations in PostDFS order. * \return vector order of Operations in PostDFS order.
* *
...@@ -39,7 +39,7 @@ ReadGraph CreateReadGraph(const Operation& root); ...@@ -39,7 +39,7 @@ ReadGraph CreateReadGraph(const Operation& root);
* and can be used when topoligical order is needed. * and can be used when topoligical order is needed.
*/ */
Array<Operation> PostDFSOrder( Array<Operation> PostDFSOrder(
const Operation& root, const ReadGraph& g); const Array<Operation>& roots, const ReadGraph& g);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \file schedule.cc * \file schedule.cc
*/ */
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include "./graph.h"
namespace tvm { namespace tvm {
...@@ -31,7 +32,7 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) ...@@ -31,7 +32,7 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v)
return 0; return 0;
} }
void Split(ScheduleNode* self, IterVar parent, void Split(StageNode* self, IterVar parent,
IterVar outer, IterVar inner, Expr factor) { IterVar outer, IterVar inner, Expr factor) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
...@@ -49,19 +50,30 @@ void Split(ScheduleNode* self, IterVar parent, ...@@ -49,19 +50,30 @@ void Split(ScheduleNode* self, IterVar parent,
} // namespace } // namespace
Schedule::Schedule(Operation op, std::string scope) { TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
auto n = std::make_shared<ScheduleNode>(); .set_dispatch<StageNode>([](const StageNode *op, IRPrinter *p) {
p->stream << "stage("
<< op->op
<< ")";
});
Stage::Stage(Operation op) {
auto n = std::make_shared<StageNode>();
n->op = op; n->op = op;
n->scope = scope;
n->all_iter_vars = op->root_iter_vars(); n->all_iter_vars = op->root_iter_vars();
n->leaf_iter_vars = op->root_iter_vars(); n->leaf_iter_vars = op->root_iter_vars();
node_ = n; node_ = n;
} }
Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*) Stage& Stage::set_scope(std::string scope) { // NOLINT(*)
CHECK_EQ((*this)->attach_type, kNone); (*this)->scope = scope;
return *this;
}
Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
(*this)->attach_type = kScope; (*this)->attach_type = kScope;
(*this)->attach_parent = scope; (*this)->attach_ivar = scope;
(*this)->attach_stage = parent;
bool found = false; bool found = false;
for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) { for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) {
if (scope == parent->leaf_iter_vars[i]) { if (scope == parent->leaf_iter_vars[i]) {
...@@ -70,25 +82,20 @@ Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*) ...@@ -70,25 +82,20 @@ Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*)
} }
CHECK(found) CHECK(found)
<< "Cannot compute at a iteration variable that is not part of parent leaf vars"; << "Cannot compute at a iteration variable that is not part of parent leaf vars";
parent->children.push_back(*this);
return *this; return *this;
} }
Schedule& Schedule::compute_inline(Schedule parent) { // NOLINT(*) Stage& Stage::compute_inline() { // NOLINT(*)
CHECK_EQ((*this)->attach_type, kNone);
(*this)->attach_type = kInline; (*this)->attach_type = kInline;
parent->children.push_back(*this);
return *this; return *this;
} }
Schedule& Schedule::compute_root(Schedule parent) { // NOLINT(*) Stage& Stage::compute_root() { // NOLINT(*)
CHECK_EQ((*this)->attach_type, kNone);
(*this)->attach_type = kRoot; (*this)->attach_type = kRoot;
parent->children.push_back(*this);
return *this; return *this;
} }
Schedule& Schedule::split( Stage& Stage::split(
IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*) IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
// place holder for the splitted results. // place holder for the splitted results.
IterVar outer(Range(), parent->var->name_hint + ".outer"); IterVar outer(Range(), parent->var->name_hint + ".outer");
...@@ -99,7 +106,7 @@ Schedule& Schedule::split( ...@@ -99,7 +106,7 @@ Schedule& Schedule::split(
return *this; return *this;
} }
Schedule& Schedule::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*) Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
// place holder for the splitted results. // place holder for the splitted results.
IterVar inner(Range(), parent->var->name_hint + ".inner"); IterVar inner(Range(), parent->var->name_hint + ".inner");
*p_inner = inner; *p_inner = inner;
...@@ -108,9 +115,9 @@ Schedule& Schedule::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr ...@@ -108,9 +115,9 @@ Schedule& Schedule::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr
return *this; return *this;
} }
Schedule& Schedule::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*) Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*)
IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused"); IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused");
ScheduleNode* self = operator->(); StageNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
...@@ -128,8 +135,8 @@ Schedule& Schedule::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // ...@@ -128,8 +135,8 @@ Schedule& Schedule::fuse(IterVar inner, IterVar outer, IterVar* p_target) { //
return *this; return *this;
} }
Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*) Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
ScheduleNode* self = operator->(); StageNode* self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
std::vector<size_t> pos; std::vector<size_t> pos;
...@@ -148,16 +155,34 @@ Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*) ...@@ -148,16 +155,34 @@ Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*)
return *this; return *this;
} }
Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent, Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner, IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor) { // NOLINT(*) Expr x_factor, Expr y_factor) { // NOLINT(*)
split(x_parent, p_x_outer, p_x_inner, x_factor); split(x_parent, p_x_outer, p_x_inner, x_factor);
split(y_parent, p_y_outer, p_y_inner, y_factor); split(y_parent, p_y_outer, p_y_inner, y_factor);
reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner})); reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
return *this; return *this;
} }
Schedule::Schedule(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>();
n->roots = ops;
auto g = schedule::CreateReadGraph(n->roots);
Array<Operation> post_order = schedule::PostDFSOrder(n->roots, g);
for (Operation op : post_order) {
Stage stage(op);
n->stages.push_back(stage);
n->stage_map.Set(op, stage);
}
node_ = std::move(n);
}
Stage Schedule::operator[](const Operation& op) {
return (*this)->stage_map.at(op);
}
IterVarRelation SplitNode::make( IterVarRelation SplitNode::make(
IterVar parent, IterVar outer, IterVar parent, IterVar outer,
IterVar inner, Expr factor) { IterVar inner, Expr factor) {
...@@ -178,7 +203,7 @@ IterVarRelation FuseNode::make( ...@@ -178,7 +203,7 @@ IterVarRelation FuseNode::make(
return IterVarRelation(n); return IterVarRelation(n);
} }
TVM_REGISTER_NODE_TYPE(ScheduleNode); TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(SplitNode); TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode); TVM_REGISTER_NODE_TYPE(FuseNode);
......
...@@ -7,41 +7,37 @@ def test_schedule_create(): ...@@ -7,41 +7,37 @@ def test_schedule_create():
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
B = tvm.placeholder((n, l), name='B') B = tvm.placeholder((n, l), name='B')
AA = tvm.compute((m, l), lambda i, j: A[i, j]) AA = tvm.compute((m, l), lambda i, j: A[i, j])
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) T = tvm.compute((m, n, l), lambda i, j, k: AA(i, k) * B(j, k))
s = tvm.Schedule(T.op)
sch_T = tvm.Schedule(T.op, scope="shared") s[AA].set_scope("shared")
sch_A = tvm.Schedule(AA.op, scope="global") xo, xi = s[T].split(T.op.axis[0], factor=10)
xi1, xi2 = s[T].split(xi, factor=2)
xo, xi = sch_T.split(T.op.axis[0], factor=10) s[AA].compute_at(s[T], xi1)
xi1, xi2 = sch_T.split(xi, factor=2) xo, xi = s[AA].split(AA.op.axis[0], factor=10)
s[T].reorder(xi2, xi1)
sch_A.compute_at(sch_T, xi1) assert T.op.axis[1] in s[T].leaf_iter_vars
xo, xi = sch_A.split(AA.op.axis[0], factor=10)
sch_T.reorder(xi2, xi1)
assert T.op.axis[1] in sch_T.leaf_iter_vars
def test_reorder(): def test_reorder():
m = tvm.Var('m') m = tvm.Var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
T = tvm.compute(m, lambda i: A[i+1]) T = tvm.compute(m, lambda i: A[i+1])
sch_T = tvm.Schedule(T.op, scope="shared") s = tvm.Schedule(T.op)
xo, xi = sch_T.split(T.op.axis[0], factor=10) xo, xi = s[T].split(T.op.axis[0], factor=10)
xi1, xi2 = sch_T.split(xi, factor=2) xi1, xi2 = s[T].split(xi, factor=2)
order = (xi2, xi1, xo) order = (xi2, xi1, xo)
assert tuple(sch_T.leaf_iter_vars) != order assert tuple(s[T].leaf_iter_vars) != order
sch_T.reorder(*order) s[T].reorder(*order)
assert tuple(sch_T.leaf_iter_vars) == order assert tuple(s[T].leaf_iter_vars) == order
def test_split(): def test_split():
m = tvm.Var('m') m = tvm.Var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i: A[i]) T = tvm.compute((m,), lambda i: A[i])
sT = tvm.Schedule(T.op) s = tvm.Schedule(T.op)
xo, xi = sT.split(T.op.axis[0], factor=10) xo, xi = s[T].split(T.op.axis[0], factor=10)
assert tuple(sT.leaf_iter_vars) == (xo, xi) assert tuple(s[T].leaf_iter_vars) == (xo, xi)
def test_tile(): def test_tile():
...@@ -50,9 +46,9 @@ def test_tile(): ...@@ -50,9 +46,9 @@ def test_tile():
A = tvm.placeholder((m, n), name='A') A = tvm.placeholder((m, n), name='A')
T = tvm.compute((m, n), lambda i, j: A[i, j]) T = tvm.compute((m, n), lambda i, j: A[i, j])
sch_T = tvm.Schedule(T.op, scope="shared") s = tvm.Schedule(T.op)
xo, yo, xi, yi = sch_T.tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5) xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
assert tuple(sch_T.leaf_iter_vars) == (xo, yo, xi, yi) assert tuple(s[T].leaf_iter_vars) == (xo, yo, xi, yi)
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_create() test_schedule_create()
......
...@@ -6,10 +6,12 @@ def test_schedule0(): ...@@ -6,10 +6,12 @@ def test_schedule0():
l = tvm.Var('l') l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
sA1 = tvm.Schedule(A1.op)
bounds = tvm.schedule.InferBound(sA1) s = tvm.Schedule(A1.op)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA1, bounds) stmt = tvm.ir_pass.ScheduleOps(s, bounds)
print(stmt) print(stmt)
def test_schedule1(): def test_schedule1():
...@@ -17,11 +19,12 @@ def test_schedule1(): ...@@ -17,11 +19,12 @@ def test_schedule1():
l = tvm.Var('l') l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
sA1 = tvm.Schedule(A1.op)
xo, xi = sA1.split(A1.op.axis[0], 8) s = tvm.Schedule(A1.op)
bounds = tvm.schedule.InferBound(sA1) xo, xi = s[A1].split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA1, bounds) stmt = tvm.ir_pass.ScheduleOps(s, bounds)
print(stmt) print(stmt)
def test_schedule2(): def test_schedule2():
...@@ -30,13 +33,13 @@ def test_schedule2(): ...@@ -30,13 +33,13 @@ def test_schedule2():
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op) s = tvm.Schedule(A2.op)
xo, xi = sA2.split(A2.op.axis[0], 8) xo, xi = s[A2].split(A2.op.axis[0], 8)
sA1.compute_at(sA2, xo) s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(sA2) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.ir_pass.ScheduleOps(sA2, bounds) stmt = tvm.ir_pass.ScheduleOps(s, bounds)
print(stmt) print(stmt)
......
...@@ -6,11 +6,11 @@ def test_bound1(): ...@@ -6,11 +6,11 @@ def test_bound1():
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op) s = tvm.Schedule([A2.op])
xo, xi = sA2.split(A2.op.axis[0], 8) xo, xi = s[A2].split(s[A2].op.axis[0], 8)
sA1.compute_at(sA2, xo) s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(sA2) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.axis[0]].extent.value == 8) assert(bounds[A1.op.axis[0]].extent.value == 8)
...@@ -20,11 +20,10 @@ def test_bound2(): ...@@ -20,11 +20,10 @@ def test_bound2():
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op) s = tvm.Schedule(A2.op)
sA2 = tvm.Schedule(A2.op) xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8)
xo, yo, xi, yi = sA2.tile(A2.op.axis[0], A2.op.axis[1], 8, 8) s[A1].compute_at(s[A2], yo)
sA1.compute_at(sA2, yo) bounds = tvm.schedule.InferBound(s)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.axis[0]].extent.value == 8) assert(bounds[A1.op.axis[0]].extent.value == 8)
assert(bounds[A1.op.axis[1]].extent.value == 8) assert(bounds[A1.op.axis[1]].extent.value == 8)
...@@ -35,16 +34,18 @@ def test_bound3(): ...@@ -35,16 +34,18 @@ def test_bound3():
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op, scope="shared")
sA2 = tvm.Schedule(A2.op) s = tvm.Schedule(A2.op)
s[A1].set_scope("shared")
thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x") thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x")
xo, xi = sA2.split(A2.op.axis[0], 32) xo, xi = s[A2].split(A2.op.axis[0], 32)
xi0, xi1 = sA2.split(xi, outer=thread_x) xi0, xi1 = s[A2].split(xi, outer=thread_x)
yo, yi = sA2.split(A2.op.axis[1], 16) yo, yi = s[A2].split(A2.op.axis[1], 16)
sA2.reorder(xo, xi0, yo, xi1, yi) s[A2].reorder(xo, xi0, yo, xi1, yi)
sA1.compute_at(sA2, yo) s[A1].compute_at(s[A2], yo)
bounds = tvm.schedule.InferBound(sA2) bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16) assert(bounds[A1.op.axis[1]].extent.value==16)
...@@ -56,10 +57,12 @@ def test_create_read_graph(): ...@@ -56,10 +57,12 @@ def test_create_read_graph():
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j]) A1 = tvm.compute((m, l), lambda i, j: A[i, j])
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3) A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3)
g = tvm.schedule.CreateReadGraph(A2.op)
g = tvm.schedule.CreateReadGraph([A2.op])
assert g[A2.op][0] == A1 assert g[A2.op][0] == A1
assert g[A1.op][0] == A assert g[A1.op][0] == A
post_order = tvm.schedule.PostDFSOrder(A2.op, g) post_order = tvm.schedule.PostDFSOrder([A2.op], g)
assert(post_order[0] == A1.op) assert(post_order[0] == A1.op)
assert(post_order[1] == A2.op) assert(post_order[1] == A2.op)
......
#!/bin/bash #!/bin/bash
if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then
if [ ${TASK} == "lint" ]; then if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then
make lint || exit -1 make lint || exit -1
echo "Check documentations of c++ code..." echo "Check documentations of c++ code..."
make doc 2>log.txt make doc 2>log.txt
(cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag") > logclean.txt (cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag") > logclean.txt
echo "---------Error Log----------" echo "---------Error Log----------"
cat logclean.txt cat logclean.txt
echo "----------------------------" echo "----------------------------"
(cat logclean.txt|grep warning) && exit -1 (cat logclean.txt|grep warning) && exit -1
(cat logclean.txt|grep error) && exit -1 (cat logclean.txt|grep error) && exit -1
exit 0 fi
fi fi
...@@ -22,19 +22,16 @@ if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then ...@@ -22,19 +22,16 @@ if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then
fi fi
fi fi
if [ ${TASK} == "cpp_test" ]; then if [ ${TASK} == "cpp_test" ] || [ ${TASK} == "all_test" ]; then
make -f dmlc-core/scripts/packages.mk gtest make -f dmlc-core/scripts/packages.mk gtest
make test || exit -1 make test || exit -1
for test in tests/cpp/*_test; do for test in tests/cpp/*_test; do
./$test || exit -1 ./$test || exit -1
done done
exit 0
fi fi
# run two test one for cython, one for ctypes if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
if [ ${TASK} == "python_test" ]; then make all || exit -1
make clean
make -j all || exit -1
if [ ${TRAVIS_OS_NAME} == "osx" ]; then if [ ${TRAVIS_OS_NAME} == "osx" ]; then
python -m nose tests/python/ || exit -1 python -m nose tests/python/ || exit -1
python3 -m nose tests/python/ || exit -1 python3 -m nose tests/python/ || exit -1
...@@ -42,5 +39,4 @@ if [ ${TASK} == "python_test" ]; then ...@@ -42,5 +39,4 @@ if [ ${TASK} == "python_test" ]; then
nosetests tests/python/ || exit -1 nosetests tests/python/ || exit -1
nosetests3 tests/python/ || exit -1 nosetests3 tests/python/ || exit -1
fi fi
exit 0
fi fi
#!/bin/bash #!/bin/bash
if [ ${TRAVIS_OS_NAME} == "osx" ]; then if [ ${TRAVIS_OS_NAME} == "osx" ]; then
brew update if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
brew install python3 brew update
if [ ${TASK} == "python_test" ]; then brew install python3
python -m pip install --user nose python -m pip install --user nose numpy
python3 -m pip install --user nose python3 -m pip install --user nose numpy
fi fi
fi fi
if [ ${TASK} == "lint" ]; then if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then
pip install --user cpplint 'pylint==1.4.4' 'astroid==1.3.6' if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then
pip install --user cpplint 'pylint==1.4.4' 'astroid==1.3.6'
fi
fi fi
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment