Unverified Commit 27a02844 by Jon Soifer Committed by GitHub

[Relay][Pass] Fix bug in re-processing call node in MergeComposite pass (#4879)

* Fix bug in re-processing call node

* Add test

* Add to main

* temp changes to work from another machine

* fix rest of tests

* fix test_reuse_call_merge

* fix merge

Co-authored-by: Jon Soifer <jonso@microsoft.com>
parent 0b2d11a5
......@@ -87,7 +87,7 @@ class MergeCompositeWrapper : public ExprMutator {
* a new Relay expression ready to be wrapped into a composite function.
*/
Expr ExtractPattern(const Call& pattern, const Call& root,
Map<std::string, Array<Expr>>* var_map) {
Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
// check to make sure both calls are to operators (not functions)
if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
return Expr();
......@@ -99,14 +99,20 @@ class MergeCompositeWrapper : public ExprMutator {
for (const auto& arg : pattern->args) {
Expr new_arg;
if (arg->IsInstance<CallNode>()) {
// fail if the root argument is not also a call node
if (!root->args[i]->IsInstance<CallNode>()) {
return Expr();
// if we've already processed this call node, return the previous result
if (call_map->find(arg) != call_map->end()) {
new_arg = (*call_map)[arg];
} else {
// fail if the root argument is not also a call node
if (!root->args[i]->IsInstance<CallNode>()) {
return Expr();
}
// if it's a call node, recursively call this function
new_arg = ExtractPattern(Downcast<Call>(arg),
Downcast<Call>(root->args[i]),
var_map, call_map);
call_map->Set(arg, new_arg);
}
// if it's a call node, recursively call this function
new_arg = ExtractPattern(Downcast<Call>(arg),
Downcast<Call>(root->args[i]),
var_map);
} else if (arg->IsInstance<VarNode>()) {
// if there's a var in the pattern, it must be a free var
// so call the function to update the var_map
......@@ -155,7 +161,8 @@ class MergeCompositeWrapper : public ExprMutator {
Call pattern = Downcast<Call>(pattern_);
CHECK(pattern.defined());
Map<std::string, Array<Expr>> args_map;
auto extract = ExtractPattern(pattern, call, &args_map);
Map<Expr, Expr> call_map;
auto extract = ExtractPattern(pattern, call, &args_map, &call_map);
if (extract.defined()) {
auto free_vars = FreeVars(extract);
// make the composite function
......
......@@ -110,6 +110,26 @@ def make_conv_bias_relu_pattern():
return r
def make_add_add_add_pattern():
"""Create a pattern to match the following graph.
Useful for testing re-using a call node.
x y
/ \ /
| add
\ | \
add |
| /
add
"""
x = relay.var('x')
y = relay.var('y')
add_node = relay.add(x, y)
add_node_1 = relay.add(x, add_node)
r = relay.add(add_node_1, add_node)
return r
def test_simple_merge():
"""Test composite function is correctly produced from simple graph.
......@@ -239,6 +259,67 @@ def test_branch_merge():
assert relay.analysis.alpha_equal(result, expected)
def test_reuse_call_merge():
"""Test composite function is correctly produced from simple graph
which re-uses call nodes.
We could expect the pattern `make_add_add_add` to be merged
into a single op `add_add_add`.
x y
\ / \
sub | x y
/ | / \ / |
| add ====> sub |
\ | \ | /
add | add_add_add
| /
add
"""
pattern_table = [
("add_add_add", make_add_add_add_pattern())
]
def before():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))
sub_node = relay.subtract(a, b)
# pattern
add_node = relay.add(sub_node, b)
add_node_1 = relay.add(sub_node, add_node)
r = relay.add(add_node_1, add_node)
return relay.Function([a, b], r)
def expected():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))
# add_relu_add function
in_1 = relay.var('in_1', shape=(10, 10))
in_2 = relay.var('in_2', shape=(10, 10))
add_node = relay.add(in_1, in_2)
add_node_1 = relay.add(in_1, add_node)
add_node_2 = relay.add(add_node_1, add_node)
add_add_add = relay.Function([in_1, in_2], add_node_2)
add_add_add = add_add_add.set_attribute("Primitive",
tir.IntImm("int32", 1))
add_add_add = add_add_add.set_attribute("Composite",
tir.StringImm("add_add_add"))
# merged function
sub_node = relay.subtract(a, b)
call = relay.Call(add_add_add, [sub_node, b])
return relay.Function([a, b], call)
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
def test_multiple_patterns():
"""Test different patterns are merged correctly in the graph.
......@@ -608,3 +689,4 @@ if __name__ == "__main__":
test_merge_order()
test_parallel_merge()
test_multiple_input_subgraphs()
test_reuse_call_merge()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment