Commit 84aeaf48 by ziheng Committed by GitHub

Change Schedule Array constructor to static make method (#170)

* Change Schedule Array constructor to static make method

* Add CreateSchedule

* Add doc

* Change CreateSchedule to create_schedule at cpp side
parent 3bf72469
...@@ -198,11 +198,6 @@ class Schedule : public NodeRef { ...@@ -198,11 +198,6 @@ class Schedule : public NodeRef {
Schedule() {} Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief construct schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
*/
explicit Schedule(Array<Operation> ops);
/*!
* \brief Get a copy of current schedule. * \brief Get a copy of current schedule.
* \return The copied schedule. * \return The copied schedule.
*/ */
...@@ -439,10 +434,26 @@ class ScheduleNode : public Node { ...@@ -439,10 +434,26 @@ class ScheduleNode : public Node {
/*! \brief Invalidate temp cache. */ /*! \brief Invalidate temp cache. */
void InvalidateCache(); void InvalidateCache();
/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
static Schedule make(Array<Operation> ops);
static constexpr const char* _type_key = "Schedule"; static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node); TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
}; };
/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
inline Schedule create_schedule(Array<Operation> ops) {
return ScheduleNode::make(ops);
}
/*! \brief node container for IterVar attr */ /*! \brief node container for IterVar attr */
class IterVarAttrNode : public Node { class IterVarAttrNode : public Node {
public: public:
......
...@@ -73,7 +73,7 @@ def create_schedule(ops): ...@@ -73,7 +73,7 @@ def create_schedule(ops):
""" """
if not isinstance(ops, (list, _collections.Array)): if not isinstance(ops, (list, _collections.Array)):
ops = [ops] ops = [ops]
return _api_internal._Schedule(ops) return _api_internal._CreateSchedule(ops)
@register_node @register_node
......
...@@ -225,9 +225,9 @@ TVM_REGISTER_API("_IterVar") ...@@ -225,9 +225,9 @@ TVM_REGISTER_API("_IterVar")
args[3]); args[3]);
}); });
TVM_REGISTER_API("_Schedule") TVM_REGISTER_API("_CreateSchedule")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Schedule(args[0].operator Array<Operation>()); *ret = create_schedule(args[0].operator Array<Operation>());
}); });
TVM_REGISTER_API("_StageSetScope") TVM_REGISTER_API("_StageSetScope")
......
...@@ -322,45 +322,6 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*) ...@@ -322,45 +322,6 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*)
return *this; return *this;
} }
Schedule::Schedule(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>();
node_ = n;
n->outputs = ops;
auto g = schedule::CreateReadGraph(n->outputs);
Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g);
// output set.
std::unordered_set<Operation> output_set;
for (Operation x : ops) {
output_set.insert(x);
}
for (Operation op : post_order) {
Stage stage(op);
stage->is_output = output_set.count(op) != 0;
n->stages.push_back(stage);
n->stage_map.Set(op, stage);
// mark scan updates.
if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
Array<Tensor> inputs;
for (Tensor t : scan->state_placeholder) {
inputs.push_back(t);
}
for (Tensor t : scan->inputs) {
inputs.push_back(t);
}
// Create the scan group.
Stage scan_group = create_group(scan->update, inputs, false);
scan_group->attach_type = kScanUpdate;
scan_group->attach_stage = stage;
for (size_t i = 0; i < scan->update.size(); ++i) {
Stage s = n->stage_map[scan->update[i]->op];
CHECK(scan_group.same_as(s->group));
}
}
}
}
Stage CopyStage(const Stage& s) { Stage CopyStage(const Stage& s) {
std::shared_ptr<StageNode> n = std::shared_ptr<StageNode> n =
std::make_shared<StageNode>(*s.operator->()); std::make_shared<StageNode>(*s.operator->());
...@@ -580,6 +541,46 @@ void ScheduleNode::InitCache() { ...@@ -580,6 +541,46 @@ void ScheduleNode::InitCache() {
CHECK_EQ(op2stage_cache_.size(), stages.size()); CHECK_EQ(op2stage_cache_.size(), stages.size());
} }
Schedule ScheduleNode::make(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>();
Schedule sch(n);
n->outputs = ops;
auto g = schedule::CreateReadGraph(n->outputs);
Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g);
// output set.
std::unordered_set<Operation> output_set;
for (Operation x : ops) {
output_set.insert(x);
}
for (Operation op : post_order) {
Stage stage(op);
stage->is_output = output_set.count(op) != 0;
n->stages.push_back(stage);
n->stage_map.Set(op, stage);
// mark scan updates.
if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
Array<Tensor> inputs;
for (Tensor t : scan->state_placeholder) {
inputs.push_back(t);
}
for (Tensor t : scan->inputs) {
inputs.push_back(t);
}
// Create the scan group.
Stage scan_group = sch.create_group(scan->update, inputs, false);
scan_group->attach_type = kScanUpdate;
scan_group->attach_stage = stage;
for (size_t i = 0; i < scan->update.size(); ++i) {
Stage s = n->stage_map[scan->update[i]->op];
CHECK(scan_group.same_as(s->group));
}
}
}
return sch;
}
IterVarRelation SplitNode::make(IterVar parent, IterVarRelation SplitNode::make(IterVar parent,
IterVar outer, IterVar outer,
IterVar inner, IterVar inner,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment