diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc
index 785b486ddc06f171055cb6858f8c6d21d8e01e59..6acf4e65b1acbd2160b9689b3ff6f1f36220aa42 100644
--- a/src/relay/pass/simplify_inference.cc
+++ b/src/relay/pass/simplify_inference.cc
@@ -15,7 +15,8 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
                             Expr gamma,
                             Expr beta,
                             Expr moving_mean,
-                            Expr moving_var) {
+                            Expr moving_var,
+                            Type tdata) {
   const auto param = attrs.as<BatchNormAttrs>();
   Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
   Expr var_add_eps = Add(moving_var, epsilon);
@@ -32,9 +33,11 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
   }
 
   int axis = param->axis;
-  const auto* tdata = data->type_as<TensorTypeNode>();
-  scale = ExpandBiasToMatchAxis(scale, tdata->shape.size(), {axis});
-  shift = ExpandBiasToMatchAxis(shift, tdata->shape.size(), {axis});
+  auto ttype = tdata.as<TensorTypeNode>();
+  CHECK(ttype);
+  auto ndim = ttype->shape.size();
+  scale = ExpandBiasToMatchAxis(scale, ndim, {axis});
+  shift = ExpandBiasToMatchAxis(shift, ndim, {axis});
 
   Expr out = Multiply(data, scale);
   out = Add(out, shift);
@@ -54,14 +57,26 @@ class InferenceSimplifier : public ExprMutator {
     }
     if (const auto* call = new_n->tuple.as<CallNode>()) {
       if (call->op.same_as(batch_norm)) {
-        return BatchNormToInferUnpack(call->attrs,
-          call->args[0], call->args[1], call->args[2], call->args[3], call->args[4]);
+        return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
+                                      call->args[3], call->args[4], ty_map_.at(call->args[0]));
       } else if (call->op.same_as(dropout)) {
         return call->args[0];
       }
     }
     return new_e;
   }
+
+  Expr VisitExpr_(const CallNode* n) {
+    static const Op& batch_norm = Op::Get("nn.batch_norm");
+    auto new_n = ExprMutator::VisitExpr_(n);
+    if (n->op.same_as(batch_norm)) {
+      ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
+    }
+    return new_n;
+  }
+
+ private:
+  std::unordered_map<Expr, Type, NodeHash, NodeEqual> ty_map_;
 };
 
 Expr SimplifyInference(const Expr& e) {
diff --git a/tests/python/relay/test_pass_simplify_inference.py b/tests/python/relay/test_pass_simplify_inference.py
index 9830b83dc6e59752338054f83c418a1309f33436..7585a88063abab8639e720347649c3d31362853c 100644
--- a/tests/python/relay/test_pass_simplify_inference.py
+++ b/tests/python/relay/test_pass_simplify_inference.py
@@ -30,12 +30,12 @@ def test_simplify_batchnorm():
             y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'),
                 gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
             y1 = rly.nn.dropout(y1)
-            y1 = rly.ir_pass.infer_type(y1)
-            y1 = simplify_inference(y1)
-
             y2 = simple_bn(y2 + rly.const(1, 'float32'),
                            gamma, beta, moving_mean, moving_var,
                            epsilon=eps, axis=axis, shape=ttype1.shape)
+        y1 = rly.ir_pass.infer_type(y1)
+        y1 = simplify_inference(y1)
+
         assert rly.ir_pass.graph_equal(y1, y2)
 
     check(2, 1, 1)