Commit f88f4580 by Zhi Committed by masahi

[RELAY][FUSION] Enhance fusion rule that starts from elemwise and broadcast (#2932)

* [relay][bugfix] fuse injective to elemwise and broadcast

* enhance fusion for prarllel injectiveOD

* check if tensor in schedule

* fix codegen

* fix lint

* update

* lint
parent 977896cb
...@@ -552,6 +552,22 @@ class ScheduleNode : public Node { ...@@ -552,6 +552,22 @@ class ScheduleNode : public Node {
void InvalidateCache(); void InvalidateCache();
/*! /*!
* \brief Check if the schedule contains an Operation.
* \param op The candidate Operation.
* \return true if the schedule has the Operation. Otherwise, false.
*/
EXPORT bool Contain(const Operation& op) const;
/*!
* \brief Check if the schedule contains a Tensor.
* \param tensor The candidate tensor.
* \return true if the schedule has the tensor. Otherwise, false.
*/
EXPORT bool Contain(const Tensor& tensor) const {
return Contain(tensor->op);
}
/*!
* \brief Create a schedule for array of ops(and their dependencies). * \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled. * \param ops The ops to be scheduled.
* \return sch The created Schedule. * \return sch The created Schedule.
......
...@@ -127,9 +127,11 @@ class ScheduleGetter : ...@@ -127,9 +127,11 @@ class ScheduleGetter :
schedule = schedule =
fschedule[master_op_](master_attrs_, tensor_outs, target_); fschedule[master_op_](master_attrs_, tensor_outs, target_);
for (const auto& scalar : scalars_) { for (const auto& scalar : scalars_) {
if (schedule->Contain(scalar)) {
schedule[scalar].compute_inline(); schedule[scalar].compute_inline();
} }
} }
}
return std::make_pair(schedule, cfunc); return std::make_pair(schedule, cfunc);
} }
......
...@@ -715,10 +715,13 @@ class GraphPartitioner { ...@@ -715,10 +715,13 @@ class GraphPartitioner {
// The final terminal node can already be fused to a OutEWiseFusable group. // The final terminal node can already be fused to a OutEWiseFusable group.
auto fcond = [](OpPatternKind kind, bool is_sink) { auto fcond = [](OpPatternKind kind, bool is_sink) {
if (!is_sink) { if (!is_sink) {
return kind <= kBroadcast; // Elemwise, broadcast, and injective ops on the parallel branches
// are allowed be fused to the elemwise/broadcast master.
return kind <= kInjective;
} else { } else {
return (kind <= kBroadcast || return (kind <= kBroadcast ||
kind == kCommReduce || kind == kCommReduce ||
kind == kInjective ||
kind == kOutEWiseFusable); kind == kOutEWiseFusable);
} }
}; };
......
...@@ -712,6 +712,10 @@ void ScheduleNode::InitCache() { ...@@ -712,6 +712,10 @@ void ScheduleNode::InitCache() {
CHECK_EQ(op2stage_cache_.size(), stages.size()); CHECK_EQ(op2stage_cache_.size(), stages.size());
} }
bool ScheduleNode::Contain(const Operation& op) const {
return stage_map.find(op) != stage_map.end();
}
Schedule ScheduleNode::make(Array<Operation> ops) { Schedule ScheduleNode::make(Array<Operation> ops) {
auto n = make_node<ScheduleNode>(); auto n = make_node<ScheduleNode>();
Schedule sch(n); Schedule sch(n);
......
...@@ -23,13 +23,15 @@ def test_fuse_simple(): ...@@ -23,13 +23,15 @@ def test_fuse_simple():
x = relay.var("x", shape=(10, 20)) x = relay.var("x", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32")) y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y) z = relay.exp(y)
return relay.Function([x], z) w = relay.squeeze(z)
return relay.Function([x], w)
def expected(): def expected():
x = relay.var("p", shape=(10, 20)) x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32")) y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y) z = relay.exp(y)
f1 = relay.Function([x], z) w = relay.squeeze(z)
f1 = relay.Function([x], w)
x = relay.var("x", shape=(10, 20)) x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x]) y = relay.Call(f1, [x])
return relay.Function([x], y) return relay.Function([x], y)
...@@ -503,6 +505,38 @@ def test_inception_like(): ...@@ -503,6 +505,38 @@ def test_inception_like():
assert relay.ir_pass.alpha_equal(zz, after) assert relay.ir_pass.alpha_equal(zz, after)
def test_fuse_parallel_injective():
"""Test fusing parallel injective ops to an elemwise op."""
def before():
x = relay.var("x", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.squeeze(y)
u = relay.transpose(y, axes=[0, 1])
w = relay.left_shift(z, u)
return relay.Function([x], w)
def expected():
x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.squeeze(y)
u = relay.transpose(y, axes=[0, 1])
w = relay.left_shift(z, u)
f1 = relay.Function([x], w)
x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x])
return relay.Function([x], y)
z = before()
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(zz, after)
if __name__ == "__main__": if __name__ == "__main__":
test_fuse_simple() test_fuse_simple()
test_conv2d_fuse() test_conv2d_fuse()
...@@ -515,3 +549,4 @@ if __name__ == "__main__": ...@@ -515,3 +549,4 @@ if __name__ == "__main__":
test_tuple_intermediate() test_tuple_intermediate()
test_tuple_consecutive() test_tuple_consecutive()
test_inception_like() test_inception_like()
test_fuse_parallel_injective()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment