diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index dfad1013701f935a03c584b2b0347c51bd01e577..cb87d358e966b8ec5f9acf0ad15a12af2bab623a 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -94,15 +94,16 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
 
 /*! \brief Attributes used in squeeze operators */
 struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
-  Array<IndexExpr> axes;
+  // use axis to make the name numpy compatible.
+  Array<Integer> axis;
 
   TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
-    TVM_ATTR_FIELD(axes)
-        .describe("The axes to squeeze in the input tensor."
-                  "If `axes = []`, all axis of dimension 1 get squeezed;"
+    TVM_ATTR_FIELD(axis)
+        .describe("The axis to squeeze in the input tensor."
+                  "If `axis = None`, all axis of dimension 1 get squeezed;"
                   "Else, the dimension in axes get squeezed."
-                  "It is an error if an axes does not has dimension 1.")
-        .set_default(Array<IndexExpr>({}));
+                  "It is an error if an axis does not has dimension 1.")
+        .set_default(NullValue<Array<Integer> >());
   }
 };  // struct SqueezeAttrs
 
diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index 50e5dfa8d89be5a19865214e7c522a5947b30532..2e3bbadb7841bee5347b81450d457c3452f11478 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -40,6 +40,18 @@ class ExprNode : public RelayNode {
                                       "field for this node";
     return this->checked_type_;
   }
+  /*!
+   * \brief Check if the inferred(checked) type of the Expr
+   *  is backed by a TTypeNode and return it.
+   *
+   * \note This function will thrown an error if the node type
+   *       of this Expr is not TTypeNode.
+   *
+   * \return The corresponding TTypeNode pointer.
+   * \tparam The specific TypeNode we look for.
+   */
+  template<typename TTypeNode>
+  inline const TTypeNode* type_as() const;
 
   static constexpr const char* _type_key = "relay.Expr";
   TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode);
@@ -391,6 +403,20 @@ class TupleGetItemNode : public ExprNode {
 
 RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
 
+// implementataions
+template<typename TTypeNode>
+inline const TTypeNode* ExprNode::type_as() const {
+  static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
+                "TType must be a special case of type");
+  CHECK(checked_type_.defined())
+      << "Type inference for this Expr has not completed";
+  const TTypeNode* node = checked_type_.as<TTypeNode>();
+  CHECK(node != nullptr)
+      << "Expected type to be " << TTypeNode::_type_key
+      << ", but get " << checked_type_->type_key();
+  return node;
+}
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_EXPR_H_
diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h
index c0256cf3a1c37bd81a1a5a2c09c3d9c4c289f125..bf4025f79224c4d2b7c16ee50462fc5bc2a62a7b 100644
--- a/include/tvm/relay/expr_functor.h
+++ b/include/tvm/relay/expr_functor.h
@@ -150,7 +150,14 @@ class ExprVisitor
 class ExprMutator
     : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
  public:
-  Expr Mutate(const Expr& expr);
+  /*!
+   * \brief Mutate is alias for VisitExpr
+   * \return expr.
+   */
+  Expr Mutate(const Expr& expr) {
+    return this->VisitExpr(expr);
+  }
+  Expr VisitExpr(const Expr& expr) override;
   Expr VisitExpr_(const VarNode* op) override;
   Expr VisitExpr_(const ConstantNode* op) override;
   Expr VisitExpr_(const GlobalVarNode* op) override;
@@ -161,7 +168,8 @@ class ExprMutator
   Expr VisitExpr_(const LetNode* op) override;
   Expr VisitExpr_(const IfNode* op) override;
   Expr VisitExpr_(const TupleGetItemNode* op) override;
-  /*! \brief Used to visit the types inside of expressions.
+  /*!
+   * \brief Used to visit the types inside of expressions.
    *
    * Can be overloaded to transform the types in arbitrary
    * ways, one way would be to define a sub-class of type
@@ -169,7 +177,7 @@ class ExprMutator
    */
   virtual Type VisitType(const Type& t);
 
- private:
+ protected:
   /*! \brief Internal map used for memoization. */
   std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
 };
diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h
index 9f28fbebccfcc227f6d894b2f9561722b6337dbd..ad447ad13cee4d4b7ee53bb9dffb48d70dde3e40 100644
--- a/include/tvm/relay/op.h
+++ b/include/tvm/relay/op.h
@@ -74,6 +74,17 @@ class OpNode : public relay::ExprNode {
     v->Visit("support_level", &support_level);
   }
 
+  /*!
+   * \brief Check that if current op is a "primtive operator".
+   * That is the arguments are all type variables, and there is a single
+   * type relation applied to the input and output types.
+   */
+  bool IsPrimitiveOp() const {
+    if (is_primitive_ != -1) return is_primitive_ != 0;
+    is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
+    return is_primitive_ != 0;
+  }
+
   static constexpr const char* _type_key = "relay.Op";
   TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode);
 
@@ -81,9 +92,24 @@ class OpNode : public relay::ExprNode {
   // friend class
   friend class GenericOpMap;
   friend class OpRegistry;
+  friend bool IsPrimitiveOp(const Expr&);
   // Program internal unique index of operator.
   // Used to help index the program.
   uint32_t index_{0};
+  // whether this is a primitive op. -1 means unknown.
+  mutable int is_primitive_{-1};
+  // Internal function to compute if it is primitive op
+  bool IsPrimitiveOp_() const {
+    const auto& fn_ty = this->op_type;
+    if (fn_ty->type_constraints.size() != 1) return false;
+    const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
+    if (rel == nullptr) return false;
+    // validate if the type parameter matches up
+    for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
+      if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
+    }
+    return true;
+  }
 };
 
 /*!
@@ -497,22 +523,7 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
  */
 inline bool IsPrimitiveOp(const Expr& expr) {
   const auto* op = expr.as<OpNode>();
-
-  if (!op) {
-    return false;
-  }
-
-  const auto& fn_ty = op->op_type;
-  if (fn_ty->type_constraints.size() != 1) return false;
-
-  const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
-  if (rel == nullptr) return false;
-  // validate if the type parameter matches up
-  for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
-    if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
-  }
-
-  return true;
+  return op != nullptr && op->IsPrimitiveOp();
 }
 
 }  // namespace relay
diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py
index f930751c41a7d6c5a479d5eddb6021515ec082e8..6adfaacdc86d3c4a6a13ccc0fbfbb84020536ad5 100644
--- a/python/tvm/relay/ir_pass.py
+++ b/python/tvm/relay/ir_pass.py
@@ -10,6 +10,7 @@ from . import _make
 from .expr import Expr
 from .ty import Type
 
+
 def infer_type(expr, env=None):
     """Infer the type of expr under the context of env.
 
@@ -30,6 +31,23 @@ def infer_type(expr, env=None):
     return _ir_pass.infer_type(expr, env)
 
 
+def forward_fold_scale_axis(expr):
+    """Fold the scaling of axis into weights of conv2d/dense.
+
+    Parameters
+    ----------
+    expr : tvm.relay.Expr
+        The input expression, we expect that expr's types
+        should be fully inferred by infer_type.
+
+    Returns
+    -------
+    folded_expr : tvm.relay.Expr
+        The folded expression after transformation.
+    """
+    return _ir_pass.forward_fold_scale_axis(expr)
+
+
 def well_formed(expr):
     """Check that each Var is only bound once (well formed).
 
@@ -149,6 +167,7 @@ def alpha_equal(lhs, rhs):
     """
     return bool(_make._alpha_equal(lhs, rhs))
 
+
 def graph_equal(lhs, rhs):
     """Compare two Relay expr for data-flow equivalence.
     The difference between this and alpha-equality is that
@@ -170,6 +189,7 @@ def graph_equal(lhs, rhs):
     """
     return bool(_make._graph_equal(lhs, rhs))
 
+
 def structural_hash(value):
     """Hash a Relay expression structurally.
 
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 9d14463a530c7c548158cb2338861b1fb2d2ea2b..909b175f08ca650626f66f842a1b320a71a644ba 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -49,27 +49,25 @@ def transpose(data, axes=None):
     return _make.transpose(data, list(axes))
 
 
-def squeeze(data, axes=None):
+def squeeze(data, axis=None):
     """Squeeze axes in the array.
 
     Parameters
     ----------
-    data : relay.Expr
+    data : tvm.relay.Expr
         The input data to the operator.
 
-    axes : None or List[int]
-        Axes to remove.
-        If axes = [] or = None, remove all axis of dimensions 1.
-        Otherwise, remove all axis in axes.
-        If any axis in axes has dimension that does not equal 1, it is an error.
+    axis : None or List[int]
+        The set of axes to remove.
+        If axis = None, remove all axis of dimensions 1.
+        If any specified axis has dimension that does not equal 1, it is an error.
 
     Returns
     -------
-    result : relay.Expr
+    result : tvm.relay.Expr
         The squeezed result.
     """
-    axes = axes or []
-    return _make.squeeze(data, list(axes))
+    return _make.squeeze(data, axis)
 
 
 def reshape(data, newshape):
diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc
index 7aab9bb3223b5edf8c64a1132f3fec4bae8f8c41..8409581b53bf45e68301bde55449f9c3699678c2 100644
--- a/src/relay/ir/alpha_equal.cc
+++ b/src/relay/ir/alpha_equal.cc
@@ -296,13 +296,23 @@ class AlphaEqualHandler:
     if (const CallNode* rhs = other.as<CallNode>()) {
       if (!ExprEqual(lhs->op, rhs->op)) return false;
       if (lhs->args.size() != rhs->args.size()) return false;
-      if (lhs->type_args.size() != rhs->type_args.size()) return false;
-
+      // skip type_args check for primitive ops.
+      bool is_primitive = IsPrimitiveOp(lhs->op);
+      if (!is_primitive) {
+        if (lhs->type_args.size() != rhs->type_args.size()) {
+          return false;
+        }
+      }
       for (size_t i = 0; i < lhs->args.size(); ++i) {
-        if (!ExprEqual(lhs->args[i], rhs->args[i])) return false;
+        if (!ExprEqual(lhs->args[i], rhs->args[i])) {
+          return false;
+        }
       }
-      for (size_t i = 0; i < lhs->type_args.size(); ++i) {
-        if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
+
+      if (!is_primitive) {
+        for (size_t i = 0; i < lhs->type_args.size(); ++i) {
+          if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
+        }
       }
       return AttrEqual(lhs->attrs, rhs->attrs);
     } else {
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index 557daa98e89988c1f4f32a844029e1f6147bc356..b7a752d43a5c3c4019fefb6ea126b0bc3d2fa573 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -12,12 +12,12 @@
 namespace tvm {
 namespace relay {
 
-Expr ExprMutator::Mutate(const Expr& expr) {
+Expr ExprMutator::VisitExpr(const Expr& expr) {
   auto it = this->memo_.find(expr);
   if (it != this->memo_.end()) {
     return it->second;
   } else {
-    Expr new_expr = ExprMutator::VisitExpr(expr);
+    Expr new_expr = ExprFunctor::VisitExpr(expr);
     memo_[expr] = new_expr;
     return new_expr;
   }
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 5faa0805426a83194c898cebed25232a05d562c0..635f04668f3317b070ca7cedbcd8d2bd46d3f063 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -761,9 +761,9 @@ Examples::
 TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
 
 Expr MakeSqueeze(Expr data,
-                 Array<IndexExpr> axes) {
+                 Array<Integer> axis) {
   auto attrs = make_node<SqueezeAttrs>();
-  attrs->axes = std::move(axes);
+  attrs->axis = std::move(axis);
   static const Op& op = Op::Get("squeeze");
   return CallNode::make(op, {data}, Attrs(attrs), {});
 }
@@ -785,8 +785,8 @@ bool SqueezeRel(const Array<Type>& types,
   const auto* param = attrs.as<SqueezeAttrs>();
   CHECK(param != nullptr);
   std::vector<IndexExpr> result_shape;
-  // if axes is empty, squeeze all axes of dimension 1
-  if (param->axes.size() == 0) {
+  // if axes is None, squeeze all axes of dimension 1
+  if (!param->axis.defined()) {
     for (const auto& e : data->shape) {
       const int64_t* axis_ptr = as_const_int(e);
       CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
@@ -800,10 +800,8 @@ bool SqueezeRel(const Array<Type>& types,
     for (const auto& e : data->shape) {
       original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
     }
-    for (const auto& e : param->axes) {
-      const int64_t* axis_ptr = as_const_int(e);
-      CHECK(axis_ptr != nullptr);
-      original_shape.at(*axis_ptr).second = false;
+    for (const auto& e : param->axis) {
+      original_shape.at(e->value).second = false;
     }
     for (const auto p : original_shape) {
       if (p.second) {
diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc
new file mode 100644
index 0000000000000000000000000000000000000000..b1c767704372e6ef65d1b95fe20be87d937a6dc5
--- /dev/null
+++ b/src/relay/pass/fold_scale_axis.cc
@@ -0,0 +1,554 @@
+/*!
+ * Copyright (c) 2018 by Contributors
+ *
+ * \file fold_scale_axis.cc
+ *
+ * \brief Fold axis scaling into weights of
+ *  conv/dense operators.
+ */
+#include <tvm/relay/pass.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include "pattern_util.h"
+#include "../op/nn/layout.h"
+
+namespace tvm {
+namespace relay {
+/*!
+ * \brief namespace of fold scale axis
+ *
+ * Use namespace to reduce potential naming conflict.
+ */
+namespace fold_scale_axis {
+
+using runtime::TypedPackedFunc;
+
+
+// FoldScaleAxisFoward algorithm:
+//
+// The general idea is that we transform Expr to tuple of
+// (value, axes, scale), where the final result satiesfies:
+//
+// result = value
+// for i, k in enumerate(axes):
+//    k-ith dimension of result *= i-th dimension of scale
+//
+// Then we can propagate this signal along and fold the scale if necessary.
+// However, it is possible that certain scale may never be consumed
+// if there is no dense/conv2d that follows multiplication.
+//
+// In order to make sure all the scale we sent out can be consumed eventually,
+// we run a backward "preparation phase", which propagates the demand
+// of the potential axes scaling back to its input.
+//
+// The folding process is done in two steps:
+// - Prepare phase: backward propagation of demand.
+// - Transform phase: forward transformation,
+
+/*!
+ * \brief sorted array axis, can also be nullptr.
+ *
+ *  nullptr means no scaling request can be done.
+ */
+using AxesSet = Array<Integer>;
+
+/*!
+ * \brief Merge two axis set together by taking
+ *  intersection.
+ *
+ * \note The axes in a AxesSet should be sorted.
+ *
+ * \param lhs The left axis.
+ * \param rhs The right axis.
+ * \return The result of the inersection.
+ */
+AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) {
+  if (!lhs.defined()) return lhs;
+  if (!rhs.defined()) return rhs;
+  // This code relies on axes in a AxesSet to be sorted.
+  AxesSet ret;
+  size_t i = 0, j = 0;
+  while (i < lhs.size() && j < rhs.size()) {
+    if (lhs[i]->value < rhs[j]->value) {
+      ++i;
+    } else if (lhs[i]->value > rhs[j]->value) {
+      ++j;
+    } else {
+      ret.push_back(lhs[i]);
+      ++i; ++j;
+    }
+  }
+  return ret;
+}
+
+/*!
+ * \param Get function from op_map.
+ * \param op_map The OpMap.
+ * \param op The operator being called.
+ * \tparam ValueType the content value type.
+ * \return The result value map.
+ */
+template<typename ValueType>
+ValueType GetFunc(const OpMap<ValueType>& op_map,
+                  const Expr& op) {
+  if (const OpNode* opnode = op.as<OpNode>()) {
+    return op_map.get(GetRef<Op>(opnode), ValueType());
+  } else {
+    return ValueType();
+  }
+}
+
+/*!
+ * \brief Preparation function for for pass scale forward.
+ * \param call The call node.
+ * \param out_scale_axes Possible scaling on axes of the output.
+ * \return The result scaling on axes of the input.
+ */
+using FForwardPrep = runtime::TypedPackedFunc<
+  Array<AxesSet> (const Call& call, const AxesSet& out_scale_axes)>;
+
+/*! \brief Axis scale tuple.  */
+class STupleNode : public Node {
+ public:
+  /*! \brief The value */
+  Expr value;
+  /*! \brief The axes to scale, can be nullptr(means no-scaling) */
+  AxesSet axes = NullValue<AxesSet>();
+  /*! \brief The scaling factor */
+  Expr scale = NullValue<Expr>();
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("value", &value);
+    v->Visit("axes", &axes);
+    v->Visit("scale", &scale);
+  }
+
+  static constexpr const char* _type_key = "relay.fold_scale_axis.STupleNode";
+  TVM_DECLARE_NODE_TYPE_INFO(STupleNode, Node);
+};
+
+RELAY_DEFINE_NODE_REF(STuple, STupleNode, NodeRef);
+
+/*!
+ * \brief The transform function, transform an old call to
+ *  a new one given the new args.
+ * \param ref_call Reference call node that represent the op and the types.
+ * \param expected_out_axes The scale axes allowed in the output.
+ * \param sargs The input arguments.
+ */
+using FForwardTransform = TypedPackedFunc<
+  STuple(const Call& ref_call,
+         const AxesSet& expected_out_axes,
+         const Array<STuple>& sargs)>;
+
+//----------------------------------------------
+// Generic Visitors for FScaleAxisForward
+//----------------------------------------------
+class FScaleAxisForwardPrep : private ExprVisitor {
+ public:
+  std::unordered_map<const Node*, AxesSet>
+  Prepare(const Expr& body) {
+    this->Update(body, NullValue<AxesSet>());
+    this->VisitExpr(body);
+    // flist is added in the Post-DFS order
+    // which is a special case of topological order.
+    // We reversely traverse the list to invoke the lazy functions.
+    // This act like a backprop of valid scale axis messages
+    for (auto it = flist_.rbegin(); it != flist_.rend(); ++it) {
+      (*it)();
+    }
+    // return the created message;
+    return std::move(message_);
+  }
+
+ private:
+  // The invoke list
+  std::vector<std::function<void()> > flist_;
+  // The message on each node.
+  std::unordered_map<const Node*, AxesSet> message_;
+  // Update the message stored at node.
+  void Update(const Expr& node, const AxesSet& axes) {
+    // We run intersection of messages:
+    //
+    // %y = multiply(%x, %scale)
+    // %z1 = conv2d(%y, %w)
+    // %z2 = exp(%y)
+    //
+    // Consider the above code example,
+    // because %z2 will propagate null to %y,
+    // the AxesSet on %y is also null,
+    // and the forward folding won't be triggered.
+    const Node* key = node.get();
+    if (message_.count(key)) {
+      message_[key] = Intersect(message_[key], axes);
+    } else {
+      message_[key] = axes;
+    }
+  }
+  // Visitor pattern override.
+  void VisitExpr_(const LetNode* call) {
+    LOG(FATAL) << "FoldScaleAxis only accept dataflow-form";
+  }
+
+  void VisitExpr_(const FunctionNode* op) {
+    ExprVisitor::VisitExpr_(op);
+    auto flazy = [this, op] {
+      this->Update(op->body, NullValue<AxesSet>());
+    };
+    flist_.push_back(flazy);
+  }
+
+  void VisitExpr_(const CallNode* call) {
+    ExprVisitor::VisitExpr_(call);
+    // function to be lazily invoked
+    auto flazy = [this, call]() {
+      static const auto& fprep =
+        Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
+      // find the message send to this node.
+      auto it = message_.find(call);
+      AxesSet out_axes;
+      if (it != message_.end()) {
+        out_axes = it->second;
+      } else {
+        out_axes = NullValue<AxesSet>();
+      }
+      // pass the message back to all the children it references.
+      auto f = GetFunc(fprep, call->op);
+      if (f != nullptr) {
+        Array<AxesSet> in_axes = f(GetRef<Call>(call), out_axes);
+        CHECK_EQ(in_axes.size(), call->args.size());
+        for (size_t i = 0; i < call->args.size(); ++i) {
+          this->Update(call->args[i], in_axes[i]);
+        }
+      } else {
+        for (size_t i = 0; i < call->args.size(); ++i) {
+          this->Update(call->args[i], NullValue<AxesSet>());
+        }
+      }
+    };
+    flist_.push_back(flazy);
+  }
+
+  void VisitExpr_(const TupleNode* op) {
+    ExprVisitor::VisitExpr_(op);
+    // do not support pass scale through tuple for now.
+    auto flazy = [this, op]() {
+      for (const Expr& field : op->fields) {
+        this->Update(field, NullValue<AxesSet>());
+      }
+    };
+    flist_.push_back(flazy);
+  }
+
+  void VisitExpr_(const IfNode* op) {
+    ExprVisitor::VisitExpr_(op);
+    // do pass through condition
+    // by assigning NullValue<AxesSet>
+    // it means fuse signal cannot pass
+    // through into these subexpressions.
+    auto flazy = [this, op]() {
+      this->Update(op->cond, NullValue<AxesSet>());
+      this->Update(op->true_branch, NullValue<AxesSet>());
+      this->Update(op->false_branch, NullValue<AxesSet>());
+    };
+    flist_.push_back(flazy);
+  }
+};
+
+class FScaleAxisForwardTransform : private ExprMutator {
+ public:
+  // Transform expression.
+  Expr Transform(Expr expr) {
+    expected_scale_axes_ =
+        FScaleAxisForwardPrep().Prepare(expr);
+    return this->Mutate(expr);
+  }
+
+ private:
+  // Valid axes on each node.
+  std::unordered_map<const Node*, AxesSet> expected_scale_axes_;
+  std::unordered_map<const Node*, STuple> scale_memo_;
+  // If user simply call mutate,
+  // then only Expr is returned and we cannot
+  // accept outstanding scales.
+  Expr VisitExpr(const Expr& expr) final {
+    Expr res = ExprMutator::VisitExpr(expr);
+    CHECK(!scale_memo_.count(expr.get()))
+        << "Outstanding scale";
+    return res;
+  }
+
+  STuple GetSTuple(const Expr& expr) {
+    Expr res = ExprMutator::VisitExpr(expr);
+    auto it = scale_memo_.find(expr.get());
+    if (it != scale_memo_.end()) {
+      CHECK(it->second->value.same_as(res));
+      return it->second;
+    } else {
+      auto node = make_node<STupleNode>();
+      node->value = res;
+      return STuple(node);
+    }
+  }
+
+  Expr VisitExpr_(const CallNode* call_node) final {
+    static const auto& ftransform =
+        Op::GetAttr<FForwardTransform>("FScaleAxisForwardTransform");
+    auto new_op = this->Mutate(call_node->op);
+    bool has_scale = false;
+    bool unchanged = call_node->op.same_as(new_op);
+
+    Array<STuple> call_sargs;
+    Array<Expr> call_args;
+    for (auto arg : call_node->args) {
+      STuple new_sarg = this->GetSTuple(arg);
+      unchanged &= new_sarg->value.same_as(arg);
+      if (new_sarg->axes.defined()) has_scale = true;
+      call_sargs.push_back(new_sarg);
+      call_args.push_back(new_sarg->value);
+    }
+
+    // get expected scale axes.
+    AxesSet expected_out_axes;
+    auto axis_it = expected_scale_axes_.find(call_node);
+    if (axis_it != expected_scale_axes_.end()) {
+      expected_out_axes = axis_it->second;
+    }
+    // propagation function
+    auto f = GetFunc(ftransform, call_node->op);
+    if (f != nullptr) {
+      STuple sret = f(GetRef<Call>(call_node), expected_out_axes, call_sargs);
+      if (sret.defined()) {
+        if (sret->axes.defined()) {
+          scale_memo_[call_node] = sret;
+        }
+        return sret->value;
+      }
+    }
+    // normal path
+    CHECK(!has_scale) << "Outstanding scale, on op=" << call_node->op;
+    if (unchanged) {
+      return GetRef<Expr>(call_node);
+    } else {
+      return CallNode::make(
+          new_op, call_args, call_node->attrs, call_node->type_args);
+    }
+  }
+};
+
+//----------------------------------------------
+// Per operator defs for FScaleAxisForward
+//----------------------------------------------
+
+// Intermediate operators
+Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
+  return {out};
+}
+
+STuple ReluForwardTransform(const Call& ref_call,
+                              const AxesSet& expected_axes,
+                              const Array<STuple>& sargs) {
+  if (!sargs[0]->axes.defined()) return STuple();
+  // return transformed conv2d
+  auto rnode = make_node<STupleNode>();
+  rnode->value = CallNode::make(
+      ref_call->op, {sargs[0]->value}, ref_call->attrs, {});
+  rnode->scale = sargs[0]->scale;
+  rnode->axes = sargs[0]->axes;
+  return STuple(rnode);
+}
+
+RELAY_REGISTER_OP("nn.relu")
+.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
+
+RELAY_REGISTER_OP("nn.relu")
+.set_attr<FForwardTransform>("FScaleAxisForwardTransform", ReluForwardTransform);
+
+RELAY_REGISTER_OP("nn.leaky_relu")
+.set_attr<FForwardPrep>("FScaleAxisForwardPrep", ReluForwardPrep);
+
+RELAY_REGISTER_OP("nn.leaky_relu")
+.set_attr<FForwardTransform>("FScaleAxisForwardTransform", ReluForwardTransform);
+
+// AddSub
+Array<AxesSet> AddSubForwardPrep(const Call& call, AxesSet out_axes) {
+  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
+  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
+
+  auto none = NullValue<AxesSet>();
+  if (MatchBroadcastToLeftAxes(tlhs, trhs, out_axes)) {
+    return {out_axes, none};
+  } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_axes)) {
+    return {none, out_axes};
+  } else {
+    return {none, none};
+  }
+}
+
+STuple AddSubForwardTransform(const Call& ref_call,
+                              const AxesSet& expected_out_axes,
+                              const Array<STuple>& sargs) {
+  if (!sargs[0]->axes.defined() && !sargs[1]->axes.defined()) {
+    return STuple();
+  }
+  const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
+  const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
+
+  auto rnode = make_node<STupleNode>();
+  if (sargs[0]->axes.defined()) {
+    CHECK(!sargs[1]->axes.defined());
+    CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, sargs[0]->axes));
+    Expr scale = ExpandBiasToMatchAxis(
+        sargs[0]->scale, tlhs->shape.size(), sargs[0]->axes);
+    Expr rhs = Divide(sargs[1]->value, scale);
+    rnode->value = CallNode::make(ref_call->op, {sargs[0]->value, rhs},
+                                  ref_call->attrs, ref_call->type_args);
+    rnode->scale = sargs[0]->scale;
+    rnode->axes = sargs[0]->axes;
+  } else {
+    CHECK(sargs[1]->axes.defined());
+    CHECK(sargs[0]->axes.defined());
+    CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, sargs[1]->axes));
+    Expr scale = ExpandBiasToMatchAxis(
+        sargs[1]->scale, trhs->shape.size(), sargs[1]->axes);
+    Expr lhs = Divide(sargs[0]->value, scale);
+    rnode->value = CallNode::make(ref_call->op, {lhs, sargs[1]->value},
+                                  ref_call->attrs, ref_call->type_args);
+    rnode->scale = sargs[1]->scale;
+    rnode->axes = sargs[1]->axes;
+  }
+  return STuple(rnode);
+}
+
+RELAY_REGISTER_OP("add")
+.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
+
+RELAY_REGISTER_OP("add")
+.set_attr<FForwardTransform>("FScaleAxisForwardTransform", AddSubForwardTransform);
+
+RELAY_REGISTER_OP("subtract")
+.set_attr<FForwardPrep>("FScaleAxisForwardPrep", AddSubForwardPrep);
+
+RELAY_REGISTER_OP("subtract")
+.set_attr<FForwardTransform>("FScaleAxisForwardTransform", AddSubForwardTransform);
+
+// Producer operators
+// Multiply produces the scale-axis pair.
+STuple MultiplyForwardTransform(const Call& ref_call,
+                                const AxesSet& expected_out_axes,
+                                const Array<STuple>& sargs) {
+  if (!expected_out_axes.defined()) return STuple();
+  // TODO(tvm-team) allow same axes accumulation
+  // not as important because it is less common in nn.
+  CHECK(!sargs[0]->axes.defined());
+  CHECK(!sargs[1]->axes.defined());
+  const auto* tlhs = ref_call->args[0]->type_as<TensorTypeNode>();
+  const auto* trhs = ref_call->args[1]->type_as<TensorTypeNode>();
+
+  Expr lhs = sargs[0]->value;
+  Expr rhs = sargs[1]->value;
+  auto rnode = make_node<STupleNode>();
+  if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs)) {
+    rnode->value = lhs;
+    rnode->scale = rhs;
+    rnode->axes = expected_out_axes;
+  } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs)) {
+    rnode->value = rhs;
+    rnode->scale = lhs;
+    rnode->axes = expected_out_axes;
+  }
+  return STuple(rnode);
+}
+
+RELAY_REGISTER_OP("multiply")
+.set_attr<FForwardTransform>("FScaleAxisForwardTransform", MultiplyForwardTransform);
+
+// Consumer operators
+// Conv2D send out requirement of axis folding.
+Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
+  // TODO(tvm-team) support general data layout
+  // by transforming weight
+  const auto* param = call->attrs.as<Conv2DAttrs>();
+  CHECK(param != nullptr);
+  Layout data_layout(param->data_layout);
+  Layout weight_layout(param->weight_layout);
+  int c_big_axis = data_layout.indexof('C');
+  int c_small_axis = data_layout.indexof('c');
+  const auto* tdata = call->args[0]->type_as<TensorTypeNode>();
+  CHECK(tdata) << "require checked type";
+
+  CHECK_GE(c_big_axis, 0);
+  AxesSet data_axes = NullValue<AxesSet>();
+  // For now, we only support simple pattern (no folded weight/data)
+  // More general layout can be supported under the current framework.
+  // By using a unified layout transformation.
+  // We only need to change the Prep and Mutate function.
+  //
+  // only handle depthwise or full conv2d.
+  // TODO(tvm-team) handle grouped conv by reshape + bcast
+  bool is_depthwise_conv2d =
+      is_const_int(tdata->shape[c_big_axis], param->groups);
+  if (weight_layout.indexof('i') < 0 &&
+      c_small_axis < 0 &&
+      (param->groups == 1 || is_depthwise_conv2d)) {
+    data_axes = {c_big_axis};
+  }
+  return {data_axes, NullValue<AxesSet>()};
+}
+
+// Conv2D consumes the scale axis during transformation.
+STuple Conv2DForwardTransform(const Call& ref_call,
+                              const AxesSet& expected_axes,
+                              const Array<STuple>& sargs) {
+  // if data do not have scale, normal transform path.
+  STuple sdata = sargs[0];
+  if (!sdata->scale.defined()) return STuple();
+  CHECK(sdata->axes.defined());
+  const auto* param = ref_call->attrs.as<Conv2DAttrs>();
+  CHECK(param != nullptr);
+  Layout data_layout(param->data_layout);
+  Layout weight_layout(param->weight_layout);
+  int c_big_axis = data_layout.indexof('C');
+  CHECK_GE(c_big_axis, 0);
+  // For now, we only support simple pattern (no folded weight/data)
+  // TODO(tvm-team) support general data layout
+  CHECK_EQ(weight_layout.indexof('i'), -1);
+  CHECK(sdata->axes.size() == 1 &&
+        c_big_axis == sdata->axes[0]->value);
+  int big_ic_axis = weight_layout.indexof('I');
+
+  const auto* tdata = ref_call->args[0]->type_as<TensorTypeNode>();
+  // Check it must be depthwise or full conv2d.
+  bool is_depthwise_conv2d =
+      is_const_int(tdata->shape[c_big_axis], param->groups);
+  CHECK(param->groups == 1 || is_depthwise_conv2d);
+
+  // match the ic_axis
+  Expr scale = ExpandBiasToMatchAxis(
+      sdata->scale, weight_layout.ndim(), {big_ic_axis});
+  Expr weight = Multiply(sargs[1]->value, scale);
+  // return transformed conv2d
+  auto rnode = make_node<STupleNode>();
+  rnode->value = CallNode::make(
+      ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
+  return STuple(rnode);
+}
+
+RELAY_REGISTER_OP("nn.conv2d")
+.set_attr<FForwardPrep>("FScaleAxisForwardPrep", Conv2DForwardPrep);
+
+RELAY_REGISTER_OP("nn.conv2d")
+.set_attr<FForwardTransform>("FScaleAxisForwardTransform", Conv2DForwardTransform);
+
+
+Expr ForwardFoldScaleAxis(Expr data) {
+  return FScaleAxisForwardTransform().Transform(data);
+}
+
+// Expose the FoldScaleAxisFoward
+TVM_REGISTER_API("relay._ir_pass.forward_fold_scale_axis")
+.set_body_typed<Expr(Expr)>(ForwardFoldScaleAxis);
+
+}  // namespace fold_scale_axis
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h
new file mode 100644
index 0000000000000000000000000000000000000000..a395e74cdf0b887f1b8197c1ffd5beeff0016384
--- /dev/null
+++ b/src/relay/pass/pattern_util.h
@@ -0,0 +1,123 @@
+/*!
+ *  Copyright (c) 2018 by Contributors.
+ *
+ * \file tvm/relay/pass/pattern_util.h
+ * \brief Header of internal operator functions
+ *  These can be used for writing passes.
+ */
+#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
+#define TVM_RELAY_PASS_PATTERN_UTIL_H_
+
+#include <tvm/relay/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/attrs/transform.h>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief Try to match lhs and rhs via broadcasting rule, such that:
+ *
+ * rhs matches the dimension of lhs specified by lhs_axes
+ * rhs's value equals 1 on rest of dimensions.
+ *
+ * \param tlhs The type of left operand (data)
+ * \param trhs The type right operand (bias)
+ * \param lhs_axes The axes on lhs to match.
+ * \param rhs_value A squeezed version of rhs which only contains matched dimension.
+ * \return Whether match is successful.
+ */
+inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
+                                     const TensorTypeNode* trhs,
+                                     const Array<Integer>& lhs_axes,
+                                     Expr* rhs_value = nullptr) {
+  if (tlhs->shape.size() < trhs->shape.size()) return false;
+  AttrsEqual equal;
+  size_t base = tlhs->shape.size() - trhs->shape.size();
+  size_t j = 0;
+
+  NodePtr<SqueezeAttrs> squeeze_attrs;
+  if (rhs_value != nullptr) {
+    squeeze_attrs = make_node<SqueezeAttrs>();
+  }
+
+  for (size_t i = 0; i < tlhs->shape.size(); ++i) {
+    if (j < lhs_axes.size() && i == static_cast<size_t>(lhs_axes[j]->value)) {
+      if (i < base || !equal(tlhs->shape[i], trhs->shape[i - base])) {
+        return false;
+      }
+      ++j;
+    } else if (i >= base) {
+      if (!is_const_int(trhs->shape[i - base], 1)) {
+        return false;
+      }
+      if (rhs_value != nullptr) {
+        squeeze_attrs->axis.push_back(static_cast<int>(i - base));
+      }
+    }
+  }
+  if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) {
+    static const Op& squeeze_op = Op::Get("squeeze");
+    *rhs_value = CallNode::make(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {});
+  }
+  return true;
+}
+
+/*!
+ * \brief Expand 1D Tensor to match axis.
+ *
+ * The result bias can be used to add or multiply to
+ * the target Tensor on the specified axis via broadcasting rule.
+ *
+ * \param bias The bias.
+ * \param target_ndim target dimension.
+ * \param axes The axis on the output we want to match on.
+ */
+inline Expr ExpandBiasToMatchAxis(Expr bias,
+                                  int target_ndim,
+                                  const Array<Integer>& axes) {
+  static const Op& expand_dims = Op::Get("expand_dims");
+  for (size_t i = axes.size(); i != 0; --i) {
+    if (i == axes.size()) {
+      int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1;
+      if (num_pad_axis > 0) {
+        auto attrs = make_node<ExpandDimsAttrs>();
+        attrs->axis = i;
+        attrs->num_newaxis = static_cast<int>(num_pad_axis);
+        bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
+      }
+    } else {
+      int64_t diff = axes[i]->value - axes[i - 1]->value;
+      CHECK_GE(diff, 0L);
+      if (diff > 0) {
+        auto attrs = make_node<ExpandDimsAttrs>();
+        attrs->axis = i;
+        attrs->num_newaxis = static_cast<int>(diff);
+        bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
+      }
+    }
+  }
+  return bias;
+}
+
+inline Expr Multiply(Expr lhs, Expr rhs) {
+  static const Op& op = Op::Get("multiply");
+  return CallNode::make(op, {lhs, rhs}, Attrs(), {});
+}
+
+inline Expr Divide(Expr lhs, Expr rhs) {
+  static const Op& op = Op::Get("divide");
+  return CallNode::make(op, {lhs, rhs}, Attrs(), {});
+}
+
+
+inline Expr ReshapeLike(Expr lhs, Expr rhs) {
+  static const Op& op = Op::Get("reshape_like");
+  return CallNode::make(op, {lhs, rhs}, Attrs(), {});
+}
+
+
+
+}  // namespace relay
+}  // namespace tvm
+#endif  // TVM_RELAY_PASS_PATTERN_UTIL_H_
diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index 7c8eeef92c5d59c5ebf03e48a9e953900721bbdc..c1f6cdc639740ce7eeca50a2c43354ee9855b452 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -406,28 +406,57 @@ class TypeInferencer::Resolver : public ExprMutator {
     CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
         << "Cannot resolve type of " << GetRef<Expr>(op)
         << " at " << op->span;
+
     Expr new_e = ExprMutator::VisitExpr_(op);
-    if (!checked_type.same_as(new_e->checked_type_)) {
+    // new_call and new_var's code is only going to be valid for VarNode/CallNode.
+    // Compiler optimization will likely fold these away for other nodes.
+    CallNode* new_call =(
+        std::is_base_of<CallNode, T>::value ?
+        static_cast<CallNode*>(new_e.node_.get()) : nullptr);
+    VarNode* new_var =(
+        std::is_base_of<VarNode, T>::value ?
+        static_cast<VarNode*>(new_e.node_.get()) : nullptr);
+
+    // check if we need update the new_e
+    bool need_update_type = !checked_type.same_as(new_e->checked_type_);
+    bool need_update_call = (
+        std::is_base_of<CallNode, T>::value &&
+        it->second.type_args.defined() &&
+        !it->second.type_args.same_as(new_call->type_args));
+    bool need_update_var = (
+        std::is_base_of<VarNode, T>::value &&
+        update_missing_type_annotation_ &&
+        !new_var->type_annotation.defined());
+
+    if (!need_update_type && !need_update_var && !need_update_call) return new_e;
+
+    if (!new_e.node_.unique()) {
       // Copy on write optimization
       // If new_e is an old expression,
       // we make a copy mutating an existing reference.
-      if (!new_e.node_.unique()) {
-        new_e = Expr(make_node<T>(*new_e.as<T>()));
-      }
-      new_e->checked_type_ = checked_type;
+      new_e = Expr(make_node<T>(*new_e.as<T>()));
+      new_call = (
+          std::is_base_of<CallNode, T>::value ?
+          static_cast<CallNode*>(new_e.node_.get()) : nullptr);
+      new_var = (
+          std::is_base_of<VarNode, T>::value ?
+          static_cast<VarNode*>(new_e.node_.get()) : nullptr);
     }
 
-    if (it->second.type_args.defined()) {
-      Call call = Downcast<Call>(new_e);
-      const CallNode* const_call_ref = call.operator->();
-      CallNode* call_ref = const_cast<CallNode*>(const_call_ref);
-      call_ref->type_args = it->second.type_args;
+    // attach the information.
+    if (need_update_type) {
+      new_e->checked_type_ = checked_type;
+    }
 
-      for (size_t i = 0; i < call->type_args.size(); i++) {
-        call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i]));
+    if (need_update_call) {
+      new_call->type_args = it->second.type_args;
+      for (size_t i = 0; i < new_call->type_args.size(); i++) {
+        new_call->type_args.Set(i, solver_->Resolve(new_call->type_args[i]));
       }
     }
-
+    if (need_update_var) {
+      new_var->type_annotation = checked_type;
+    }
     return new_e;
   }
 
@@ -438,6 +467,9 @@ class TypeInferencer::Resolver : public ExprMutator {
  private:
   const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_;
   TypeSolver* solver_;
+  // whether attach the checked type as type_annotation
+  // if original type anntation is missing.
+  bool update_missing_type_annotation_{true};
 };
 
 
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 2ee6f758f1004341725ff14404607c9b26f7dcf7..427ac562fbc7c9a7baa92c4c06832c5880423da7 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -55,8 +55,8 @@ def test_transpose_infer_type():
 def test_squeeze_infer_type():
     n, t, d = 1, 4, 1
     x = relay.var("x", relay.TensorType((n, t, d), "float32"))
-    y = relay.squeeze(x, axes=(2,))
-    assert "axes=" in y.astext()
+    y = relay.squeeze(x, axis=(2,))
+    assert "axis=" in y.astext()
     yy = relay.ir_pass.infer_type(y)
     assert yy.checked_type == relay.TensorType(
         (1, 4), "float32")
@@ -64,7 +64,7 @@ def test_squeeze_infer_type():
     n, t, d = 1, 4, 1
     x = relay.var("x", relay.TensorType((n, t, d), "float32"))
     y = relay.squeeze(x)
-    assert "axes=" not in y.astext()
+    assert "axis=" not in y.astext()
     yy = relay.ir_pass.infer_type(y)
     assert yy.checked_type == relay.TensorType(
         (4,), "float32")
@@ -74,7 +74,7 @@ def test_squeeze_infer_type():
 def test_squeeze_bad_axes_infer_type():
     n, t, d = 1, 4, 1
     x = relay.var("x", relay.TensorType((n, t, d), "float32"))
-    y = relay.squeeze(x, axes=(1,))
+    y = relay.squeeze(x, axis=(1,))
     yy = relay.ir_pass.infer_type(y)
 
 
diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ce3b35efe460906a439f192abfc1a75dc656d39
--- /dev/null
+++ b/tests/python/relay/test_pass_fold_scale_axis.py
@@ -0,0 +1,153 @@
+from tvm import relay
+
+
+def test_fold_fwd_simple():
+    """Simple testcase."""
+    def before(x, conv_weight, in_bias, in_scale, channels):
+        args = [x, conv_weight, in_bias, in_scale]
+        in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
+        in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
+        x = relay.multiply(x, in_scale)
+        x = relay.nn.relu(x)
+        x = relay.add(x, in_bias)
+        y = relay.nn.conv2d(x, conv_weight,
+                            channels=channels,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        return relay.Function(args, y)
+
+    def expected(x, conv_weight, in_bias, in_scale, channels):
+        # use a fixed order of args so alpha equal check can pass
+        args = [x, conv_weight, in_bias, in_scale]
+        in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
+        in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
+        squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
+        x = relay.nn.relu(x)
+        in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
+        x = relay.add(x, in_bias)
+        conv_weight = relay.multiply(
+            conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
+        y = relay.nn.conv2d(x, conv_weight,
+                            channels=channels,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        return relay.Function(args, y)
+
+    def check(shape, channels):
+        x =  relay.var("x", shape=shape)
+        in_channels = shape[1]
+        weight = relay.var("weight")
+        in_bias = relay.var("in_bias", shape=(in_channels,))
+        in_scale = relay.var("in_scale", shape=(in_channels,))
+
+        y1 = before(x, weight, in_bias, in_scale, channels)
+        y1 = relay.ir_pass.infer_type(y1)
+        type_dict = {x.name_hint:x.checked_type for x in y1.params}
+        weight = relay.var("weight", type_dict["weight"])
+        y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
+        y1_expected = expected(x, weight, in_bias, in_scale, channels)
+        assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
+
+    check((2, 4, 10, 10), 2)
+
+
+def test_fold_fwd_dual_path():
+    """scale axis being consumed by two consumers"""
+    def before(x, conv_weight, in_bias, in_scale, channels):
+        args = [x, conv_weight, in_bias, in_scale]
+        x = relay.multiply(in_scale, x)
+        x = relay.nn.relu(x)
+        x = relay.subtract(x, in_bias)
+        y1 = relay.nn.conv2d(x, conv_weight,
+                             channels=channels,
+                             kernel_size=(3, 3),
+                             data_layout="NHWC",
+                             weight_layout="HWOI",
+                             groups=channels,
+                             padding=(1, 1))
+        y2 = relay.nn.conv2d(x, conv_weight,
+                             channels=channels,
+                             kernel_size=(3, 3),
+                             data_layout="NHWC",
+                             weight_layout="HWOI",
+                             groups=channels,
+                             padding=(1, 1))
+        z = relay.add(y1, y2)
+        return relay.Function(args, z)
+
+    def expected(x, conv_weight, in_bias, in_scale, channels):
+        args = [x, conv_weight, in_bias, in_scale]
+        x = relay.nn.relu(x)
+        in_bias = relay.divide(in_bias, in_scale)
+        x = relay.subtract(x, in_bias)
+        y1 = relay.nn.conv2d(x,
+                             relay.multiply(conv_weight, in_scale),
+                             channels=channels,
+                             kernel_size=(3, 3),
+                             data_layout="NHWC",
+                             weight_layout="HWOI",
+                             groups=channels,
+                             padding=(1, 1))
+        y2 = relay.nn.conv2d(x,
+                             relay.multiply(conv_weight, in_scale),
+                             channels=channels,
+                             kernel_size=(3, 3),
+                             data_layout="NHWC",
+                             weight_layout="HWOI",
+                             groups=channels,
+                             padding=(1, 1))
+        z = relay.add(y1, y2)
+        return relay.Function(args, z)
+
+    def check(shape, channels):
+        x =  relay.var("x", shape=shape)
+        in_channels = shape[-1]
+        # test depthwise
+        assert in_channels == channels
+        weight = relay.var("weight")
+        in_bias = relay.var("in_bias", shape=(in_channels,))
+        in_scale = relay.var("in_scale", shape=(in_channels,))
+        y1 = before(x, weight, in_bias, in_scale, channels)
+        y1 = relay.ir_pass.infer_type(y1)
+        y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
+        type_dict = {x.name_hint:x.checked_type for x in y1.params}
+        weight = relay.var("weight", type_dict["weight"])
+        y1_expected = expected(x, weight, in_bias, in_scale, channels)
+        assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
+
+    check((2, 4, 10, 3), 3)
+
+
+def test_fold_fwd_fail():
+    """testcase where we canont fold"""
+    def before(x, conv_weight, in_bias, in_scale, channels):
+        x = relay.multiply(x, in_scale)
+        xx = relay.nn.leaky_relu(x, alpha=0.1)
+        y1 = relay.nn.conv2d(xx, conv_weight,
+                             channels=channels,
+                             kernel_size=(3, 3),
+                             data_layout="NHWC",
+                             padding=(1, 1))
+        z = relay.add(y1, x)
+        return relay.Function(relay.ir_pass.free_vars(z), z)
+
+    def check(shape, channels):
+        x =  relay.var("x", shape=shape)
+        in_channels = shape[-1]
+        # test depthwise
+        assert in_channels == channels
+        weight = relay.var("weight")
+        in_bias = relay.var("in_bias", shape=(in_channels,))
+        in_scale = relay.var("in_scale", shape=(in_channels,))
+        y1 = before(x, weight, in_bias, in_scale, channels)
+        y1 = relay.ir_pass.infer_type(y1)
+        y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
+        assert relay.ir_pass.alpha_equal(y1, y1_folded)
+
+    check((2, 11, 10, 4), 4)
+
+
+if __name__ == "__main__":
+    test_fold_fwd_simple()
+    test_fold_fwd_dual_path()
+    test_fold_fwd_fail()