Commit 59bb0dd4 by tqchen

temp checkin of schedule

parent 357ad592
Subproject commit e96ee0f2fb5239021c0facd5398a9a96644bc411 Subproject commit 29fd3defa3dbf810e52dbc2ecd3933604989dcc8
...@@ -107,18 +107,21 @@ using Halide::select; ...@@ -107,18 +107,21 @@ using Halide::select;
/*! /*!
* \brief sum of of source expression over rdom * \brief sum of of source expression over rdom
* \param source The source expression. * \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
*/ */
Expr sum(Expr source, Array<IterVar> rdom); Expr sum(Expr source, Array<IterVar> rdom);
/*! /*!
* \brief max of of source expression over rdom * \brief max of of source expression over rdom
* \param source The source expression. * \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
*/ */
Expr max(Expr source, Array<IterVar> rdom); Expr max(Expr source, Array<IterVar> rdom);
/*! /*!
* \brief max of of source expression over rdom * \brief max of of source expression over rdom
* \param source The source expression. * \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
*/ */
Expr min(Expr source, Array<IterVar> rdom); Expr min(Expr source, Array<IterVar> rdom);
......
...@@ -19,9 +19,10 @@ class IterVarRelationNode; ...@@ -19,9 +19,10 @@ class IterVarRelationNode;
/*! \brief the attachment type */ /*! \brief the attachment type */
enum AttachType : int { enum AttachType : int {
kRoot = 0, kNone = 0,
kInline = 1, kRoot = 1,
kSplit = 2 kInline = 2,
kScope = 3
}; };
/*! \brief schedule container */ /*! \brief schedule container */
...@@ -29,12 +30,70 @@ class Schedule : public NodeRef { ...@@ -29,12 +30,70 @@ class Schedule : public NodeRef {
public: public:
Schedule() {} Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief create a new schedule for op.
* \param op The operator in the schedule
* \param scope The scope of the schedule
*/
Schedule(Operation op, std::string scope); Schedule(Operation op, std::string scope);
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
*/ */
inline const ScheduleNode* operator->() const; inline const ScheduleNode* operator->() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline ScheduleNode* operator->();
/*!
* \brief specify the schedule to be computed at the parent schedule's scope.
* \param parent The parent schedule.
* \param scope The iteration point to carry the schedule.
*/
Schedule& compute_at(Schedule parent, IterVar scope); // NOLINT(*)
/*!
* \brief Compute the function inline, attach it at parent.
* \param parent The parent schedule to be attached to.
*/
Schedule& compute_inline(Schedule parent); // NOLINT(*)
/*!
* \brief Compute the function at root, attach it to its parent.
* \param parent The parent schedule to be attached to.
*/
Schedule& compute_root(Schedule parent); // NOLINT(*)
/*!
* \brief Split the parent by factor, generate
* \param parent The parent iteration domain.
* \param p_outer The result outer domain
* \param p_inner The result inner domain.
* \param factor The split factor of the loop.
* \param outer The generated
*/
Schedule& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
/*!
* \brief Split the iteration with a given outer domain,
* the outer domain must have a thread-tag.
*
* \param parent The parent domain.
* \param outer The outer domain to be spliited, must have a thread_tag.
* \param p_inner The result inner domain.
* \param factor Optional, the factor of the split,
* factor must be provided such that factor * outer.extent >= parent.extent.
*/
Schedule& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
/*!
* \brief Fuse the inner outer domain to the target
* \param inner The inner domain to be fused
* \param outer The outer domain to be fused.
* \param p_target The result target domain.
*/
Schedule& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
/*!
* \brief Reorder the iteration
* \param order The order of iteration variable.
*/
Schedule& reorder(const Array<IterVar>& order); // NOLINT(*)
}; };
/*! /*!
...@@ -83,7 +142,7 @@ class ScheduleNode : public Node { ...@@ -83,7 +142,7 @@ class ScheduleNode : public Node {
/*! \brief The relation bwteen of IterVars */ /*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations; Array<IterVarRelation> relations;
/*! \brief The attachment type of the schedule */ /*! \brief The attachment type of the schedule */
AttachType attach_type; AttachType attach_type{kNone};
/*! /*!
* \brief The attach point of this schedule. * \brief The attach point of this schedule.
*/ */
...@@ -169,6 +228,9 @@ class FuseNode : public IterVarRelationNode { ...@@ -169,6 +228,9 @@ class FuseNode : public IterVarRelationNode {
inline const ScheduleNode* Schedule::operator->() const { inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get()); return static_cast<const ScheduleNode*>(node_.get());
} }
inline ScheduleNode* Schedule::operator->() {
return static_cast<ScheduleNode*>(node_.get());
}
inline const IterVarRelationNode* IterVarRelation::operator->() const { inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get()); return static_cast<const IterVarRelationNode*>(node_.get());
......
...@@ -6,13 +6,68 @@ ...@@ -6,13 +6,68 @@
namespace tvm { namespace tvm {
namespace {
// find first occurance location in leaf
size_t FindIterVar(ArrayNode* array_node, const IterVar& v) {
const Node* n = v.get();
for (size_t i = 0; i < array_node->data.size(); ++i) {
if (array_node->data[i].get() == n) return i;
}
return array_node->data.size();
}
size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* const IterVar& v) {
size_t pos = Find(leaf_iter_vars, parent);
}
}
Schedule::Schedule(Operation op, std::string scope) { Schedule::Schedule(Operation op, std::string scope) {
auto n = std::make_shared<ScheduleNode>(); auto n = std::make_shared<ScheduleNode>();
n->op = op; n->op = op;
n->scope = scope; n->scope = scope;
n->all_iter_vars = op->root_iter_vars();
n->leaf_iter_vars = op->root_iter_vars();
node_ = n; node_ = n;
} }
Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*)
CHECK_EQ((*this)->attach_type, kNone);
(*this)->attach_type = kScope;
(*this)->attach_parent = scope;
parent->children.push_back(*this);
return *this;
}
Schedule& Schedule::compute_inline(Schedule parent) { // NOLINT(*)
CHECK_EQ((*this)->attach_type, kNone);
(*this)->attach_type = kInline;
parent->children.push_back(*this);
return *this;
}
Schedule& Schedule::compute_root(Schedule parent) { // NOLINT(*)
CHECK_EQ((*this)->attach_type, kNone);
(*this)->attach_type = kRoot;
parent->children.push_back(*this);
return *this;
}
Schedule& Schedule::split(
IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
ScheduleNode* self = operator->();
ArrayNode* leaf_iter_vars = self->leaf_iter_vars.CopyOnWrite();
CHECK(pos != leaf_iter_vars->data.size())
<< "Cannot find IterVar " << parent << " in the active leaf vars"
<< " this means "
return *this;
}
IterVarRelation SplitNode::make( IterVarRelation SplitNode::make(
IterVar parent, IterVar outer, IterVar parent, IterVar outer,
IterVar inner, Expr factor) { IterVar inner, Expr factor) {
......
...@@ -8,7 +8,18 @@ def test_schedule_create(): ...@@ -8,7 +8,18 @@ def test_schedule_create():
B = tvm.placeholder((n, l), name='B') B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
sch = tvm.Schedule(T.op, scope="shared") Tsch = tvm.Schedule(T.op, scope="shared")
Asch = tvm.Schedule(A.op)
T.op.
xo, xi = sch.split(sch.dim_var[0], factor)
Asch.compute_at(Tsch, xi)
xf = sch.fuse(xo, xi)
tk1 = tvm.Split(T.op.dim_var[0], 10) tk1 = tvm.Split(T.op.dim_var[0], 10)
assert isinstance(sch, tvm.schedule.Schedule) assert isinstance(sch, tvm.schedule.Schedule)
assert isinstance(tk1, tvm.schedule.DimSplit) assert isinstance(tk1, tvm.schedule.DimSplit)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment