Unverified Commit 68046ba3 by Trevor Morris Committed by GitHub

[Relay][MergeComposite] Support TupleGetItem in body of pattern (#5106)

* Support TupleGetItemNode in body of pattern only

* Add bn_relu test case for MergeComposite with TupleGetItem

* formatting

* TupleGetItemNode::make -> TupleGetItem()
parent 63937a01
......@@ -66,6 +66,31 @@ class MergeCompositeWrapper : public ExprMutator {
return root;
}
Expr ExtractPattern(const TupleGetItem& pattern, const Expr& root,
Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
if (!root->IsInstance<TupleGetItemNode>()) {
return Expr();
}
auto root_node = Downcast<TupleGetItem>(root);
if (pattern->index != root_node->index) {
return Expr();
}
if (pattern->tuple->IsInstance<CallNode>() &&
root_node->tuple->IsInstance<CallNode>()) {
Expr new_arg;
if (call_map->find(pattern->tuple) != call_map->end()) {
new_arg = (*call_map)[pattern->tuple];
} else {
new_arg = ExtractPattern(Downcast<Call>(pattern->tuple),
Downcast<Call>(root_node->tuple),
var_map, call_map);
call_map->Set(pattern->tuple, new_arg);
}
return TupleGetItem(new_arg, root_node->index);
}
return Expr();
}
/*!
* \brief Try and extract a given pattern from a graph as a subgraph.
* \param pattern The pattern to extract.
......@@ -125,6 +150,10 @@ class MergeCompositeWrapper : public ExprMutator {
new_arg = ExtractPattern(Downcast<Constant>(arg),
root->args[i],
var_map);
} else if (arg->IsInstance<TupleGetItemNode>()) {
new_arg = ExtractPattern(Downcast<TupleGetItem>(arg),
root->args[i],
var_map, call_map);
}
if (!new_arg.defined()) {
return Expr();
......
......@@ -129,6 +129,25 @@ def make_add_add_add_pattern():
r = relay.add(add_node_1, add_node)
return r
def make_bn_relu_pattern():
"""Create a pattern to match the following graph.
batch_norm
|
TupleGetItem(0)
|
relu
"""
x = relay.var('x')
gamma = relay.var("gamma")
beta = relay.var("beta")
moving_mean = relay.var("moving_mean")
moving_var = relay.var("moving_var")
bn_node = relay.nn.batch_norm(x, gamma, beta, moving_mean, moving_var)
tuple_get_item_node = bn_node[0]
r = relay.nn.relu(tuple_get_item_node)
return r
def test_simple_merge():
"""Test composite function is correctly produced from simple graph.
......@@ -666,6 +685,52 @@ def test_multiple_input_subgraphs():
assert relay.analysis.alpha_equal(result, expected)
def test_tuple_get_item_merge():
"""Test composite function can be merged from pattern containing TupleGetItem nodes."""
pattern_table = [
("bn_relu", make_bn_relu_pattern())
]
def before():
x = relay.var('x', shape=(1, 8))
gamma = relay.var("gamma", shape=(8,))
beta = relay.var("beta", shape=(8,))
moving_mean = relay.var("moving_mean", shape=(8,))
moving_var = relay.var("moving_var", shape=(8,))
bn_node = relay.nn.batch_norm(x, gamma, beta, moving_mean, moving_var)
tuple_get_item_node = bn_node[0]
r = relay.nn.relu(tuple_get_item_node)
return relay.Function([x, gamma, beta, moving_mean, moving_var], r)
def expected():
x = relay.var('x', shape=(1, 8))
beta = relay.var("beta", shape=(8,))
gamma = relay.var("gamma", shape=(8,))
moving_mean = relay.var("moving_mean", shape=(8,))
moving_var = relay.var("moving_var", shape=(8,))
# bn_relu function
in_1 = relay.var('x1', shape=(1, 8))
in_2 = relay.var('gamma1', shape=(8,))
in_3 = relay.var('beta1', shape=(8,))
in_4 = relay.var('moving_mean1', shape=(8,))
in_5 = relay.var('moving_var1', shape=(8,))
bn_node = relay.nn.batch_norm(in_1, in_2, in_3, in_4, in_5)
tuple_get_item_node = bn_node[0]
relu_node = relay.nn.relu(tuple_get_item_node)
bn_relu = relay.Function([in_1, in_2, in_3, in_4, in_5], relu_node)
bn_relu = bn_relu.with_attr("Composite", tir.StringImm("bn_relu"))
# merged function
r = relay.Call(bn_relu, [x, gamma, beta, moving_mean, moving_var])
return relay.Function([x, gamma, beta, moving_mean, moving_var], r)
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected)
if __name__ == "__main__":
test_simple_merge()
test_branch_merge()
......@@ -674,3 +739,4 @@ if __name__ == "__main__":
test_parallel_merge()
test_multiple_input_subgraphs()
test_reuse_call_merge()
test_tuple_get_item_merge()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment