Commit 990521dd by Wuwei Lin Committed by Tianqi Chen

[RELAY][PASS] Fix expr subst and CombineParallelConv2D (#2218)

parent f884c5d8
...@@ -18,7 +18,7 @@ class ExprSubstituter : public ExprMutator { ...@@ -18,7 +18,7 @@ class ExprSubstituter : public ExprMutator {
Expr VisitExpr(const Expr& expr) final { Expr VisitExpr(const Expr& expr) final {
auto it = subst_map_.find(expr); auto it = subst_map_.find(expr);
if (it != subst_map_.end()) { if (it != subst_map_.end()) {
return (*it).second; return ExprMutator::VisitExpr((*it).second);
} }
return ExprMutator::VisitExpr(expr); return ExprMutator::VisitExpr(expr);
} }
......
...@@ -134,7 +134,46 @@ def test_combine_parallel_conv2d_scale(): ...@@ -134,7 +134,46 @@ def test_combine_parallel_conv2d_scale():
check((1, 4, 16, 16), 4, 8) check((1, 4, 16, 16), 4, 8)
def test_combine_parallel_conv2d_multiple_blocks():
def before(x, w, repeat):
args = [x, w]
y = x
for i in range(repeat):
y1 = relay.nn.conv2d(y, w)
y2 = relay.nn.conv2d(y, w)
y = relay.concatenate((y1, y2), axis=1)
return relay.Function(args, y)
def expected(x, w, channels, repeat):
args = [x, w]
y = x
for i in range(repeat):
w_concat = relay.concatenate((w, w), axis=0)
y = relay.nn.conv2d(y, w_concat, channels=channels*2)
y1 = relay.strided_slice(y, [0, 0], [None, channels])
y2 = relay.strided_slice(y, [0, channels], [None, channels * 2])
y = relay.concatenate((y1, y2), axis=1)
return relay.Function(args, y)
def check(x_shape, repeat):
x = relay.var("x", shape=x_shape)
in_c = x_shape[1]
out_c = in_c // 2
w = relay.var("w", shape=(out_c, in_c, 1, 1))
y_before = before(x, w, repeat)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w, out_c, repeat)
y_expected = relay.ir_pass.infer_type(y_expected)
assert relay.ir_pass.alpha_equal(y, y_expected)
check((1, 4, 16, 16), 4)
if __name__ == "__main__": if __name__ == "__main__":
test_combine_parallel_conv2d() test_combine_parallel_conv2d()
test_combine_parallel_conv2d_scale_relu() test_combine_parallel_conv2d_scale_relu()
test_combine_parallel_conv2d_scale() test_combine_parallel_conv2d_scale()
test_combine_parallel_conv2d_multiple_blocks()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment