From 6cd5a8f991eb8ba3cdd27a063067af733aecc9a2 Mon Sep 17 00:00:00 2001
From: masahi <masahi129@gmail.com>
Date: Fri, 17 Aug 2018 11:37:56 +0900
Subject: [PATCH] [NNVM] Bug fix Prevent fusing convolution with injective op 
 (#1608)

---
 nnvm/src/compiler/graph_fuse.cc              | 31 +++++++++++++++++-
 nnvm/tests/python/compiler/test_op_fusion.py | 34 ++++++++++++++++++++
 2 files changed, 64 insertions(+), 1 deletion(-)

diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc
index 52a8ae44f..f65312be1 100644
--- a/nnvm/src/compiler/graph_fuse.cc
+++ b/nnvm/src/compiler/graph_fuse.cc
@@ -63,12 +63,16 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
       // Check if we can fuse to the master.
       int chosen_master = -1;
       bool ewise = inode.source->num_outputs() == 1;
+      bool mark_as_injective = false;
       for (const auto& e : inode.inputs) {
         if (fuse_vec[e.node_id] == FuseRule::kUknown) {
           TOpPattern ipt = pattern_vec[e.node_id];
           if (ipt != kElemWise) ewise = false;
-          if (ipt <= kInjective) {
+          if (ipt <= kBroadcast) {
+            fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
+          } else if (ipt == kInjective) {
             fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
+            mark_as_injective = true;
           } else if (ipt == kOutEWiseFusable &&
                      chosen_master == -1 &&
                      shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) {
@@ -87,6 +91,8 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
       master_vec[nid] = chosen_master;
       if (chosen_master != -1) {
         pt = kOutEWiseFusable;
+      } else if (mark_as_injective) {
+        pt = kInjective;
       } else {
         pt = ewise ? kElemWise : kBroadcast;
       }
@@ -135,8 +141,31 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
     if (group_vec[nid] == -1) {
       group_vec[nid] = nid;
     }
+
+    // Check if injective op and out_ewise_fusable op (e.g. conv2d) are in the same group.
+    bool parent_out_ewise = false;
+    bool parent_injective = false;
+    for (const auto& e : inode.inputs) {
+      TOpPattern pt = pattern_vec[e.node_id];
+      if (pt == kOutEWiseFusable) {
+        parent_out_ewise = true;
+      } else if (pt == kInjective) {
+        parent_injective = true;
+      }
+    }
+    // Change the master node from out_ewise_fusable op to itself
+    if (parent_injective && parent_out_ewise) master_vec[nid] = nid;
+
     // Propagate the group id.
     for (const auto& e : inode.inputs) {
+      TOpPattern pt = pattern_vec[e.node_id];
+      if (parent_out_ewise && parent_injective) {
+        if (pt == kOutEWiseFusable) {
+          continue;  // Do not fuse out_ewise_fusable op
+        } else if (pt == kInjective) {
+          master_vec[e.node_id] = nid;
+        }
+      }
       if (fuse_vec[e.node_id] == FuseRule::kFuseToMaster) {
         CHECK(group_vec[e.node_id] == -1||
               group_vec[e.node_id] == group_vec[nid]);
diff --git a/nnvm/tests/python/compiler/test_op_fusion.py b/nnvm/tests/python/compiler/test_op_fusion.py
index 8d05ae02c..5f4da3865 100644
--- a/nnvm/tests/python/compiler/test_op_fusion.py
+++ b/nnvm/tests/python/compiler/test_op_fusion.py
@@ -77,6 +77,39 @@ def test_injective_reduce_injective():
         np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
 
 
+def test_injective_conv2d():
+    channels = 16
+    data = sym.Variable(name="data")
+    pool = sym.global_avg_pool2d(data=data)
+    weight = sym.reshape(pool, shape=[1, channels, 1, 1])
+    residual = sym.conv2d(data=data, kernel_size=(3,3), channels=channels, padding=(1, 1),
+                          layout="NCHW", kernel_layout="OIHW", use_bias=False, name="conv")
+    net = weight * data + residual
+    size = 56
+    dtype="float32"
+    dshape = (1, channels, size, size)
+    kshape = (channels, channels, 3, 3)
+    oshape = dshape
+    shape_dict = {"data": dshape}
+
+    for target, ctx in ctx_list():
+        graph, lib, _ = nnvm.compiler.build(net, target, shape_dict)
+        # data, global_avg_pool, conv weight, conv op, fused elemwise add
+        assert graph.index.num_nodes == 5
+
+        data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
+        kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
+        m = graph_runtime.create(graph, lib, ctx)
+        m.run(data=data, conv_weight=kernel)
+        # get output
+        out = m.get_output(0, tvm.nd.empty(oshape, dtype))
+        residual = topi.testing.conv2d_nchw_python(
+            data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
+        weight = np.mean(data.asnumpy(), axis=(2, 3))
+        c_np = weight[:, :, np.newaxis, np.newaxis] * data.asnumpy() + residual
+        np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
+
+
 def build_and_run(sym, params, data, out_shape, target, ctx, opt_level=2):
     with nnvm.compiler.build_config(opt_level=opt_level):
         graph, lib, params = nnvm.compiler.build(sym, target, shape={"data":data.shape}, params=params)
@@ -123,3 +156,4 @@ if __name__ == "__main__":
     test_ewise_injective()
     test_conv_ewise_injective()
     test_fuse_conv2d_elu()
+    test_injective_conv2d()
-- 
GitLab