Commit f88f4580 by Zhi Committed by masahi

[RELAY][FUSION] Enhance fusion rule that starts from elemwise and broadcast (#2932)

* [relay][bugfix] fuse injective to elemwise and broadcast

* enhance fusion for prarllel injectiveOD

* check if tensor in schedule

* fix codegen

* fix lint

* update

* lint
parent 977896cb
......@@ -552,6 +552,22 @@ class ScheduleNode : public Node {
void InvalidateCache();
/*!
* \brief Check if the schedule contains an Operation.
* \param op The candidate Operation.
* \return true if the schedule has the Operation. Otherwise, false.
*/
EXPORT bool Contain(const Operation& op) const;
/*!
* \brief Check if the schedule contains a Tensor.
* \param tensor The candidate tensor.
* \return true if the schedule has the tensor. Otherwise, false.
*/
EXPORT bool Contain(const Tensor& tensor) const {
return Contain(tensor->op);
}
/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
......
......@@ -127,9 +127,11 @@ class ScheduleGetter :
schedule =
fschedule[master_op_](master_attrs_, tensor_outs, target_);
for (const auto& scalar : scalars_) {
if (schedule->Contain(scalar)) {
schedule[scalar].compute_inline();
}
}
}
return std::make_pair(schedule, cfunc);
}
......
......@@ -715,10 +715,13 @@ class GraphPartitioner {
// The final terminal node can already be fused to a OutEWiseFusable group.
auto fcond = [](OpPatternKind kind, bool is_sink) {
if (!is_sink) {
return kind <= kBroadcast;
// Elemwise, broadcast, and injective ops on the parallel branches
// are allowed be fused to the elemwise/broadcast master.
return kind <= kInjective;
} else {
return (kind <= kBroadcast ||
kind == kCommReduce ||
kind == kInjective ||
kind == kOutEWiseFusable);
}
};
......
......@@ -712,6 +712,10 @@ void ScheduleNode::InitCache() {
CHECK_EQ(op2stage_cache_.size(), stages.size());
}
bool ScheduleNode::Contain(const Operation& op) const {
return stage_map.find(op) != stage_map.end();
}
Schedule ScheduleNode::make(Array<Operation> ops) {
auto n = make_node<ScheduleNode>();
Schedule sch(n);
......
......@@ -23,13 +23,15 @@ def test_fuse_simple():
x = relay.var("x", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
return relay.Function([x], z)
w = relay.squeeze(z)
return relay.Function([x], w)
def expected():
x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
f1 = relay.Function([x], z)
w = relay.squeeze(z)
f1 = relay.Function([x], w)
x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x])
return relay.Function([x], y)
......@@ -503,6 +505,38 @@ def test_inception_like():
assert relay.ir_pass.alpha_equal(zz, after)
def test_fuse_parallel_injective():
"""Test fusing parallel injective ops to an elemwise op."""
def before():
x = relay.var("x", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.squeeze(y)
u = relay.transpose(y, axes=[0, 1])
w = relay.left_shift(z, u)
return relay.Function([x], w)
def expected():
x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.squeeze(y)
u = relay.transpose(y, axes=[0, 1])
w = relay.left_shift(z, u)
f1 = relay.Function([x], w)
x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x])
return relay.Function([x], y)
z = before()
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(zz, after)
if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
......@@ -515,3 +549,4 @@ if __name__ == "__main__":
test_tuple_intermediate()
test_tuple_consecutive()
test_inception_like()
test_fuse_parallel_injective()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment