Commit b21aee7d by alex-weaver Committed by Tianqi Chen

Fixed namespacing issues in schedules (#873)

* Fixed namespacing issues in schedules

* Fixed compile error
parent 7d2654c2
...@@ -86,7 +86,7 @@ inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs) ...@@ -86,7 +86,7 @@ inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs)
auto dense_f = s.rfactor(dense, kf)[0]; auto dense_f = s.rfactor(dense, kf)[0];
Tensor out; Tensor out;
if (contains(s->outputs, dense->op)) { if (detail::contains(s->outputs, dense->op)) {
out = dense; out = dense;
} else { } else {
out = outs[0]->op.output(0); out = outs[0]->op.output(0);
...@@ -107,7 +107,7 @@ inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs) ...@@ -107,7 +107,7 @@ inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs)
traverse = [&](const Operation& op) { traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output) // Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) { if (is_broadcast(op->tag)) {
if (!contains(s->outputs, op)) { if (!detail::contains(s->outputs, op)) {
s[op].compute_inline(); s[op].compute_inline();
} }
for (auto tensor : op->InputTensors()) { for (auto tensor : op->InputTensors()) {
......
...@@ -27,7 +27,7 @@ namespace cuda { ...@@ -27,7 +27,7 @@ namespace cuda {
*/ */
inline Schedule ScheduleOutputForExtern(Target target, Operation op, Schedule sch) { inline Schedule ScheduleOutputForExtern(Target target, Operation op, Schedule sch) {
auto x = op.output(0); auto x = op.output(0);
auto fused = Fuse(sch[x], sch[x]->op.as<ComputeOpNode>()->axis); auto fused = detail::Fuse(sch[x], sch[x]->op.as<ComputeOpNode>()->axis);
auto num_thread = target.max_num_threads; auto num_thread = target.max_num_threads;
IterVar bx, tx; IterVar bx, tx;
sch[x].split(fused, num_thread, &bx, &tx); sch[x].split(fused, num_thread, &bx, &tx);
......
...@@ -24,7 +24,7 @@ namespace cuda { ...@@ -24,7 +24,7 @@ namespace cuda {
*/ */
inline void ScheduleInjectiveOp(const Target &target, Operation op, Schedule s) { inline void ScheduleInjectiveOp(const Target &target, Operation op, Schedule s) {
auto x = op.output(0); auto x = op.output(0);
auto fused = Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis); auto fused = detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
auto num_thread = target.max_num_threads; auto num_thread = target.max_num_threads;
IterVar bx, tx; IterVar bx, tx;
s[x].split(fused, num_thread, &bx, &tx); s[x].split(fused, num_thread, &bx, &tx);
......
...@@ -37,19 +37,19 @@ inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) { ...@@ -37,19 +37,19 @@ inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
auto num_thread = target.max_num_threads; auto num_thread = target.max_num_threads;
Tensor out; Tensor out;
Tensor OL; Tensor OL;
if (contains(s->outputs, pool->op)) { if (detail::contains(s->outputs, pool->op)) {
out = pool; out = pool;
OL = s.cache_write(pool, "local"); OL = s.cache_write(pool, "local");
} else { } else {
out = outs[0]->op.output(0); out = outs[0]->op.output(0);
s[pool].set_scope("local"); s[pool].set_scope("local");
} }
auto fused = Fuse(s[out], s[out]->op.as<ComputeOpNode>()->axis); auto fused = detail::Fuse(s[out], s[out]->op.as<ComputeOpNode>()->axis);
IterVar bx, tx; IterVar bx, tx;
s[out].split(fused, num_thread, &bx, &tx); s[out].split(fused, num_thread, &bx, &tx);
s[out].bind(bx, tvm::thread_axis(Range(), "blockIdx.x")); s[out].bind(bx, tvm::thread_axis(Range(), "blockIdx.x"));
s[out].bind(tx, tvm::thread_axis(Range(), "threadIdx.x")); s[out].bind(tx, tvm::thread_axis(Range(), "threadIdx.x"));
if (contains(s->outputs, pool->op)) { if (detail::contains(s->outputs, pool->op)) {
s[OL].compute_at(s[out], tx); s[OL].compute_at(s[out], tx);
} else { } else {
s[pool].compute_at(s[out], tx); s[pool].compute_at(s[out], tx);
...@@ -60,7 +60,7 @@ inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) { ...@@ -60,7 +60,7 @@ inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
traverse = [&](const Operation& op) { traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output) // Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) { if (is_broadcast(op->tag)) {
if (!contains(s->outputs, op)) { if (!detail::contains(s->outputs, op)) {
s[op].compute_inline(); s[op].compute_inline();
} }
for (auto tensor : op->InputTensors()) { for (auto tensor : op->InputTensors()) {
...@@ -105,7 +105,7 @@ inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>& ...@@ -105,7 +105,7 @@ inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>&
auto thread_y = tvm::thread_axis(Range(0, num_thread), "threadIdx.y"); auto thread_y = tvm::thread_axis(Range(0, num_thread), "threadIdx.y");
Tensor out; Tensor out;
Tensor OL; Tensor OL;
if (contains(s->outputs, pool->op)) { if (detail::contains(s->outputs, pool->op)) {
out = pool; out = pool;
OL = s.cache_write(pool, "local"); OL = s.cache_write(pool, "local");
} else { } else {
...@@ -126,7 +126,7 @@ inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>& ...@@ -126,7 +126,7 @@ inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>&
s[out].bind(by, block_y); s[out].bind(by, block_y);
s[out].bind(bx, block_x); s[out].bind(bx, block_x);
if (contains(s->outputs, pool->op)) { if (detail::contains(s->outputs, pool->op)) {
s[OL].compute_at(s[out], tx); s[OL].compute_at(s[out], tx);
} else { } else {
s[pool].compute_at(s[out], tx); s[pool].compute_at(s[out], tx);
...@@ -137,7 +137,7 @@ inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>& ...@@ -137,7 +137,7 @@ inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>&
traverse = [&](const Operation& op) { traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output) // Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) { if (is_broadcast(op->tag)) {
if (!contains(s->outputs, op)) { if (!detail::contains(s->outputs, op)) {
s[op].compute_inline(); s[op].compute_inline();
} }
for (auto tensor : op->InputTensors()) { for (auto tensor : op->InputTensors()) {
......
...@@ -65,7 +65,7 @@ Schedule ScheduleReduce(const Target& target, ...@@ -65,7 +65,7 @@ Schedule ScheduleReduce(const Target& target,
thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x"); thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
} }
auto fused_reduce = Fuse(out_stage, out_stage->op.as<ComputeOpNode>()->reduce_axis); auto fused_reduce = detail::Fuse(out_stage, out_stage->op.as<ComputeOpNode>()->reduce_axis);
IterVar ko, ki; IterVar ko, ki;
out_stage.split(fused_reduce, num_thread, &ko, &ki); out_stage.split(fused_reduce, num_thread, &ko, &ki);
...@@ -87,7 +87,7 @@ Schedule ScheduleReduce(const Target& target, ...@@ -87,7 +87,7 @@ Schedule ScheduleReduce(const Target& target,
auto stage_real = sch[real_output]; auto stage_real = sch[real_output];
if (!all_reduce) { if (!all_reduce) {
// Fuse and split the axis // Fuse and split the axis
auto fused_outer = Fuse(stage_real, stage_real->op.as<ComputeOpNode>()->axis); auto fused_outer = detail::Fuse(stage_real, stage_real->op.as<ComputeOpNode>()->axis);
IterVar bx, outer_in; IterVar bx, outer_in;
stage_real.split(fused_outer, num_thread, &bx, &outer_in); stage_real.split(fused_outer, num_thread, &bx, &outer_in);
......
...@@ -16,27 +16,42 @@ using namespace tvm; ...@@ -16,27 +16,42 @@ using namespace tvm;
namespace generic { namespace generic {
/*! /*!
* \brief Create a generic default schedule for the given output tensors. * \brief Create a generic default schedule for the given output tensors.
* *
* \param target The target to generate a schedule for. * \param target The target to generate a schedule for.
* \param outs The output tensors. * \param outs The output tensors.
* \param auto_inline Whether to apply the auto inline step. *
* * \return A schedule for the given ops.
* \return A schedule for the given ops. */
*/ inline Schedule default_schedule(const Target& target, Array<Tensor> outs) {
inline Schedule default_schedule(const Target& target, Array<Tensor> outs, bool auto_inline) { Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
return s;
}
/*!
* \brief Create a generic default schedule for the given output tensors, and apply
* auto inline
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule default_schedule_auto_inline(const Target& target, Array<Tensor> outs) {
Array<Operation> out_ops; Array<Operation> out_ops;
for (auto t : outs) { for (auto t : outs) {
out_ops.push_back(t->op); out_ops.push_back(t->op);
} }
auto s = create_schedule(out_ops); auto s = create_schedule(out_ops);
if (auto_inline) {
auto x = outs[0]; auto x = outs[0];
tvm::schedule::AutoInlineInjective(s); tvm::schedule::AutoInlineInjective(s);
auto axis = s[x]->op.as<ComputeOpNode>()->axis; auto axis = s[x]->op.as<ComputeOpNode>()->axis;
if (axis.size() > 0) { if (axis.size() > 0) {
Fuse(s[x], axis); detail::Fuse(s[x], axis);
}
} }
return s; return s;
} }
......
...@@ -32,7 +32,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou ...@@ -32,7 +32,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
auto s = create_schedule(out_ops); auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s); tvm::schedule::AutoInlineInjective(s);
auto x = outs[0]; auto x = outs[0];
Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis); detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
return s; return s;
} }
......
...@@ -68,7 +68,7 @@ inline Schedule schedule_binary_dense(const Target &target, const Array<Tensor>& ...@@ -68,7 +68,7 @@ inline Schedule schedule_binary_dense(const Target &target, const Array<Tensor>&
s[C].parallel(s[C]->op.as<ComputeOpNode>()->axis[0]); s[C].parallel(s[C]->op.as<ComputeOpNode>()->axis[0]);
Tensor out; Tensor out;
if (contains(s->outputs, C->op)) { if (detail::contains(s->outputs, C->op)) {
out = C; out = C;
} else { } else {
out = outs[0]->op.output(0); out = outs[0]->op.output(0);
...@@ -83,7 +83,7 @@ inline Schedule schedule_binary_dense(const Target &target, const Array<Tensor>& ...@@ -83,7 +83,7 @@ inline Schedule schedule_binary_dense(const Target &target, const Array<Tensor>&
traverse = [&](const Operation& op) { traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output) // Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) { if (is_broadcast(op->tag)) {
if (!contains(s->outputs, op)) { if (!detail::contains(s->outputs, op)) {
s[op].compute_inline(); s[op].compute_inline();
} }
for (auto tensor : op->InputTensors()) { for (auto tensor : op->InputTensors()) {
......
...@@ -16,7 +16,7 @@ using namespace tvm; ...@@ -16,7 +16,7 @@ using namespace tvm;
namespace x86 { namespace x86 {
/*! /*!
* \brief Create a default x86 schedule for the given ops. * \brief Helper to create a default x86 schedule for the given ops.
* *
* \param target The target to generate a schedule for. * \param target The target to generate a schedule for.
* \param outs The output tensors. * \param outs The output tensors.
...@@ -24,7 +24,7 @@ namespace x86 { ...@@ -24,7 +24,7 @@ namespace x86 {
* *
* \return A schedule for the given ops. * \return A schedule for the given ops.
*/ */
inline Schedule default_schedule(const Target &target, inline Schedule MakeDefaultSchedule(const Target &target,
const Array<Tensor>& outs, const Array<Tensor>& outs,
bool auto_inline) { bool auto_inline) {
Array<Operation> out_ops; Array<Operation> out_ops;
...@@ -38,7 +38,7 @@ inline Schedule default_schedule(const Target &target, ...@@ -38,7 +38,7 @@ inline Schedule default_schedule(const Target &target,
if (auto_inline) { if (auto_inline) {
tvm::schedule::AutoInlineInjective(s); tvm::schedule::AutoInlineInjective(s);
if (axis.size() > 0) { if (axis.size() > 0) {
Fuse(s[x], axis); detail::Fuse(s[x], axis);
} }
return s; return s;
} }
...@@ -46,7 +46,7 @@ inline Schedule default_schedule(const Target &target, ...@@ -46,7 +46,7 @@ inline Schedule default_schedule(const Target &target,
if (axis.size() == 4) { if (axis.size() == 4) {
auto n = axis[0]; auto n = axis[0];
auto c = axis[1]; auto c = axis[1];
auto fused = Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h auto fused = detail::Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h
s[x].parallel(fused); s[x].parallel(fused);
} else { } else {
s[x].parallel(axis[0]); s[x].parallel(axis[0]);
...@@ -55,6 +55,30 @@ inline Schedule default_schedule(const Target &target, ...@@ -55,6 +55,30 @@ inline Schedule default_schedule(const Target &target,
return s; return s;
} }
/*!
* \brief Create a default x86 schedule for the given ops.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule default_schedule(const Target &target, const Array<Tensor>& outs) {
return MakeDefaultSchedule(target, outs, false);
}
/*!
* \brief Create a default x86 schedule for the given ops, with auto inline
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule default_schedule_auto_inline(const Target &target, const Array<Tensor>& outs) {
return MakeDefaultSchedule(target, outs, true);
}
} // namespace x86 } // namespace x86
} // namespace topi } // namespace topi
#endif // TOPI_X86_DEFAULT_H_ #endif // TOPI_X86_DEFAULT_H_
...@@ -36,7 +36,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou ...@@ -36,7 +36,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
if (axis.size() == 4) { if (axis.size() == 4) {
auto n = axis[0]; auto n = axis[0];
auto c = axis[1]; auto c = axis[1];
auto fused = Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h auto fused = detail::Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h
s[x].parallel(fused); s[x].parallel(fused);
} else { } else {
s[x].parallel(axis[0]); s[x].parallel(axis[0]);
......
...@@ -343,7 +343,11 @@ TVM_REGISTER_GLOBAL("topi.nn.log_softmax") ...@@ -343,7 +343,11 @@ TVM_REGISTER_GLOBAL("topi.nn.log_softmax")
/* Generic schedules */ /* Generic schedules */
TVM_REGISTER_GLOBAL("topi.generic.default_schedule") TVM_REGISTER_GLOBAL("topi.generic.default_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::generic::default_schedule(args[0], args[1], args[2]); if (args[2]) {
*rv = topi::generic::default_schedule_auto_inline(args[0], args[1]);
} else {
*rv = topi::generic::default_schedule(args[0], args[1]);
}
}); });
TVM_REGISTER_GLOBAL("topi.generic.schedule_extern") TVM_REGISTER_GLOBAL("topi.generic.schedule_extern")
...@@ -369,7 +373,11 @@ TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense") ...@@ -369,7 +373,11 @@ TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense")
TVM_REGISTER_GLOBAL("topi.x86.default_schedule") TVM_REGISTER_GLOBAL("topi.x86.default_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::x86::default_schedule(args[0], args[1], args[2]); if (args[2]) {
*rv = topi::x86::default_schedule_auto_inline(args[0], args[1]);
} else {
*rv = topi::x86::default_schedule(args[0], args[1]);
}
}); });
TVM_REGISTER_GLOBAL("topi.x86.schedule_injective") TVM_REGISTER_GLOBAL("topi.x86.schedule_injective")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment