Commit 6a62beb2 by Ziheng Jiang Committed by Tianqi Chen

[FUSION] add 'void AutoFuseEwise(Schedule sch)' (#36)

* [FUSION] add Fusion(Schedule)

* [FUSION] rename to AutoFuseEwise, detect whether the stage has been scheduled

* [FUSION] change to visitor pattern

* [FUSION] rename filename

* [FUSION] fine-tune the interface

* [FUSION] typo

* move elem_wise to schedule

* rename test function
parent 08505e34
...@@ -167,7 +167,6 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func); ...@@ -167,7 +167,6 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
*/ */
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -123,6 +123,12 @@ class Stage : public NodeRef { ...@@ -123,6 +123,12 @@ class Stage : public NodeRef {
IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner, IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor); Expr x_factor, Expr y_factor);
/*!
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
*/
inline bool is_scheduled() const;
// declare container type // declare container type
using ContainerType = StageNode; using ContainerType = StageNode;
}; };
...@@ -353,6 +359,11 @@ inline StageNode* Stage::operator->() { ...@@ -353,6 +359,11 @@ inline StageNode* Stage::operator->() {
return static_cast<StageNode*>(node_.get()); return static_cast<StageNode*>(node_.get());
} }
inline bool Stage::is_scheduled() const {
const StageNode* n = operator->();
return !(n->relations.empty() && n->attach_type == kNone);
}
inline const ScheduleNode* Schedule::operator->() const { inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get()); return static_cast<const ScheduleNode*>(node_.get());
} }
......
...@@ -33,6 +33,13 @@ Map<IterVar, Range> InferBound(Schedule sch); ...@@ -33,6 +33,13 @@ Map<IterVar, Range> InferBound(Schedule sch);
*/ */
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map); Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
/*!
* \brief To automatically inline the element-wise operations.
*
* \param sch The schedule to be inlined.
*/
void AutoInlineElemWise(Schedule sch);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_ #endif // TVM_SCHEDULE_PASS_H_
...@@ -135,7 +135,7 @@ class Stage(NodeBase): ...@@ -135,7 +135,7 @@ class Stage(NodeBase):
parent : Stage parent : Stage
The parent stage The parent stage
""" """
_api_internal._StageComputeInline(self) _api_internal._StageComputeRoot(self)
def reorder(self, *args): def reorder(self, *args):
"""reorder the arguments in the specified order. """reorder the arguments in the specified order.
......
...@@ -13,6 +13,11 @@ ...@@ -13,6 +13,11 @@
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
TVM_REGISTER_API(_schedule_AutoInlineElemWise)
.set_body([](TVMArgs args, TVMRetValue* ret) {
AutoInlineElemWise(args[0]);
});
#define REGISTER_SCHEDULE_PASS1(PassName) \ #define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \ TVM_REGISTER_API(_schedule_## PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
......
/*!
* Copyright (c) 2016 by Contributors
* \file auto_inline_elem_wise.cc
*/
#include <tvm/schedule_pass.h>
#include <tvm/ir_visitor.h>
namespace tvm {
namespace ir {
class ElemWiseDetector : public IRVisitor {
public:
explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
void Visit(const NodeRef& e) final {
if (!is_elem_wise_) return;
IRVisitor::Visit(e);
}
void Visit_(const Call* op) final {
Array<Expr> axis = op->args;
if (axis_.size() != axis.size()) {
is_elem_wise_ = false;
return;
}
for (size_t i = 0; i < axis_.size(); ++i) {
// const Variable *v1 = axis_[i]->var.as<Variable>();
// const Variable *v2 = axis[i].as<Variable>();
if (!axis[i].same_as(axis_[i]->var)) {
// if (!(v1 && v2) || (v1 != v2)) {
is_elem_wise_ = false;
return;
}
}
IRVisitor::Visit_(op);
}
bool is_elem_wise_{true};
private:
Array<IterVar> axis_;
};
bool IsElemWise(const Operation& op) {
if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
ElemWiseDetector v = ElemWiseDetector(compute->axis);
v.Visit(compute->body);
return v.is_elem_wise_;
}
return false;
}
} // namespace ir
namespace schedule {
void AutoInlineElemWise(Schedule sch) {
for (Stage s : sch->stages) {
if (!s.is_scheduled() && ir::IsElemWise(s->op)) {
bool is_root = false;
for (auto r : sch->roots) {
if (r == s->op) {
is_root = true;
break;
}
}
if (!is_root)
s.compute_inline();
}
}
}
} // namespace schedule
} // namespace tvm
...@@ -42,8 +42,24 @@ def test_schedule2(): ...@@ -42,8 +42,24 @@ def test_schedule2():
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt) print(stmt)
def test_auto_inline():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.placeholder((m, n), name='C')
T1 = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='T1')
T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2')
s = tvm.Schedule(T2.op)
tvm.schedule.AutoInlineElemWise(s)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
if __name__ == "__main__": if __name__ == "__main__":
test_schedule0() test_schedule0()
test_schedule1() test_schedule1()
test_schedule2() test_schedule2()
test_auto_inline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment