diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc
index 9e9dd0604916f1971be52a9c89e3d8b0934ec917..760a226a2facc22e72c96adcb6aa2193bac08d0b 100644
--- a/src/relay/pass/fold_scale_axis.cc
+++ b/src/relay/pass/fold_scale_axis.cc
@@ -246,44 +246,9 @@ class ForwardPrep : private ExprVisitor {
 // Per operator defs for FScaleAxisForward
 //----------------------------------------------
 
-// Helper functions
-Expr GetForwardScale(const Expr& expr, AxesSet out) {
-  static const Op& multiply = Op::Get("multiply");
-  static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
-
-  const CallNode* call = expr.as<CallNode>();
-  if (!call) return NullValue<Expr>();
-  auto f = fprep.get(call->op, nullptr);
-
-  if (call->op.same_as(multiply)) {
-    const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
-    const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
-    if (MatchBroadcastToLeftAxes(tlhs, trhs, out)) {
-      return call->args[1];
-    } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out)) {
-      return call->args[0];
-    } else {
-      return NullValue<Expr>();
-    }
-  } else if (f != nullptr) {
-    Array<AxesSet> in_axes = f(GetRef<Call>(call), out);
-    for (size_t i = 0; i < call->args.size(); i++) {
-      auto scale = GetForwardScale(call->args[i], in_axes[i]);
-      if (scale.defined()) {
-        return scale;
-      }
-    }
-  }
-  return NullValue<Expr>();
-}
-
 // Intermediate operators
 Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
-  Expr scale = GetForwardScale(call->args[0], out);
-  if (IsPositiveConstant(scale)) {
-    return {out};
-  }
-  return {NullValue<AxesSet>()};
+  return {out};
 }
 
 Expr ReluForwardRewrite(const Call& ref_call,
@@ -391,16 +356,21 @@ Expr MultiplyForwardRewrite(const Call& ref_call,
   Expr lhs = new_args[0];
   Expr rhs = new_args[1];
   auto rnode = make_node<ScaledExprNode>();
-  if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs)) {
+  if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) &&
+      IsAllPositiveConstant(rhs)) {
     rnode->value = lhs;
     rnode->scale = rhs;
     rnode->axes = expected_out_axes;
-  } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs)) {
+    return Expr(rnode);
+  } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) &&
+             IsAllPositiveConstant(lhs)) {
     rnode->value = rhs;
     rnode->scale = lhs;
     rnode->axes = expected_out_axes;
+    return Expr(rnode);
+  } else {
+    return Expr();
   }
-  return Expr(rnode);
 }
 
 RELAY_REGISTER_OP("multiply")
@@ -790,22 +760,6 @@ RELAY_REGISTER_OP("subtract")
 RELAY_REGISTER_OP("subtract")
 .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
 
-// Find relu in the backward path between multiply and conv2d
-bool FindBackwardRelu(const Expr& expr) {
-  const CallNode* call = expr.as<CallNode>();
-  static const Op& conv2d = Op::Get("nn.conv2d");
-  static const Op& relu = Op::Get("nn.relu");
-
-  if (!call) return false;
-  if (call->op.same_as(relu)) return true;
-  if (call->op.same_as(conv2d)) return false;
-
-  for (size_t i = 0; i < call->args.size(); i++) {
-    if (FindBackwardRelu(call->args[i])) return true;
-  }
-  return false;
-}
-
 // Producer operators
 // Multiply produces the scale-axis pair.
 Expr MultiplyBackwardTransform(const Call& call,
@@ -821,16 +775,16 @@ Expr MultiplyBackwardTransform(const Call& call,
     // NOTE we won't recursively call mutating on scale part.
     // since there  won't be scale chance within scale part.
     Expr rhs = call->args[1];
+    // Only propagate positive scaling.
     if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) &&
-        (!FindBackwardRelu(call->args[0]) ||
-         IsPositiveConstant(call->args[1]))) {
+        IsAllPositiveConstant(rhs)) {
       return transformer->Transform(call->args[0], lhs_axes, rhs);
     }
   } else if (rhs_axes.defined() && rhs_axes.size() != 0) {
+    // Only propagate positive scaling.
     Expr lhs = call->args[0];
     if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) &&
-        (!FindBackwardRelu(call->args[1]) ||
-         IsPositiveConstant(call->args[0]))) {
+        IsAllPositiveConstant(lhs)) {
       return transformer->Transform(call->args[1], rhs_axes, lhs);
     }
   }
diff --git a/src/relay/pass/pass_util.h b/src/relay/pass/pass_util.h
index d42494409b53227f74d1a76537cb411156863756..ddd73901c452ae1e4e0e6f414bbb3beab5acf923 100644
--- a/src/relay/pass/pass_util.h
+++ b/src/relay/pass/pass_util.h
@@ -22,6 +22,15 @@ namespace relay {
 std::unordered_map<const Node*, size_t>
 GetExprRefCount(const Expr& body);
 
+
+/*!
+ * \brief Check if expr is positive constant.
+ * \param expr The expression to be checked.
+ * \return Whether all elements of expr is positive constant.
+ */
+bool IsAllPositiveConstant(const Expr& expr);
+
+
 /*!
  * \brief Substitute var with subst.
  * \param type The type to be substituted.
diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h
index 5d76efd0124d44ef3a919b315675952e1800b783..e6e8415bd620fef196bf15127681a6969af012f3 100644
--- a/src/relay/pass/pattern_util.h
+++ b/src/relay/pass/pattern_util.h
@@ -190,57 +190,6 @@ Expr MakeConcatenate(Expr data, int axis);
 
 Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
 
-
-template <typename T>
-bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
-  CHECK_EQ(tensor->ctx.device_type, kDLCPU);
-  CHECK(tensor->strides == nullptr);
-  CHECK_EQ(tensor->byte_offset, 0);
-  const T* data = static_cast<const T*>(tensor->data);
-  int64_t num_elems = 1;
-  for (int i = 0; i < tensor->ndim; ++i) {
-    num_elems *= tensor->shape[i];
-  }
-
-  for (int64_t i = 0; i < num_elems; i++) {
-    if (*data < value) {
-      return false;
-    }
-    data++;
-  }
-  return true;
-}
-
-
-inline bool IsPositiveConstant(const Expr& expr) {
-  const auto* constant = expr.as<ConstantNode>();
-  if (!constant) return false;
-  const auto& tensor = constant->data;
-  const auto& dtype = tensor->dtype;
-
-  if (dtype.lanes != 1) {
-    // pass
-  } else if (dtype.code == kDLFloat && dtype.bits == 32) {
-    return IsNDArrayAllGreaterEqual<float>(tensor, 0);
-  } else if (dtype.code == kDLFloat && dtype.bits == 64) {
-    return IsNDArrayAllGreaterEqual<double>(tensor, 0);
-  } else if (dtype.code == kDLInt && dtype.bits == 8) {
-    return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
-  } else if (dtype.code == kDLInt && dtype.bits == 32) {
-    return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
-  } else if (dtype.code == kDLUInt && dtype.bits == 8) {
-    return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
-  } else if (dtype.code == kDLUInt && dtype.bits == 32) {
-    return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
-  }
-
-  LOG(WARNING) << "Unsupported data type (code = " << dtype.code
-               << ", bits = " << dtype.bits << ", lanes = " << dtype.lanes
-               << ")";
-  return false;
-}
-
-
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_PASS_PATTERN_UTIL_H_
diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc
index 8f7179deea5351a1be3b7c16f4ad427fc1490339..b99d975135bea45484103cf13e5529507021a5c7 100644
--- a/src/relay/pass/util.cc
+++ b/src/relay/pass/util.cc
@@ -146,5 +146,67 @@ GetExprRefCount(const Expr& body) {
   return ExprRefCounter().Get(body);
 }
 
+template <typename T>
+bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
+  CHECK_EQ(tensor->ctx.device_type, kDLCPU);
+  CHECK(tensor->strides == nullptr);
+  CHECK_EQ(tensor->byte_offset, 0);
+  const T* data = static_cast<const T*>(tensor->data);
+  int64_t num_elems = 1;
+  for (int i = 0; i < tensor->ndim; ++i) {
+    num_elems *= tensor->shape[i];
+  }
+
+  for (int64_t i = 0; i < num_elems; i++) {
+    if (*data < value) {
+      return false;
+    }
+    data++;
+  }
+  return true;
+}
+
+bool IsAllPositiveConstant(const Expr& expr) {
+  // peel through a few common transform ops.
+  static const auto& expand_dims = Op::Get("expand_dims");
+  static const auto& reshape = Op::Get("reshape");
+  static const auto& transpose = Op::Get("transpose");
+  static const auto& squeeze = Op::Get("squeeze");
+
+  if (const auto* constant = expr.as<ConstantNode>()) {
+    const auto& tensor = constant->data;
+    const auto& dtype = tensor->dtype;
+    if (dtype.lanes != 1) {
+      return false;
+    } else if (dtype.code == kDLFloat && dtype.bits == 32) {
+      return IsNDArrayAllGreaterEqual<float>(tensor, 0);
+    } else if (dtype.code == kDLFloat && dtype.bits == 64) {
+      return IsNDArrayAllGreaterEqual<double>(tensor, 0);
+    } else if (dtype.code == kDLInt && dtype.bits == 8) {
+      return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
+    } else if (dtype.code == kDLInt && dtype.bits == 32) {
+      return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
+    } else if (dtype.code == kDLUInt && dtype.bits == 8) {
+      return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
+    } else if (dtype.code == kDLUInt && dtype.bits == 32) {
+      return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
+    } else {
+      return false;
+    }
+  } else if (const auto* op = expr.as<CallNode>()) {
+    // tail recursion.
+    if (op->op.same_as(expand_dims) ||
+        op->op.same_as(reshape) ||
+        op->op.same_as(transpose) ||
+        op->op.same_as(squeeze)) {
+      return IsAllPositiveConstant(op->args[0]);
+    } else {
+      return false;
+    }
+  } else {
+    return false;
+  }
+}
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py
index f42aa7b7b8d0bd17bf4a33af523125bb3227e0e3..57cb7c84b10d7cff87bac83cb96c47270b666d45 100644
--- a/tests/python/relay/test_pass_fold_scale_axis.py
+++ b/tests/python/relay/test_pass_fold_scale_axis.py
@@ -1,6 +1,9 @@
 from tvm import relay
 import numpy as np
 
+def _get_positive_scale(size):
+    return np.random.uniform(0.5, 1, size=size).astype('float32')
+
 
 def test_fold_fwd_simple():
     """Simple testcase."""
@@ -14,6 +17,7 @@ def test_fold_fwd_simple():
                             channels=channels,
                             kernel_size=(3, 3),
                             padding=(1, 1))
+
         return relay.Function(args, y)
 
     def expected(x, conv_weight, in_bias, in_scale, channels):
@@ -37,14 +41,14 @@ def test_fold_fwd_simple():
         in_channels = shape[1]
         weight = relay.var("weight")
         in_bias = relay.var("in_bias", shape=(in_channels,))
-        in_scale = relay.const(np.random.uniform(size=(in_channels, 1, 1)).astype('float32'))
-
+        in_scale = relay.const(_get_positive_scale((in_channels, 1, 1)))
         y1 = before(x, weight, in_bias, in_scale, channels)
         y1 = relay.ir_pass.infer_type(y1)
         type_dict = {x.name_hint:x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
         y1_expected = expected(x, weight, in_bias, in_scale, channels)
+
         y1_folded = relay.ir_pass.infer_type(y1_folded)
         y1_expected = relay.ir_pass.infer_type(y1_expected)
         assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
@@ -107,7 +111,7 @@ def test_fold_fwd_dual_path():
         assert in_channels == channels
         weight = relay.var("weight")
         in_bias = relay.var("in_bias", shape=(in_channels,))
-        in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32"))
+        in_scale = relay.const(_get_positive_scale(in_channels,))
         y1 = before(x, weight, in_bias, in_scale, channels)
         y1 = relay.ir_pass.infer_type(y1)
         y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
@@ -141,7 +145,7 @@ def test_fold_fwd_fail():
         assert in_channels == channels
         weight = relay.var("weight")
         in_bias = relay.var("in_bias", shape=(in_channels,))
-        in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32"))
+        in_scale = relay.const(_get_positive_scale(size=(in_channels,)))
         y1 = before(x, weight, in_bias, in_scale, channels)
         y1 = relay.ir_pass.infer_type(y1)
         y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)