From 93949456eedee33b1a39f59a30189bbce6732247 Mon Sep 17 00:00:00 2001
From: Wuwei Lin <vincentl13x@gmail.com>
Date: Wed, 5 Dec 2018 04:58:16 +0800
Subject: [PATCH] [RELAY][PASS] Check Positiveness in FoldScaleAxis (#2220)

---
 python/tvm/relay/build_module.py              |   5 +-
 src/relay/pass/fold_scale_axis.cc             |  61 +++++++++-
 src/relay/pass/pattern_util.h                 |  51 ++++++++
 .../python/relay/test_pass_fold_scale_axis.py | 112 +++++++++++++-----
 4 files changed, 193 insertions(+), 36 deletions(-)

diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 7af22431a..5b05bc445 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -150,13 +150,14 @@ def optimize(func, params=None):
         func = ir_pass.infer_type(func)
         func = ir_pass.combine_parallel_conv2d(func)
 
+    if cfg.pass_enabled("FoldConstant"):
+        func = ir_pass.fold_constant(func)
+
     if cfg.pass_enabled("FoldScaleAxis"):
         func = ir_pass.infer_type(func)
         func = ir_pass.backward_fold_scale_axis(func)
         func = ir_pass.infer_type(func)
         func = ir_pass.forward_fold_scale_axis(func)
-
-    if cfg.pass_enabled("FoldConstant"):
         func = ir_pass.fold_constant(func)
 
     if cfg.pass_enabled("AlterOpLayout"):
diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc
index 1cd6606bd..9e9dd0604 100644
--- a/src/relay/pass/fold_scale_axis.cc
+++ b/src/relay/pass/fold_scale_axis.cc
@@ -246,9 +246,44 @@ class ForwardPrep : private ExprVisitor {
 // Per operator defs for FScaleAxisForward
 //----------------------------------------------
 
+// Helper functions
+Expr GetForwardScale(const Expr& expr, AxesSet out) {
+  static const Op& multiply = Op::Get("multiply");
+  static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
+
+  const CallNode* call = expr.as<CallNode>();
+  if (!call) return NullValue<Expr>();
+  auto f = fprep.get(call->op, nullptr);
+
+  if (call->op.same_as(multiply)) {
+    const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
+    const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
+    if (MatchBroadcastToLeftAxes(tlhs, trhs, out)) {
+      return call->args[1];
+    } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out)) {
+      return call->args[0];
+    } else {
+      return NullValue<Expr>();
+    }
+  } else if (f != nullptr) {
+    Array<AxesSet> in_axes = f(GetRef<Call>(call), out);
+    for (size_t i = 0; i < call->args.size(); i++) {
+      auto scale = GetForwardScale(call->args[i], in_axes[i]);
+      if (scale.defined()) {
+        return scale;
+      }
+    }
+  }
+  return NullValue<Expr>();
+}
+
 // Intermediate operators
 Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
-  return {out};
+  Expr scale = GetForwardScale(call->args[0], out);
+  if (IsPositiveConstant(scale)) {
+    return {out};
+  }
+  return {NullValue<AxesSet>()};
 }
 
 Expr ReluForwardRewrite(const Call& ref_call,
@@ -755,6 +790,22 @@ RELAY_REGISTER_OP("subtract")
 RELAY_REGISTER_OP("subtract")
 .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform);
 
+// Find relu in the backward path between multiply and conv2d
+bool FindBackwardRelu(const Expr& expr) {
+  const CallNode* call = expr.as<CallNode>();
+  static const Op& conv2d = Op::Get("nn.conv2d");
+  static const Op& relu = Op::Get("nn.relu");
+
+  if (!call) return false;
+  if (call->op.same_as(relu)) return true;
+  if (call->op.same_as(conv2d)) return false;
+
+  for (size_t i = 0; i < call->args.size(); i++) {
+    if (FindBackwardRelu(call->args[i])) return true;
+  }
+  return false;
+}
+
 // Producer operators
 // Multiply produces the scale-axis pair.
 Expr MultiplyBackwardTransform(const Call& call,
@@ -770,12 +821,16 @@ Expr MultiplyBackwardTransform(const Call& call,
     // NOTE we won't recursively call mutating on scale part.
     // since there  won't be scale chance within scale part.
     Expr rhs = call->args[1];
-    if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs)) {
+    if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) &&
+        (!FindBackwardRelu(call->args[0]) ||
+         IsPositiveConstant(call->args[1]))) {
       return transformer->Transform(call->args[0], lhs_axes, rhs);
     }
   } else if (rhs_axes.defined() && rhs_axes.size() != 0) {
     Expr lhs = call->args[0];
-    if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs)) {
+    if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) &&
+        (!FindBackwardRelu(call->args[1]) ||
+         IsPositiveConstant(call->args[0]))) {
       return transformer->Transform(call->args[1], rhs_axes, lhs);
     }
   }
diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h
index e6e8415bd..5d76efd01 100644
--- a/src/relay/pass/pattern_util.h
+++ b/src/relay/pass/pattern_util.h
@@ -190,6 +190,57 @@ Expr MakeConcatenate(Expr data, int axis);
 
 Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
 
+
+template <typename T>
+bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
+  CHECK_EQ(tensor->ctx.device_type, kDLCPU);
+  CHECK(tensor->strides == nullptr);
+  CHECK_EQ(tensor->byte_offset, 0);
+  const T* data = static_cast<const T*>(tensor->data);
+  int64_t num_elems = 1;
+  for (int i = 0; i < tensor->ndim; ++i) {
+    num_elems *= tensor->shape[i];
+  }
+
+  for (int64_t i = 0; i < num_elems; i++) {
+    if (*data < value) {
+      return false;
+    }
+    data++;
+  }
+  return true;
+}
+
+
+inline bool IsPositiveConstant(const Expr& expr) {
+  const auto* constant = expr.as<ConstantNode>();
+  if (!constant) return false;
+  const auto& tensor = constant->data;
+  const auto& dtype = tensor->dtype;
+
+  if (dtype.lanes != 1) {
+    // pass
+  } else if (dtype.code == kDLFloat && dtype.bits == 32) {
+    return IsNDArrayAllGreaterEqual<float>(tensor, 0);
+  } else if (dtype.code == kDLFloat && dtype.bits == 64) {
+    return IsNDArrayAllGreaterEqual<double>(tensor, 0);
+  } else if (dtype.code == kDLInt && dtype.bits == 8) {
+    return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
+  } else if (dtype.code == kDLInt && dtype.bits == 32) {
+    return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
+  } else if (dtype.code == kDLUInt && dtype.bits == 8) {
+    return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
+  } else if (dtype.code == kDLUInt && dtype.bits == 32) {
+    return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
+  }
+
+  LOG(WARNING) << "Unsupported data type (code = " << dtype.code
+               << ", bits = " << dtype.bits << ", lanes = " << dtype.lanes
+               << ")";
+  return false;
+}
+
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_PASS_PATTERN_UTIL_H_
diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py
index e6e008f80..f42aa7b7b 100644
--- a/tests/python/relay/test_pass_fold_scale_axis.py
+++ b/tests/python/relay/test_pass_fold_scale_axis.py
@@ -1,11 +1,11 @@
 from tvm import relay
+import numpy as np
 
 
 def test_fold_fwd_simple():
     """Simple testcase."""
     def before(x, conv_weight, in_bias, in_scale, channels):
-        args = [x, conv_weight, in_bias, in_scale]
-        in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
+        args = [x, conv_weight, in_bias]
         in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
         x = relay.multiply(x, in_scale)
         x = relay.nn.relu(x)
@@ -18,8 +18,7 @@ def test_fold_fwd_simple():
 
     def expected(x, conv_weight, in_bias, in_scale, channels):
         # use a fixed order of args so alpha equal check can pass
-        args = [x, conv_weight, in_bias, in_scale]
-        in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2)
+        args = [x, conv_weight, in_bias]
         in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
         squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
         x = relay.nn.relu(x)
@@ -38,7 +37,7 @@ def test_fold_fwd_simple():
         in_channels = shape[1]
         weight = relay.var("weight")
         in_bias = relay.var("in_bias", shape=(in_channels,))
-        in_scale = relay.var("in_scale", shape=(in_channels,))
+        in_scale = relay.const(np.random.uniform(size=(in_channels, 1, 1)).astype('float32'))
 
         y1 = before(x, weight, in_bias, in_scale, channels)
         y1 = relay.ir_pass.infer_type(y1)
@@ -56,7 +55,7 @@ def test_fold_fwd_simple():
 def test_fold_fwd_dual_path():
     """scale axis being consumed by two consumers"""
     def before(x, conv_weight, in_bias, in_scale, channels):
-        args = [x, conv_weight, in_bias, in_scale]
+        args = [x, conv_weight, in_bias]
         x = relay.multiply(in_scale, x)
         x = relay.nn.relu(x)
         x = relay.subtract(x, in_bias)
@@ -78,7 +77,7 @@ def test_fold_fwd_dual_path():
         return relay.Function(args, z)
 
     def expected(x, conv_weight, in_bias, in_scale, channels):
-        args = [x, conv_weight, in_bias, in_scale]
+        args = [x, conv_weight, in_bias]
         x = relay.nn.relu(x)
         in_bias = relay.divide(in_bias, in_scale)
         x = relay.subtract(x, in_bias)
@@ -108,7 +107,7 @@ def test_fold_fwd_dual_path():
         assert in_channels == channels
         weight = relay.var("weight")
         in_bias = relay.var("in_bias", shape=(in_channels,))
-        in_scale = relay.var("in_scale", shape=(in_channels,))
+        in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32"))
         y1 = before(x, weight, in_bias, in_scale, channels)
         y1 = relay.ir_pass.infer_type(y1)
         y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
@@ -142,7 +141,7 @@ def test_fold_fwd_fail():
         assert in_channels == channels
         weight = relay.var("weight")
         in_bias = relay.var("in_bias", shape=(in_channels,))
-        in_scale = relay.var("in_scale", shape=(in_channels,))
+        in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32"))
         y1 = before(x, weight, in_bias, in_scale, channels)
         y1 = relay.ir_pass.infer_type(y1)
         y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
@@ -151,11 +150,42 @@ def test_fold_fwd_fail():
     check((2, 11, 10, 4), 4)
 
 
+def test_fold_fwd_relu_fail():
+    """testcase where we canont fold because scale can not pass relu"""
+    def before(x, conv_weight, in_bias, in_scale, channels):
+        x = relay.multiply(x, in_scale)
+        xx = relay.nn.relu(x)
+        y1 = relay.nn.conv2d(xx, conv_weight,
+                             channels=channels,
+                             kernel_size=(3, 3),
+                             data_layout="NHWC",
+                             padding=(1, 1))
+        z = relay.add(y1, x)
+        return relay.Function(relay.ir_pass.free_vars(z), z)
+
+    def check(shape, channels, in_scale):
+        x =  relay.var("x", shape=shape)
+        in_channels = shape[-1]
+        # test depthwise
+        assert in_channels == channels
+        weight = relay.var("weight")
+        in_bias = relay.var("in_bias", shape=(in_channels,))
+        in_scale = relay.var("in_scale", shape=(in_channels,))
+        y1 = before(x, weight, in_bias, in_scale, channels)
+        y1 = relay.ir_pass.infer_type(y1)
+        y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
+        assert relay.ir_pass.alpha_equal(y1, y1_folded)
+
+    in_scale = relay.var("in_scale", shape=(4,))
+    check((2, 11, 10, 4), 4, in_scale)
+    in_scale = relay.const(np.random.uniform(size=(4,), low=-1.0, high=0.0)).astype("float32")
+    check((2, 11, 10, 4), 4, in_scale)
+
+
 def test_fold_bwd_simple():
     """Simple testcase."""
     def before(x, conv_weight, out_bias, out_scale, channels):
-        args = [x, conv_weight, out_bias, out_scale]
-        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        args = [x, conv_weight, out_bias]
         out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
         y = relay.nn.conv2d(x, conv_weight,
                             channels=channels,
@@ -168,8 +198,7 @@ def test_fold_bwd_simple():
 
     def expected(x, conv_weight, out_bias, out_scale, channels):
         # use a fixed order of args so alpha equal check can pass
-        args = [x, conv_weight, out_bias, out_scale]
-        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        args = [x, conv_weight, out_bias]
         out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
         squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
         conv_weight = relay.multiply(
@@ -190,7 +219,7 @@ def test_fold_bwd_simple():
         in_channels = shape[1]
         weight = relay.var("weight")
         out_bias = relay.var("out_bias", shape=(channels,))
-        out_scale = relay.var("out_scale", shape=(channels,))
+        out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
 
         y1 = before(x, weight, out_bias, out_scale, channels)
         y1 = relay.ir_pass.infer_type(y1)
@@ -208,9 +237,7 @@ def test_fold_bwd_simple():
 def test_fold_bwd_dual_path():
     """Dual path testcase."""
     def before(x, conv_weight, out_bias, out_scale, channels):
-        args = [x, conv_weight, out_bias, out_scale]
-        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
-        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
+        args = [x, conv_weight, out_bias]
         y1 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
@@ -227,8 +254,7 @@ def test_fold_bwd_dual_path():
 
     def expected(x, conv_weight, out_bias, out_scale, channels):
         # use a fixed order of args so alpha equal check can pass
-        args = [x, conv_weight, out_bias, out_scale]
-        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        args = [x, conv_weight, out_bias]
         out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
         squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
         def fold_conv_weight():
@@ -253,7 +279,7 @@ def test_fold_bwd_dual_path():
         in_channels = shape[1]
         weight = relay.var("weight")
         out_bias = relay.var("out_bias", shape=(channels,))
-        out_scale = relay.var("out_scale", shape=(channels,))
+        out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
 
         y1 = before(x, weight, out_bias, out_scale, channels)
         y1 = relay.ir_pass.infer_type(y1)
@@ -270,8 +296,7 @@ def test_fold_bwd_dual_path():
 
 def test_fold_bwd_dual_consumer():
     def before(x, conv_weight, out_bias, out_scale, channels):
-        args = [x, conv_weight, out_bias, out_scale]
-        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        args = [x, conv_weight, out_bias]
         y0 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
@@ -298,8 +323,7 @@ def test_fold_bwd_dual_consumer():
 
     def expected(x, conv_weight, out_bias, out_scale, channels):
         # use a fixed order of args so alpha equal check can pass
-        args = [x, conv_weight, out_bias, out_scale]
-        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        args = [x, conv_weight, out_bias]
         def fold_conv_weight():
             squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
             return  relay.multiply(
@@ -328,7 +352,7 @@ def test_fold_bwd_dual_consumer():
         in_channels = shape[1]
         weight = relay.var("weight")
         out_bias = relay.var("out_bias", shape=(channels,))
-        out_scale = relay.var("out_scale", shape=(channels,))
+        out_scale = relay.const(np.random.uniform(size=(channels,1, 1)).astype("float32"))
 
         y1 = before(x, weight, out_bias, out_scale, channels)
         y1 = relay.ir_pass.infer_type(y1)
@@ -346,8 +370,7 @@ def test_fold_bwd_dual_consumer():
 def test_fold_bwd_fail():
     """Dual path testcase."""
     def fail1(x, conv_weight, out_bias, out_scale, channels):
-        args = [x, conv_weight, out_bias, out_scale]
-        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        args = [x, conv_weight, out_bias]
         out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
         y1 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
@@ -367,8 +390,7 @@ def test_fold_bwd_fail():
         return relay.Function(args, y)
 
     def fail2(x, conv_weight, out_bias, out_scale, channels):
-        args = [x, conv_weight, out_bias, out_scale]
-        out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2)
+        args = [x, conv_weight, out_bias]
         out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
         y1 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
@@ -380,13 +402,12 @@ def test_fold_bwd_fail():
         y = relay.add(y1, y2)
         return relay.Function(args, y)
 
-
     def check(shape, channels, fbefore):
         x =  relay.var("x", shape=shape)
         in_channels = shape[1]
         weight = relay.var("weight")
         out_bias = relay.var("out_bias", shape=(channels,))
-        out_scale = relay.var("out_scale", shape=(channels,))
+        out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32"))
         y1 = fbefore(x, weight, out_bias, out_scale, channels)
         y1 = relay.ir_pass.infer_type(y1)
         y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
@@ -396,11 +417,40 @@ def test_fold_bwd_fail():
     check((4, 4, 10, 10), 4, fail2)
 
 
+def test_fold_bwd_relu_fail():
+    """testcase where we canont fold because scale can not pass relu"""
+    def before(x, conv_weight, out_scale, channels):
+        y = relay.nn.conv2d(x, conv_weight,
+                             channels=channels,
+                             kernel_size=(3, 3),
+                             data_layout="NCHW",
+                             padding=(1, 1))
+        y = relay.nn.relu(y)
+        y = relay.multiply(x, out_scale)
+        return relay.Function(relay.ir_pass.free_vars(y), y)
+
+    def check(shape, channels, out_scale):
+        x =  relay.var("x", shape=shape)
+        in_channels = shape[1]
+        weight = relay.var("weight")
+        y1 = before(x, weight, out_scale, channels)
+        y1 = relay.ir_pass.infer_type(y1)
+        y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
+        assert relay.ir_pass.alpha_equal(y1, y1_folded)
+
+    out_scale = relay.var("in_scale", shape=(4, 1, 1))
+    check((4, 4, 10, 10), 4, out_scale)
+    out_scale = relay.const(np.random.uniform(size=(4, 1, 1), low=-1.0, high=0.0)).astype("float32")
+    check((4, 4, 10, 10), 4, out_scale)
+
+
 if __name__ == "__main__":
     test_fold_fwd_simple()
     test_fold_fwd_dual_path()
     test_fold_fwd_fail()
+    test_fold_fwd_relu_fail()
     test_fold_bwd_simple()
     test_fold_bwd_dual_path()
     test_fold_bwd_dual_consumer()
     test_fold_bwd_fail()
+    test_fold_bwd_relu_fail()
-- 
GitLab