Commit b21aee7d by alex-weaver Committed by Tianqi Chen

Fixed namespacing issues in schedules (#873)

* Fixed namespacing issues in schedules

* Fixed compile error
parent 7d2654c2
......@@ -86,7 +86,7 @@ inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs)
auto dense_f = s.rfactor(dense, kf)[0];
Tensor out;
if (contains(s->outputs, dense->op)) {
if (detail::contains(s->outputs, dense->op)) {
out = dense;
} else {
out = outs[0]->op.output(0);
......@@ -107,7 +107,7 @@ inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs)
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) {
if (!contains(s->outputs, op)) {
if (!detail::contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
......
......@@ -27,7 +27,7 @@ namespace cuda {
*/
inline Schedule ScheduleOutputForExtern(Target target, Operation op, Schedule sch) {
auto x = op.output(0);
auto fused = Fuse(sch[x], sch[x]->op.as<ComputeOpNode>()->axis);
auto fused = detail::Fuse(sch[x], sch[x]->op.as<ComputeOpNode>()->axis);
auto num_thread = target.max_num_threads;
IterVar bx, tx;
sch[x].split(fused, num_thread, &bx, &tx);
......
......@@ -24,7 +24,7 @@ namespace cuda {
*/
inline void ScheduleInjectiveOp(const Target &target, Operation op, Schedule s) {
auto x = op.output(0);
auto fused = Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
auto fused = detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
auto num_thread = target.max_num_threads;
IterVar bx, tx;
s[x].split(fused, num_thread, &bx, &tx);
......
......@@ -37,19 +37,19 @@ inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
auto num_thread = target.max_num_threads;
Tensor out;
Tensor OL;
if (contains(s->outputs, pool->op)) {
if (detail::contains(s->outputs, pool->op)) {
out = pool;
OL = s.cache_write(pool, "local");
} else {
out = outs[0]->op.output(0);
s[pool].set_scope("local");
}
auto fused = Fuse(s[out], s[out]->op.as<ComputeOpNode>()->axis);
auto fused = detail::Fuse(s[out], s[out]->op.as<ComputeOpNode>()->axis);
IterVar bx, tx;
s[out].split(fused, num_thread, &bx, &tx);
s[out].bind(bx, tvm::thread_axis(Range(), "blockIdx.x"));
s[out].bind(tx, tvm::thread_axis(Range(), "threadIdx.x"));
if (contains(s->outputs, pool->op)) {
if (detail::contains(s->outputs, pool->op)) {
s[OL].compute_at(s[out], tx);
} else {
s[pool].compute_at(s[out], tx);
......@@ -60,7 +60,7 @@ inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) {
if (!contains(s->outputs, op)) {
if (!detail::contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
......@@ -105,7 +105,7 @@ inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>&
auto thread_y = tvm::thread_axis(Range(0, num_thread), "threadIdx.y");
Tensor out;
Tensor OL;
if (contains(s->outputs, pool->op)) {
if (detail::contains(s->outputs, pool->op)) {
out = pool;
OL = s.cache_write(pool, "local");
} else {
......@@ -126,7 +126,7 @@ inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>&
s[out].bind(by, block_y);
s[out].bind(bx, block_x);
if (contains(s->outputs, pool->op)) {
if (detail::contains(s->outputs, pool->op)) {
s[OL].compute_at(s[out], tx);
} else {
s[pool].compute_at(s[out], tx);
......@@ -137,7 +137,7 @@ inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>&
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) {
if (!contains(s->outputs, op)) {
if (!detail::contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
......
......@@ -65,7 +65,7 @@ Schedule ScheduleReduce(const Target& target,
thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
}
auto fused_reduce = Fuse(out_stage, out_stage->op.as<ComputeOpNode>()->reduce_axis);
auto fused_reduce = detail::Fuse(out_stage, out_stage->op.as<ComputeOpNode>()->reduce_axis);
IterVar ko, ki;
out_stage.split(fused_reduce, num_thread, &ko, &ki);
......@@ -87,7 +87,7 @@ Schedule ScheduleReduce(const Target& target,
auto stage_real = sch[real_output];
if (!all_reduce) {
// Fuse and split the axis
auto fused_outer = Fuse(stage_real, stage_real->op.as<ComputeOpNode>()->axis);
auto fused_outer = detail::Fuse(stage_real, stage_real->op.as<ComputeOpNode>()->axis);
IterVar bx, outer_in;
stage_real.split(fused_outer, num_thread, &bx, &outer_in);
......
......@@ -16,27 +16,42 @@ using namespace tvm;
namespace generic {
/*!
* \brief Create a generic default schedule for the given output tensors.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
* \param auto_inline Whether to apply the auto inline step.
*
* \return A schedule for the given ops.
*/
inline Schedule default_schedule(const Target& target, Array<Tensor> outs, bool auto_inline) {
* \brief Create a generic default schedule for the given output tensors.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule default_schedule(const Target& target, Array<Tensor> outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
return s;
}
/*!
* \brief Create a generic default schedule for the given output tensors, and apply
* auto inline
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule default_schedule_auto_inline(const Target& target, Array<Tensor> outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
if (auto_inline) {
auto x = outs[0];
tvm::schedule::AutoInlineInjective(s);
auto axis = s[x]->op.as<ComputeOpNode>()->axis;
if (axis.size() > 0) {
Fuse(s[x], axis);
}
detail::Fuse(s[x], axis);
}
return s;
}
......
......@@ -32,7 +32,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
auto x = outs[0];
Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
return s;
}
......
......@@ -68,7 +68,7 @@ inline Schedule schedule_binary_dense(const Target &target, const Array<Tensor>&
s[C].parallel(s[C]->op.as<ComputeOpNode>()->axis[0]);
Tensor out;
if (contains(s->outputs, C->op)) {
if (detail::contains(s->outputs, C->op)) {
out = C;
} else {
out = outs[0]->op.output(0);
......@@ -83,7 +83,7 @@ inline Schedule schedule_binary_dense(const Target &target, const Array<Tensor>&
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) {
if (!contains(s->outputs, op)) {
if (!detail::contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
......
......@@ -16,7 +16,7 @@ using namespace tvm;
namespace x86 {
/*!
* \brief Create a default x86 schedule for the given ops.
* \brief Helper to create a default x86 schedule for the given ops.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
......@@ -24,7 +24,7 @@ namespace x86 {
*
* \return A schedule for the given ops.
*/
inline Schedule default_schedule(const Target &target,
inline Schedule MakeDefaultSchedule(const Target &target,
const Array<Tensor>& outs,
bool auto_inline) {
Array<Operation> out_ops;
......@@ -38,7 +38,7 @@ inline Schedule default_schedule(const Target &target,
if (auto_inline) {
tvm::schedule::AutoInlineInjective(s);
if (axis.size() > 0) {
Fuse(s[x], axis);
detail::Fuse(s[x], axis);
}
return s;
}
......@@ -46,7 +46,7 @@ inline Schedule default_schedule(const Target &target,
if (axis.size() == 4) {
auto n = axis[0];
auto c = axis[1];
auto fused = Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h
auto fused = detail::Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h
s[x].parallel(fused);
} else {
s[x].parallel(axis[0]);
......@@ -55,6 +55,30 @@ inline Schedule default_schedule(const Target &target,
return s;
}
/*!
* \brief Create a default x86 schedule for the given ops.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule default_schedule(const Target &target, const Array<Tensor>& outs) {
return MakeDefaultSchedule(target, outs, false);
}
/*!
* \brief Create a default x86 schedule for the given ops, with auto inline
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule default_schedule_auto_inline(const Target &target, const Array<Tensor>& outs) {
return MakeDefaultSchedule(target, outs, true);
}
} // namespace x86
} // namespace topi
#endif // TOPI_X86_DEFAULT_H_
......@@ -36,7 +36,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
if (axis.size() == 4) {
auto n = axis[0];
auto c = axis[1];
auto fused = Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h
auto fused = detail::Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h
s[x].parallel(fused);
} else {
s[x].parallel(axis[0]);
......
......@@ -343,7 +343,11 @@ TVM_REGISTER_GLOBAL("topi.nn.log_softmax")
/* Generic schedules */
TVM_REGISTER_GLOBAL("topi.generic.default_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::generic::default_schedule(args[0], args[1], args[2]);
if (args[2]) {
*rv = topi::generic::default_schedule_auto_inline(args[0], args[1]);
} else {
*rv = topi::generic::default_schedule(args[0], args[1]);
}
});
TVM_REGISTER_GLOBAL("topi.generic.schedule_extern")
......@@ -369,7 +373,11 @@ TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense")
TVM_REGISTER_GLOBAL("topi.x86.default_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::x86::default_schedule(args[0], args[1], args[2]);
if (args[2]) {
*rv = topi::x86::default_schedule_auto_inline(args[0], args[1]);
} else {
*rv = topi::x86::default_schedule(args[0], args[1]);
}
});
TVM_REGISTER_GLOBAL("topi.x86.schedule_injective")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment