diff --git a/src/relay/pass/expr_subst.cc b/src/relay/pass/expr_subst.cc
index 586f748abef5792e7692026421b7a98b64f4b1e6..67dc0d2f704902f2fcbee497c6acfb8f0e6e18e7 100644
--- a/src/relay/pass/expr_subst.cc
+++ b/src/relay/pass/expr_subst.cc
@@ -18,7 +18,7 @@ class ExprSubstituter : public ExprMutator {
   Expr VisitExpr(const Expr& expr) final {
     auto it = subst_map_.find(expr);
     if (it != subst_map_.end()) {
-      return (*it).second;
+      return ExprMutator::VisitExpr((*it).second);
     }
     return ExprMutator::VisitExpr(expr);
   }
diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py
index 6fea201d64c87849404de5185edf18e3a97eaa81..7d0a5a08555e4eb6e68601c6f9960d46208c17db 100644
--- a/tests/python/relay/test_pass_combine_parallel_conv2d.py
+++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py
@@ -134,7 +134,46 @@ def test_combine_parallel_conv2d_scale():
 
     check((1, 4, 16, 16), 4, 8)
 
+
+def test_combine_parallel_conv2d_multiple_blocks():
+    def before(x, w, repeat):
+        args = [x, w]
+        y = x
+        for i in range(repeat):
+            y1 = relay.nn.conv2d(y, w)
+            y2 = relay.nn.conv2d(y, w)
+            y = relay.concatenate((y1, y2), axis=1)
+        return relay.Function(args, y)
+
+    def expected(x, w, channels, repeat):
+        args = [x, w]
+        y = x
+        for i in range(repeat):
+            w_concat = relay.concatenate((w, w), axis=0)
+            y = relay.nn.conv2d(y, w_concat, channels=channels*2)
+            y1 = relay.strided_slice(y, [0, 0], [None, channels])
+            y2 = relay.strided_slice(y, [0, channels], [None, channels * 2])
+            y = relay.concatenate((y1, y2), axis=1)
+        return relay.Function(args, y)
+
+    def check(x_shape, repeat):
+        x = relay.var("x", shape=x_shape)
+        in_c = x_shape[1]
+        out_c = in_c // 2
+        w = relay.var("w", shape=(out_c, in_c, 1, 1))
+        y_before = before(x, w, repeat)
+        y = relay.ir_pass.infer_type(y_before)
+        y = relay.ir_pass.combine_parallel_conv2d(y)
+        y = relay.ir_pass.infer_type(y)
+        y_expected = expected(x, w, out_c, repeat)
+        y_expected = relay.ir_pass.infer_type(y_expected)
+        assert relay.ir_pass.alpha_equal(y, y_expected)
+
+    check((1, 4, 16, 16), 4)
+
+
 if __name__ == "__main__":
     test_combine_parallel_conv2d()
     test_combine_parallel_conv2d_scale_relu()
     test_combine_parallel_conv2d_scale()
+    test_combine_parallel_conv2d_multiple_blocks()