Commit 84aeaf48 by ziheng Committed by GitHub

Change Schedule Array constructor to static make method (#170)

* Change Schedule Array constructor to static make method

* Add CreateSchedule

* Add doc

* Change CreateSchedule to create_schedule at cpp side
parent 3bf72469
......@@ -198,11 +198,6 @@ class Schedule : public NodeRef {
Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief construct schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
*/
explicit Schedule(Array<Operation> ops);
/*!
* \brief Get a copy of current schedule.
* \return The copied schedule.
*/
......@@ -439,10 +434,26 @@ class ScheduleNode : public Node {
/*! \brief Invalidate temp cache. */
void InvalidateCache();
/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
static Schedule make(Array<Operation> ops);
static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
};
/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
inline Schedule create_schedule(Array<Operation> ops) {
return ScheduleNode::make(ops);
}
/*! \brief node container for IterVar attr */
class IterVarAttrNode : public Node {
public:
......
......@@ -73,7 +73,7 @@ def create_schedule(ops):
"""
if not isinstance(ops, (list, _collections.Array)):
ops = [ops]
return _api_internal._Schedule(ops)
return _api_internal._CreateSchedule(ops)
@register_node
......
......@@ -225,9 +225,9 @@ TVM_REGISTER_API("_IterVar")
args[3]);
});
TVM_REGISTER_API("_Schedule")
TVM_REGISTER_API("_CreateSchedule")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Schedule(args[0].operator Array<Operation>());
*ret = create_schedule(args[0].operator Array<Operation>());
});
TVM_REGISTER_API("_StageSetScope")
......
......@@ -322,45 +322,6 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*)
return *this;
}
Schedule::Schedule(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>();
node_ = n;
n->outputs = ops;
auto g = schedule::CreateReadGraph(n->outputs);
Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g);
// output set.
std::unordered_set<Operation> output_set;
for (Operation x : ops) {
output_set.insert(x);
}
for (Operation op : post_order) {
Stage stage(op);
stage->is_output = output_set.count(op) != 0;
n->stages.push_back(stage);
n->stage_map.Set(op, stage);
// mark scan updates.
if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
Array<Tensor> inputs;
for (Tensor t : scan->state_placeholder) {
inputs.push_back(t);
}
for (Tensor t : scan->inputs) {
inputs.push_back(t);
}
// Create the scan group.
Stage scan_group = create_group(scan->update, inputs, false);
scan_group->attach_type = kScanUpdate;
scan_group->attach_stage = stage;
for (size_t i = 0; i < scan->update.size(); ++i) {
Stage s = n->stage_map[scan->update[i]->op];
CHECK(scan_group.same_as(s->group));
}
}
}
}
Stage CopyStage(const Stage& s) {
std::shared_ptr<StageNode> n =
std::make_shared<StageNode>(*s.operator->());
......@@ -580,6 +541,46 @@ void ScheduleNode::InitCache() {
CHECK_EQ(op2stage_cache_.size(), stages.size());
}
Schedule ScheduleNode::make(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>();
Schedule sch(n);
n->outputs = ops;
auto g = schedule::CreateReadGraph(n->outputs);
Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g);
// output set.
std::unordered_set<Operation> output_set;
for (Operation x : ops) {
output_set.insert(x);
}
for (Operation op : post_order) {
Stage stage(op);
stage->is_output = output_set.count(op) != 0;
n->stages.push_back(stage);
n->stage_map.Set(op, stage);
// mark scan updates.
if (op.as<ScanOpNode>()) {
const ScanOpNode* scan = op.as<ScanOpNode>();
Array<Tensor> inputs;
for (Tensor t : scan->state_placeholder) {
inputs.push_back(t);
}
for (Tensor t : scan->inputs) {
inputs.push_back(t);
}
// Create the scan group.
Stage scan_group = sch.create_group(scan->update, inputs, false);
scan_group->attach_type = kScanUpdate;
scan_group->attach_stage = stage;
for (size_t i = 0; i < scan->update.size(); ++i) {
Stage s = n->stage_map[scan->update[i]->op];
CHECK(scan_group.same_as(s->group));
}
}
}
return sch;
}
IterVarRelation SplitNode::make(IterVar parent,
IterVar outer,
IterVar inner,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment