diff --git a/src/relay/pass/expr_subst.cc b/src/relay/pass/expr_subst.cc index 586f748abef5792e7692026421b7a98b64f4b1e6..67dc0d2f704902f2fcbee497c6acfb8f0e6e18e7 100644 --- a/src/relay/pass/expr_subst.cc +++ b/src/relay/pass/expr_subst.cc @@ -18,7 +18,7 @@ class ExprSubstituter : public ExprMutator { Expr VisitExpr(const Expr& expr) final { auto it = subst_map_.find(expr); if (it != subst_map_.end()) { - return (*it).second; + return ExprMutator::VisitExpr((*it).second); } return ExprMutator::VisitExpr(expr); } diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 6fea201d64c87849404de5185edf18e3a97eaa81..7d0a5a08555e4eb6e68601c6f9960d46208c17db 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -134,7 +134,46 @@ def test_combine_parallel_conv2d_scale(): check((1, 4, 16, 16), 4, 8) + +def test_combine_parallel_conv2d_multiple_blocks(): + def before(x, w, repeat): + args = [x, w] + y = x + for i in range(repeat): + y1 = relay.nn.conv2d(y, w) + y2 = relay.nn.conv2d(y, w) + y = relay.concatenate((y1, y2), axis=1) + return relay.Function(args, y) + + def expected(x, w, channels, repeat): + args = [x, w] + y = x + for i in range(repeat): + w_concat = relay.concatenate((w, w), axis=0) + y = relay.nn.conv2d(y, w_concat, channels=channels*2) + y1 = relay.strided_slice(y, [0, 0], [None, channels]) + y2 = relay.strided_slice(y, [0, channels], [None, channels * 2]) + y = relay.concatenate((y1, y2), axis=1) + return relay.Function(args, y) + + def check(x_shape, repeat): + x = relay.var("x", shape=x_shape) + in_c = x_shape[1] + out_c = in_c // 2 + w = relay.var("w", shape=(out_c, in_c, 1, 1)) + y_before = before(x, w, repeat) + y = relay.ir_pass.infer_type(y_before) + y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.infer_type(y) + y_expected = expected(x, w, out_c, repeat) + y_expected = relay.ir_pass.infer_type(y_expected) + assert relay.ir_pass.alpha_equal(y, y_expected) + + check((1, 4, 16, 16), 4) + + if __name__ == "__main__": test_combine_parallel_conv2d() test_combine_parallel_conv2d_scale_relu() test_combine_parallel_conv2d_scale() + test_combine_parallel_conv2d_multiple_blocks()