From 56c50d2d072d976b058ba7e29a276e4e9b1e8839 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng <mercy_zheng@sjtu.edu.cn>
Date: Wed, 22 Aug 2018 20:19:43 -0700
Subject: [PATCH] trigger ci (#1620)

---
 topi/python/topi/x86/nn.py | 43 +++++++++++++++++---------------------
 1 file changed, 19 insertions(+), 24 deletions(-)

diff --git a/topi/python/topi/x86/nn.py b/topi/python/topi/x86/nn.py
index 03e07222c..6802d4c01 100644
--- a/topi/python/topi/x86/nn.py
+++ b/topi/python/topi/x86/nn.py
@@ -2,8 +2,9 @@
 """x86 nn operators"""
 from __future__ import absolute_import as _abs
 import tvm
+
 from .. import generic
-from .. import tag
+from ..util import traverse_inline
 
 @generic.schedule_softmax.register(["cpu"])
 def schedule_softmax(outs):
@@ -53,44 +54,38 @@ def schedule_dense(outs):
 
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
-    scheduled_ops = []
-
-    def traverse(op):
-        """Traverse operators from computation graph"""
-        # inline all one-to-one-mapping operators except the last stage (output)
-        if tag.is_broadcast(op.tag):
-            if op not in s.outputs:
-                s[op].compute_inline()
-            for tensor in op.input_tensors:
-                if tensor.op.input_tensors and tensor.op not in scheduled_ops:
-                    traverse(tensor.op)
 
+    def _callback(op):
         if 'dense' in op.tag:
-            C = op.output(0)
-            x, y = C.op.axis
+            output = outs[0]
+            dense = op.output(0)
 
             # Write cache for blocks
-            CC = s.cache_write(C, 'global')
+            if dense.op in s.outputs:
+                CC = s.cache_write(dense, 'local')
+            else:
+                CC = dense
 
             # Tile
             bnx = 1
             bny = 4
-            _, yo, _, yi = s[C].tile(x, y, bnx, bny)
-            s[CC].compute_at(s[C], yo)
+            x, y = output.op.axis
+            xo, yo, xi, yi = s[output].tile(x, y, bnx, bny)
+
             xc, yc = s[CC].op.axis
             k, = s[CC].op.reduce_axis
             ko, ki = s[CC].split(k, factor=4)
             s[CC].reorder(ko, xc, ki, yc)
+
             s[CC].unroll(ki)
             s[CC].vectorize(yc)
 
-            # Vectorization
-            s[C].vectorize(yi)
-
-            # Parallelization
-            s[C].parallel(yo)
+            s[output].unroll(xi)
+            s[output].vectorize(yi)
 
-        scheduled_ops.append(op)
+            fused = s[output].fuse(xo, yo)
+            s[output].parallel(fused)
+            s[CC].compute_at(s[output], fused)
 
-    traverse(outs[0].op)
+    traverse_inline(s, outs[0].op, _callback)
     return s
-- 
GitLab