From d5103bbcdfba4170f8ba16f4cb373869cade8cca Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Mon, 29 Oct 2018 19:43:37 -0700
Subject: [PATCH] [RELAY][PASS] FoldScaleAxis Backward (#2024)

---
 include/tvm/relay/expr_functor.h              |   6 +-
 python/tvm/relay/ir_pass.py                   |  29 ++
 src/relay/ir/expr_functor.cc                  |  12 +-
 src/relay/pass/fold_scale_axis.cc             | 455 +++++++++++++++++-
 src/relay/pass/pattern_util.h                 |  23 +-
 .../python/relay/test_pass_fold_scale_axis.py | 177 ++++++-
 6 files changed, 667 insertions(+), 35 deletions(-)

diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h
index bf4025f79..85a6b502d 100644
--- a/include/tvm/relay/expr_functor.h
+++ b/include/tvm/relay/expr_functor.h
@@ -135,9 +135,9 @@ class ExprVisitor
   void VisitExpr_(const TupleGetItemNode* op) override;
   virtual void VisitType(const Type& t);
 
- private:
-  // internal visited flag.
-  std::unordered_set<const Node*> visited_;
+ protected:
+  // Internal visiting counter
+  std::unordered_map<const Node*, size_t> visit_counter_;
 };
 
 /*!
diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py
index 6adfaacdc..82afa83ee 100644
--- a/python/tvm/relay/ir_pass.py
+++ b/python/tvm/relay/ir_pass.py
@@ -31,6 +31,29 @@ def infer_type(expr, env=None):
     return _ir_pass.infer_type(expr, env)
 
 
+def backward_fold_scale_axis(expr):
+    """Backward fold axis scaling into weights of conv2d/dense.
+
+    Parameters
+    ----------
+    expr : tvm.relay.Expr
+        The input expression, we expect that expr's types
+        should be fully inferred by infer_type.
+
+    Returns
+    -------
+    folded_expr : tvm.relay.Expr
+        The folded expression after transformation.
+
+    Note
+    ----
+    It is recommended to call backward_fold_scale_axis
+    before using forward_fold_scale_axis.
+    As backward folding targets common conv-bn pattern.
+    """
+    return _ir_pass.backward_fold_scale_axis(expr)
+
+
 def forward_fold_scale_axis(expr):
     """Fold the scaling of axis into weights of conv2d/dense.
 
@@ -44,6 +67,12 @@ def forward_fold_scale_axis(expr):
     -------
     folded_expr : tvm.relay.Expr
         The folded expression after transformation.
+
+    Note
+    ----
+    It is recommended to call backward_fold_scale_axis
+    before using forward_fold_scale_axis.
+    As backward folding targets common conv-bn pattern.
     """
     return _ir_pass.forward_fold_scale_axis(expr)
 
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index b7a752d43..ed7c1d1d1 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -160,10 +160,14 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
 Type ExprMutator::VisitType(const Type& t) { return t; }
 
 void ExprVisitor::VisitExpr(const Expr& expr) {
-  if (visited_.count(expr.get())) return;
-  using TParent = ExprFunctor<void(const Expr&)>;
-  TParent::VisitExpr(expr);
-  visited_.insert(expr.get());
+  auto it = visit_counter_.find(expr.get());
+  if (it != visit_counter_.end()) {
+    ++it->second;
+  } else {
+    using TParent = ExprFunctor<void(const Expr&)>;
+    TParent::VisitExpr(expr);
+    visit_counter_.insert({expr.get(), 1});
+  }
 }
 
 void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc
index b1c767704..e757118f3 100644
--- a/src/relay/pass/fold_scale_axis.cc
+++ b/src/relay/pass/fold_scale_axis.cc
@@ -24,9 +24,9 @@ namespace fold_scale_axis {
 using runtime::TypedPackedFunc;
 
 
-// FoldScaleAxisFoward algorithm:
+// FoldScaleAxis algorithm:
 //
-// The general idea is that we transform Expr to tuple of
+// The general idea is to transform Expr to tuple of
 // (value, axes, scale), where the final result satiesfies:
 //
 // result = value
@@ -41,9 +41,14 @@ using runtime::TypedPackedFunc;
 // we run a backward "preparation phase", which propagates the demand
 // of the potential axes scaling back to its input.
 //
-// The folding process is done in two steps:
+// Forward folding process is done in two steps:
 // - Prepare phase: backward propagation of demand.
 // - Transform phase: forward transformation,
+//
+// Similarly, backward folding process is done in two steps:
+// - Prepare phase: forward propagation of demand.
+// - Transform phase: transformation by push down the axes scale signal to inputs.
+//
 
 /*!
  * \brief sorted array axis, can also be nullptr.
@@ -99,7 +104,7 @@ ValueType GetFunc(const OpMap<ValueType>& op_map,
 }
 
 /*!
- * \brief Preparation function for for pass scale forward.
+ * \brief Preparation function for pass scale forward.
  * \param call The call node.
  * \param out_scale_axes Possible scaling on axes of the output.
  * \return The result scaling on axes of the input.
@@ -144,7 +149,7 @@ using FForwardTransform = TypedPackedFunc<
 //----------------------------------------------
 // Generic Visitors for FScaleAxisForward
 //----------------------------------------------
-class FScaleAxisForwardPrep : private ExprVisitor {
+class ForwardPrep : private ExprVisitor {
  public:
   std::unordered_map<const Node*, AxesSet>
   Prepare(const Expr& body) {
@@ -255,12 +260,12 @@ class FScaleAxisForwardPrep : private ExprVisitor {
   }
 };
 
-class FScaleAxisForwardTransform : private ExprMutator {
+class ForwardTransformer : private ExprMutator {
  public:
   // Transform expression.
-  Expr Transform(Expr expr) {
+  Expr Fold(Expr expr) {
     expected_scale_axes_ =
-        FScaleAxisForwardPrep().Prepare(expr);
+        ForwardPrep().Prepare(expr);
     return this->Mutate(expr);
   }
 
@@ -346,13 +351,13 @@ Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
 }
 
 STuple ReluForwardTransform(const Call& ref_call,
-                              const AxesSet& expected_axes,
-                              const Array<STuple>& sargs) {
+                            const AxesSet& expected_axes,
+                            const Array<STuple>& sargs) {
   if (!sargs[0]->axes.defined()) return STuple();
   // return transformed conv2d
   auto rnode = make_node<STupleNode>();
   rnode->value = CallNode::make(
-      ref_call->op, {sargs[0]->value}, ref_call->attrs, {});
+      ref_call->op, {sargs[0]->value}, ref_call->attrs, ref_call->type_args);
   rnode->scale = sargs[0]->scale;
   rnode->axes = sargs[0]->axes;
   return STuple(rnode);
@@ -474,8 +479,6 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
   Layout weight_layout(param->weight_layout);
   int c_big_axis = data_layout.indexof('C');
   int c_small_axis = data_layout.indexof('c');
-  const auto* tdata = call->args[0]->type_as<TensorTypeNode>();
-  CHECK(tdata) << "require checked type";
 
   CHECK_GE(c_big_axis, 0);
   AxesSet data_axes = NullValue<AxesSet>();
@@ -486,8 +489,7 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
   //
   // only handle depthwise or full conv2d.
   // TODO(tvm-team) handle grouped conv by reshape + bcast
-  bool is_depthwise_conv2d =
-      is_const_int(tdata->shape[c_big_axis], param->groups);
+  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
   if (weight_layout.indexof('i') < 0 &&
       c_small_axis < 0 &&
       (param->groups == 1 || is_depthwise_conv2d)) {
@@ -515,18 +517,24 @@ STuple Conv2DForwardTransform(const Call& ref_call,
   CHECK_EQ(weight_layout.indexof('i'), -1);
   CHECK(sdata->axes.size() == 1 &&
         c_big_axis == sdata->axes[0]->value);
+  int big_oc_axis = weight_layout.indexof('O');
   int big_ic_axis = weight_layout.indexof('I');
 
-  const auto* tdata = ref_call->args[0]->type_as<TensorTypeNode>();
   // Check it must be depthwise or full conv2d.
-  bool is_depthwise_conv2d =
-      is_const_int(tdata->shape[c_big_axis], param->groups);
+  bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, weight_layout);
   CHECK(param->groups == 1 || is_depthwise_conv2d);
+  Expr weight = sargs[1]->value;
 
   // match the ic_axis
-  Expr scale = ExpandBiasToMatchAxis(
-      sdata->scale, weight_layout.ndim(), {big_ic_axis});
-  Expr weight = Multiply(sargs[1]->value, scale);
+  if (is_depthwise_conv2d) {
+    Expr scale = ExpandBiasToMatchAxis(
+        sdata->scale, weight_layout.ndim(), {big_oc_axis});
+    weight = Multiply(weight, scale);
+  } else {
+    Expr scale = ExpandBiasToMatchAxis(
+        sdata->scale, weight_layout.ndim(), {big_ic_axis});
+    weight = Multiply(weight, scale);
+  }
   // return transformed conv2d
   auto rnode = make_node<STupleNode>();
   rnode->value = CallNode::make(
@@ -542,13 +550,416 @@ RELAY_REGISTER_OP("nn.conv2d")
 
 
 Expr ForwardFoldScaleAxis(Expr data) {
-  return FScaleAxisForwardTransform().Transform(data);
+  return ForwardTransformer().Fold(data);
 }
 
 // Expose the FoldScaleAxisFoward
 TVM_REGISTER_API("relay._ir_pass.forward_fold_scale_axis")
 .set_body_typed<Expr(Expr)>(ForwardFoldScaleAxis);
 
+//----------------------------------------
+// Implement backward transformations.
+//----------------------------------------
+class BackwardTransformer;
+
+/*!
+ * \brief Preparation function for for pass scale backward.
+ * \param call The call node.
+ * \param in_scale_axes Allowed input scaling.
+ * \return The result scaling on axes of the input.
+ */
+using FBackwardPrep = TypedPackedFunc<
+  AxesSet(const Call& call, const Array<AxesSet>& in_scale_axes)>;
+
+using FBackwardTransform = TypedPackedFunc<
+  Expr(const Call& call,
+       const AxesSet& axes,
+       const Expr& scale,
+       const BackwardTransformer& transformer)>;
+
+//----------------------------------------------
+// Generic Visitors for FScaleAxisBackward
+//----------------------------------------------
+/*!
+ * \brief Get reference counter of each internal ExprNode in body.
+ * \param body The body expression.
+ * \return The reference count mapping.
+ */
+std::unordered_map<const Node*, size_t>
+GetExprRefCount(const Expr& body) {
+  class ExprRefCounter : private ExprVisitor {
+   public:
+    std::unordered_map<const Node*, size_t>
+    Get(const Expr& body) {
+      this->VisitExpr(body);
+      return std::move(this->visit_counter_);
+    }
+  };
+  return ExprRefCounter().Get(body);
+}
+
+class BackwardPrep : private ExprVisitor {
+ public:
+  // The message on each node.
+  std::unordered_map<const Node*, AxesSet>
+  Prepare(const Expr& body) {
+    ref_counter_ = GetExprRefCount(body);
+    this->VisitExpr(body);
+    return std::move(message_);
+  }
+
+ private:
+  // The message on each node.
+  std::unordered_map<const Node*, AxesSet> message_;
+  // reference counter of an internal expr
+  std::unordered_map<const Node*, size_t> ref_counter_;
+  // Visit the expression.
+  void VisitExpr_(const CallNode* call) {
+    ExprVisitor::VisitExpr_(call);
+    static const auto& fprep =
+        Op::GetAttr<FBackwardPrep>("FScaleAxisBackwardPrep");
+    auto f = GetFunc(fprep, call->op);
+    if (f == nullptr) return;
+    auto rit = ref_counter_.find(call);
+    CHECK(rit != ref_counter_.end());
+    // We only allow propagation of scale backward
+    // if the expression is only referred by a single parent.
+    if (rit->second != 1) return;
+    Array<AxesSet> in_axes;
+    for (Expr arg : call->args) {
+      auto it = message_.find(arg.get());
+      if (it != message_.end()) {
+        in_axes.push_back(it->second);
+      } else {
+        in_axes.push_back(NullValue<AxesSet>());
+      }
+    }
+    AxesSet out_axes = f(GetRef<Call>(call), in_axes);
+    if (out_axes.defined()) {
+      message_[call] = out_axes;
+    }
+  }
+};
+
+class BackwardTransformerNode :
+      public Node,
+      private ExprMutator {
+ public:
+  // Run forward transform.
+  Expr Fold(Expr expr) {
+    expected_scale_axes_ = BackwardPrep().Prepare(expr);
+    return this->Mutate(expr);
+  }
+  /*!
+   * \brief Transform the expr to consider the scaling.
+   *
+   * \param expr The input expression.
+   * \param axes The axes to scale.
+   * \param scale The scale applied to the axes.
+   * \return The result of transformation.
+   */
+  Expr Transform(const Expr& expr, AxesSet axes, Expr scale) {
+    // NOTE: the result of Transform is not memoized.
+    // However, in the current rule, Transform will
+    // only be called to expr that is referred once.
+    if (const CallNode* call_node = expr.as<CallNode>()) {
+      return Transform(call_node, axes, scale);
+    } else {
+      CHECK(!axes.defined()) << "outstanding scale";
+      return ExprMutator::VisitExpr(expr);
+    }
+  }
+  /*!
+   * \brief Normal way of mutating call node.
+   * \param call_node The call node to be mutated.
+   * \return the result of the call Mutation.
+   */
+  Expr NormalCallTransform(const CallNode* call_node) {
+    return ExprMutator::VisitExpr_(call_node);
+  }
+  /*!
+   * \brief Get the expected axes on expr.
+   * \param expr The expresison.
+   * \return The expected axes.
+   */
+  AxesSet GetExpectedAxes(const Expr& expr) const {
+    auto it = expected_scale_axes_.find(expr.get());
+    if (it != expected_scale_axes_.end()) return it->second;
+    return NullValue<AxesSet>();
+  }
+
+  // solver is not serializable.
+  void VisitAttrs(tvm::AttrVisitor* v) final {}
+
+  static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer";
+  TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node);
+
+ private:
+  // Valid axes on each node.
+  std::unordered_map<const Node*, AxesSet> expected_scale_axes_;
+  // Override mutation of call.
+  Expr VisitExpr_(const CallNode* call_node) final {
+    return Transform(call_node, NullValue<AxesSet>(), NullValue<Expr>());
+  }
+  // Transform of CallNode.
+  Expr Transform(const CallNode* call_node, AxesSet axes, Expr scale);
+};
+
+class BackwardTransformer : public NodeRef {
+ public:
+  BackwardTransformer() {}
+  explicit BackwardTransformer(
+      ::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {
+  }
+  BackwardTransformerNode* operator->() const {
+    return static_cast<BackwardTransformerNode*>(node_.get());
+  }
+  using ContainerType = BackwardTransformerNode;
+};
+
+Expr BackwardTransformerNode::Transform(
+    const CallNode* call_node, AxesSet axes, Expr scale) {
+  static const auto& ftransform =
+      Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
+  auto f = GetFunc(ftransform, call_node->op);
+  if (f != nullptr) {
+    return f(GetRef<Call>(call_node),
+             axes,
+             scale,
+             GetRef<BackwardTransformer>(this));
+  } else {
+    CHECK(!axes.defined()) << "outstanding scale";
+    return NormalCallTransform(call_node);
+  }
+}
+
+
+//----------------------------------------------
+// Per operator defs for FScaleAxisForward
+//----------------------------------------------
+
+// Intermediate operators
+AxesSet ReluBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
+  return in_axes[0];
+}
+
+Expr ReluBackwardTransform(const Call& call,
+                           const AxesSet& axes,
+                           const Expr& scale,
+                           const BackwardTransformer& transformer) {
+  if (!axes.defined()) {
+    return transformer->NormalCallTransform(call.operator->());
+  }
+  Expr input = transformer->Transform(
+      call->args[0], axes, scale);
+  return CallNode::make(call->op, {input}, call->attrs, call->type_args);
+}
+
+RELAY_REGISTER_OP("nn.relu")
+.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);
+
+RELAY_REGISTER_OP("nn.relu")
+.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);
+
+RELAY_REGISTER_OP("nn.leaky_relu")
+.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", ReluBackwardPrep);
+
+RELAY_REGISTER_OP("nn.leaky_relu")
+.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", ReluBackwardTransform);
+
+// AddSub
+AxesSet AddSubBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
+  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
+  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
+  AttrsEqual equal;
+  if (in_axes[0].defined() &&
+      MatchBroadcastToLeftAxes(tlhs, trhs, in_axes[0])) {
+    return in_axes[0];
+  } else if (in_axes[1].defined() &&
+             MatchBroadcastToLeftAxes(trhs, tlhs, in_axes[1])) {
+    return in_axes[1];
+  } else if (in_axes[0].defined() &&
+             in_axes[1].defined() &&
+             equal(in_axes[0], in_axes[1]) &&
+             equal(tlhs->shape, trhs->shape)) {
+    // add of two elements.
+    return in_axes[0];
+  } else {
+    return NullValue<AxesSet>();
+  }
+}
+
+Expr AddSubBackwardTransform(const Call& call,
+                             const AxesSet& axes,
+                             const Expr& scale,
+                             const BackwardTransformer& transformer) {
+  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
+  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
+  if (!axes.defined()) {
+    return transformer->NormalCallTransform(call.operator->());
+  }
+  AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]);
+  AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]);
+  AttrsEqual equal;
+
+  if (lhs_axes.defined() && rhs_axes.defined()) {
+    CHECK(equal(lhs_axes, rhs_axes));
+    CHECK(equal(axes, lhs_axes));
+    Expr lhs = transformer->Transform(call->args[0], axes, scale);
+    Expr rhs = transformer->Transform(call->args[1], axes, scale);
+    return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
+  } else if (lhs_axes.defined()) {
+    CHECK(equal(axes, lhs_axes));
+    Expr lhs = transformer->Transform(call->args[0], axes, scale);
+    Expr rhs = transformer->Transform(
+        call->args[1], NullValue<AxesSet>(), NullValue<Expr>());
+    Expr rhs_scale = ExpandBiasToMatchAxis(
+        scale, tlhs->shape.size(), axes);
+    rhs = Multiply(rhs, rhs_scale);
+    return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
+  } else if (rhs_axes.defined()) {
+    CHECK(equal(axes, rhs_axes));
+    Expr lhs = transformer->Transform(
+        call->args[0], NullValue<AxesSet>(), NullValue<Expr>());
+    Expr rhs = transformer->Transform(call->args[1], axes, scale);
+    Expr lhs_scale = ExpandBiasToMatchAxis(
+        scale, trhs->shape.size(), axes);
+    lhs = Multiply(lhs, lhs_scale);
+    return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
+  } else {
+    LOG(FATAL) << "outstanding scale";
+    return Expr();
+  }
+}
+
+RELAY_REGISTER_OP("add")
+.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);
+
+RELAY_REGISTER_OP("add")
+.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
+
+RELAY_REGISTER_OP("subtract")
+.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", AddSubBackwardPrep);
+
+RELAY_REGISTER_OP("subtract")
+.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
+
+// Producer operators
+// Multiply produces the scale-axis pair.
+Expr MultiplyBackwardTransform(const Call& call,
+                               const AxesSet& axes,
+                               const Expr& scale,
+                               const BackwardTransformer& transformer) {
+  CHECK(!axes.defined()) << "outstanding scale";
+  const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
+  const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
+  AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]);
+  AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]);
+  if (lhs_axes.defined()) {
+    // NOTE we won't recursively call mutating on scale part.
+    // since there  won't be scale chance within scale part.
+    Expr rhs = call->args[1];
+    if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs)) {
+      return transformer->Transform(call->args[0], lhs_axes, rhs);
+    }
+  } else if (rhs_axes.defined()) {
+    Expr lhs = call->args[0];
+    if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs)) {
+      return transformer->Transform(call->args[1], rhs_axes, lhs);
+    }
+  }
+  return transformer->NormalCallTransform(call.operator->());
+}
+
+RELAY_REGISTER_OP("multiply")
+.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", MultiplyBackwardTransform);
+
+// Consumer operators
+// Conv2D send out requirement of axis folding.
+AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
+  const auto* param = call->attrs.as<Conv2DAttrs>();
+  CHECK(param != nullptr);
+  Layout out_layout(param->out_layout);
+  if (!out_layout.defined()) {
+    out_layout = Layout(param->data_layout);
+  }
+  Layout weight_layout(param->weight_layout);
+  int c_big_axis = out_layout.indexof('C');
+  int c_small_axis = out_layout.indexof('c');
+
+  CHECK_GE(c_big_axis, 0);
+  // For now, we only support simple pattern (no folded weight/data)
+  // More general layout can be supported under the current framework.
+  // By using a unified layout transformation.
+  // We only need to change the Prep and Mutate function.
+  //
+  // only handle depthwise or full conv2d.
+  // TODO(tvm-team) handle grouped conv by reshape + bcast
+  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
+  if (weight_layout.indexof('o') < 0 &&
+      weight_layout.indexof('i') < 0 &&
+      c_small_axis < 0 &&
+      (param->groups == 1 || is_depthwise_conv2d)) {
+    return {c_big_axis};
+  } else {
+    return NullValue<AxesSet>();
+  }
+}
+
+// Conv2D consumes the scale axis during transformation.
+Expr Conv2DBackwardTransform(const Call& call,
+                             const AxesSet& axes,
+                             const Expr& scale,
+                             const BackwardTransformer& transformer) {
+  if (!axes.defined()) {
+    return transformer->NormalCallTransform(call.operator->());
+  }
+  const auto* param = call->attrs.as<Conv2DAttrs>();
+  CHECK(param != nullptr);
+  Layout out_layout(param->out_layout);
+  if (!out_layout.defined()) {
+    out_layout = Layout(param->data_layout);
+  }
+  Layout weight_layout(param->weight_layout);
+  int c_big_axis = out_layout.indexof('C');
+  CHECK_GE(c_big_axis, 0);
+  // For now, we only support simple pattern (no folded weight/data)
+  // TODO(tvm-team) support general data layout
+  CHECK_EQ(weight_layout.indexof('o'), -1);
+  CHECK_EQ(weight_layout.indexof('i'), -1);
+  CHECK(axes.size() == 1 &&
+        c_big_axis == axes[0]->value);
+
+  int big_oc_axis = weight_layout.indexof('O');
+  // Check it must be depthwise or full conv2d.
+  bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
+  CHECK(param->groups == 1 || is_depthwise_conv2d);
+
+  Expr data = transformer->Transform(
+      call->args[0], NullValue<AxesSet>(), NullValue<Expr>());
+  Expr weight = transformer->Transform(
+      call->args[1], NullValue<AxesSet>(), NullValue<Expr>());
+  // scale on input for deptwise.
+  Expr wscale = ExpandBiasToMatchAxis(
+      scale, weight_layout.ndim(), {big_oc_axis});
+  weight = Multiply(weight, wscale);
+  return CallNode::make(
+      call->op, {data, weight}, call->attrs, call->type_args);
+}
+
+RELAY_REGISTER_OP("nn.conv2d")
+.set_attr<FBackwardPrep>("FScaleAxisBackwardPrep", Conv2DBackwardPrep);
+
+RELAY_REGISTER_OP("nn.conv2d")
+.set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", Conv2DBackwardTransform);
+
+Expr BackwardFoldScaleAxis(Expr data) {
+  return make_node<BackwardTransformerNode>()->Fold(data);
+}
+
+TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis")
+.set_body_typed<Expr(Expr)>(BackwardFoldScaleAxis);
+
 }  // namespace fold_scale_axis
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h
index a395e74cd..a41e6c35b 100644
--- a/src/relay/pass/pattern_util.h
+++ b/src/relay/pass/pattern_util.h
@@ -11,6 +11,7 @@
 #include <tvm/relay/op.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/attrs/transform.h>
+#include "../op/nn/layout.h"
 
 namespace tvm {
 namespace relay {
@@ -100,11 +101,31 @@ inline Expr ExpandBiasToMatchAxis(Expr bias,
   return bias;
 }
 
+/*!
+ * \brief Check if the call is depthwise conv2d.
+ *
+ * \param call The conv2d call.
+ * \param param The conv2d attributes.
+ * \return Whether it is depthwise_conv2d.
+ */
+inline bool IsDepthwiseConv2D(const Call& call,
+                              const Conv2DAttrs* param,
+                              const Layout& weight_layout) {
+  static const Layout kOIHW("OIHW");
+  auto wshape = ConvertLayout(
+      call->args[1]->type_as<TensorTypeNode>()->shape,
+      weight_layout, kOIHW);
+  return is_const_int(wshape[0], param->groups) &&
+      is_const_int(wshape[1], 1);
+}
+
+
 inline Expr Multiply(Expr lhs, Expr rhs) {
   static const Op& op = Op::Get("multiply");
   return CallNode::make(op, {lhs, rhs}, Attrs(), {});
 }
 
+
 inline Expr Divide(Expr lhs, Expr rhs) {
   static const Op& op = Op::Get("divide");
   return CallNode::make(op, {lhs, rhs}, Attrs(), {});
@@ -116,8 +137,6 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) {
   return CallNode::make(op, {lhs, rhs}, Attrs(), {});
 }
 
-
-
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_PASS_PATTERN_UTIL_H_
diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py
index 7ce3b35ef..1b57bdce0 100644
--- a/tests/python/relay/test_pass_fold_scale_axis.py
+++ b/tests/python/relay/test_pass_fold_scale_axis.py
@@ -62,14 +62,14 @@ def test_fold_fwd_dual_path():
                              channels=channels,
                              kernel_size=(3, 3),
                              data_layout="NHWC",
-                             weight_layout="HWOI",
+                             weight_layout="HWIO",
                              groups=channels,
                              padding=(1, 1))
         y2 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
                              data_layout="NHWC",
-                             weight_layout="HWOI",
+                             weight_layout="HWIO",
                              groups=channels,
                              padding=(1, 1))
         z = relay.add(y1, y2)
@@ -85,7 +85,7 @@ def test_fold_fwd_dual_path():
                              channels=channels,
                              kernel_size=(3, 3),
                              data_layout="NHWC",
-                             weight_layout="HWOI",
+                             weight_layout="HWIO",
                              groups=channels,
                              padding=(1, 1))
         y2 = relay.nn.conv2d(x,
@@ -93,7 +93,7 @@ def test_fold_fwd_dual_path():
                              channels=channels,
                              kernel_size=(3, 3),
                              data_layout="NHWC",
-                             weight_layout="HWOI",
+                             weight_layout="HWIO",
                              groups=channels,
                              padding=(1, 1))
         z = relay.add(y1, y2)
@@ -147,7 +147,176 @@ def test_fold_fwd_fail():
     check((2, 11, 10, 4), 4)
 
 
+def test_fold_bwd_simple():
+    """Simple testcase."""
+    def before(x, conv_weight, out_bias, out_scale, channels):
+        args = [x, conv_weight, out_bias, out_scale]
+        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
+        y = relay.nn.conv2d(x, conv_weight,
+                            channels=channels,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        y = relay.add(y, out_bias)
+        y = relay.nn.relu(y)
+        y = relay.multiply(y, out_scale)
+        return relay.Function(args, y)
+
+    def expected(x, conv_weight, out_bias, out_scale, channels):
+        # use a fixed order of args so alpha equal check can pass
+        args = [x, conv_weight, out_bias, out_scale]
+        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
+        squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
+        conv_weight = relay.multiply(
+            conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
+
+        y = relay.nn.conv2d(x, conv_weight,
+                            channels=channels,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        out_bias = relay.multiply(out_bias,
+                                  relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
+        y = relay.add(y, out_bias)
+        y = relay.nn.relu(y)
+        return relay.Function(args, y)
+
+    def check(shape, channels):
+        x =  relay.var("x", shape=shape)
+        in_channels = shape[1]
+        weight = relay.var("weight")
+        out_bias = relay.var("out_bias", shape=(channels,))
+        out_scale = relay.var("out_scale", shape=(channels,))
+
+        y1 = before(x, weight, out_bias, out_scale, channels)
+        y1 = relay.ir_pass.infer_type(y1)
+        type_dict = {x.name_hint:x.checked_type for x in y1.params}
+        weight = relay.var("weight", type_dict["weight"])
+        y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
+        y1_expected = expected(x, weight, out_bias, out_scale, channels)
+        assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
+
+    check((2, 4, 10, 10), 8)
+
+
+def test_fold_bwd_dual_path():
+    """Dual path testcase."""
+    def before(x, conv_weight, out_bias, out_scale, channels):
+        args = [x, conv_weight, out_bias, out_scale]
+        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
+        y1 = relay.nn.conv2d(x, conv_weight,
+                             channels=channels,
+                             kernel_size=(3, 3),
+                             padding=(1, 1))
+        y1 = relay.nn.relu(y1)
+        y2 = relay.nn.conv2d(x, conv_weight,
+                             channels=channels,
+                             kernel_size=(3, 3),
+                             padding=(1, 1))
+        y2 = relay.nn.relu(y2)
+        y = relay.add(y1, y2)
+        y = relay.multiply(y, out_scale)
+        return relay.Function(args, y)
+
+    def expected(x, conv_weight, out_bias, out_scale, channels):
+        # use a fixed order of args so alpha equal check can pass
+        args = [x, conv_weight, out_bias, out_scale]
+        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
+        squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
+        def fold_conv_weight():
+            return  relay.multiply(
+                conv_weight ,
+                relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
+        y1 = relay.nn.conv2d(x, fold_conv_weight(),
+                            channels=channels,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        y1 = relay.nn.relu(y1)
+        y2 = relay.nn.conv2d(x, fold_conv_weight(),
+                            channels=channels,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        y2 = relay.nn.relu(y2)
+        y = relay.add(y1, y2)
+        return relay.Function(args, y)
+
+    def check(shape, channels):
+        x =  relay.var("x", shape=shape)
+        in_channels = shape[1]
+        weight = relay.var("weight")
+        out_bias = relay.var("out_bias", shape=(channels,))
+        out_scale = relay.var("out_scale", shape=(channels,))
+
+        y1 = before(x, weight, out_bias, out_scale, channels)
+        y1 = relay.ir_pass.infer_type(y1)
+        type_dict = {x.name_hint:x.checked_type for x in y1.params}
+        weight = relay.var("weight", type_dict["weight"])
+        y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
+        y1_expected = expected(x, weight, out_bias, out_scale, channels)
+        assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
+
+    check((2, 4, 10, 10), 8)
+
+
+def test_fold_bwd_fail():
+    """Dual path testcase."""
+    def fail1(x, conv_weight, out_bias, out_scale, channels):
+        args = [x, conv_weight, out_bias, out_scale]
+        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
+        y1 = relay.nn.conv2d(x, conv_weight,
+                             channels=channels,
+                             kernel_size=(3, 3),
+                             padding=(1, 1))
+        y1 = relay.nn.relu(y1)
+        y2 = relay.nn.conv2d(x, conv_weight,
+                             channels=channels,
+                             kernel_size=(3, 3),
+                             padding=(1, 1),
+                             out_layout="CNHW")
+        # fold will fail because the axis from two path
+        # differs from each other.
+        y2 = relay.nn.relu(y2)
+        y = relay.add(y1, y2)
+        y = relay.multiply(y, out_scale)
+        return relay.Function(args, y)
+
+    def fail2(x, conv_weight, out_bias, out_scale, channels):
+        args = [x, conv_weight, out_bias, out_scale]
+        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
+        y1 = relay.nn.conv2d(x, conv_weight,
+                             channels=channels,
+                             kernel_size=(3, 3),
+                             padding=(1, 1))
+        y2 = relay.nn.relu(y1)
+        # fold will fail because y1 is referred also by y2
+        y1 = relay.multiply(y1, out_scale)
+        y = relay.add(y1, y2)
+        return relay.Function(args, y)
+
+
+    def check(shape, channels, fbefore):
+        x =  relay.var("x", shape=shape)
+        in_channels = shape[1]
+        weight = relay.var("weight")
+        out_bias = relay.var("out_bias", shape=(channels,))
+        out_scale = relay.var("out_scale", shape=(channels,))
+        y1 = fbefore(x, weight, out_bias, out_scale, channels)
+        y1 = relay.ir_pass.infer_type(y1)
+        y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
+        assert relay.ir_pass.alpha_equal(y1_folded, y1)
+
+    check((4, 4, 10, 10), 4, fail1)
+    check((4, 4, 10, 10), 4, fail2)
+
+
 if __name__ == "__main__":
     test_fold_fwd_simple()
     test_fold_fwd_dual_path()
     test_fold_fwd_fail()
+    test_fold_bwd_simple()
+    test_fold_bwd_dual_path()
+    test_fold_bwd_fail()
-- 
GitLab