From de6dd0cb10b56e3a3186e900553e6811a8806e12 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sun, 7 May 2017 08:44:53 -0700
Subject: [PATCH] [BUGFIX] Fix schedule dataflow rewrite with multiple scan
 states (#126)

---
 python/tvm/intrin.py                          | 32 +++++++++++++++++++
 src/pass/storage_flatten.cc                   |  4 ++-
 src/schedule/schedule_dataflow_rewrite.cc     | 15 +++++----
 src/schedule/schedule_ops.cc                  |  8 +++--
 tests/python/perf/rnn_matexp.py               |  4 +--
 .../unittest/test_schedule_schedule_ops.py    | 26 ++++++++++++++-
 6 files changed, 75 insertions(+), 14 deletions(-)

diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py
index 4476fe0c3..2900cfdff 100644
--- a/python/tvm/intrin.py
+++ b/python/tvm/intrin.py
@@ -118,6 +118,38 @@ def exp(x):
     return call_pure_intrin(x.dtype, "exp", x)
 
 
+def tanh(x):
+    """Take hyperbolic tanh of input x.
+
+    Parameters
+    ----------
+    x : Expr
+        Input argument.
+
+    Returns
+    -------
+    y : Expr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "tanh", x)
+
+
+def sigmoid(x):
+    """Quick function to get sigmoid
+
+    Parameters
+    ----------
+    x : Expr
+        Input argument.
+
+    Returns
+    -------
+    y : Expr
+        The result.
+    """
+    return 1.0 / (1.0 + exp(-x))
+
+
 def log(x):
     """Take log of input x.
 
diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc
index a90c86758..ff4c5912d 100644
--- a/src/pass/storage_flatten.cc
+++ b/src/pass/storage_flatten.cc
@@ -90,7 +90,9 @@ class StorageFlattener : public IRMutator {
       buf_map_[key].released = true;
       // deduce current storage scope.
       auto it = storage_scope_.find(op->func.get());
-      CHECK(it != storage_scope_.end());
+      CHECK(it != storage_scope_.end())
+          << "Cannot find storage scope of " << op->func
+          << " value_index=" << op->value_index;
       StorageScope skey;
       const std::string& strkey = it->second;
       if (strkey.length() == 0) {
diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
index 7534c7019..70b68e358 100644
--- a/src/schedule/schedule_dataflow_rewrite.cc
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -231,15 +231,16 @@ void InjectInline(ScheduleNode* sch) {
   std::unordered_map<Tensor, Tensor> repl;
   // rewrite dataflow
   for (size_t i = 0; i < sch->stages.size(); ++i) {
-    if (new_body[i].defined() &&
-        !new_body[i].same_as(sch->stages[i]->op)) {
+    if (new_body[i].defined()) {
       const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
       CHECK(compute);
-      Operation op = ComputeOpNode::make(
-          compute->name, compute->axis, new_body[i]);
-      repl[sch->stages[i]->op.output(0)] = op.output(0);
-      Stage s = sch->stages[i];
-      s->op = op;
+      if (!new_body[i].same_as(compute->body)) {
+        Operation op = ComputeOpNode::make(
+            compute->name, compute->axis, new_body[i]);
+        Stage s = sch->stages[i];
+        repl[s->op.output(0)] = op.output(0);
+        s->op = op;
+      }
     }
   }
   ReplaceDataFlow(sch->stages, &repl);
diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc
index dab1318c1..edbc2878a 100644
--- a/src/schedule/schedule_ops.cc
+++ b/src/schedule/schedule_ops.cc
@@ -252,9 +252,11 @@ class SchedulePostProc : public IRMutator {
       }
       // This must be checked for all ops, including scan.
       if (!s->op.same_as(s->origin_op)) {
-        Tensor target = s->origin_op.output(0);
-        AddReplace(s->op.output(0), target,
-                   target, s->origin_op);
+        for (int i = 0; i < s->op->num_outputs(); ++i) {
+          Tensor target = s->origin_op.output(0);
+          AddReplace(s->op.output(i), target,
+                     target, s->origin_op);
+        }
       }
       // Specially add replacements for scan op.
       if (s->op.as<ScanOpNode>()) {
diff --git a/tests/python/perf/rnn_matexp.py b/tests/python/perf/rnn_matexp.py
index 1d58d1faf..c7a939ef9 100644
--- a/tests/python/perf/rnn_matexp.py
+++ b/tests/python/perf/rnn_matexp.py
@@ -126,11 +126,11 @@ def rnn_matexp():
         Whh_a = tvm.nd.array(Whh_np, ctx)
         # Skip first pass as it is compilation
         f(res_a, Whh_a)
-        tvm.nd.sync(ctx)
+        ctx.sync()
         # measure time cost of second step.
         tstart = time.time()
         f(res_a, Whh_a)
-        tvm.nd.sync(ctx)
+        ctx.sync()
         tgap = time.time() - tstart
         print("Time cost=%g" % tgap)
         # correctness
diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py
index ea02c60b8..76b3d9622 100644
--- a/tests/python/unittest/test_schedule_schedule_ops.py
+++ b/tests/python/unittest/test_schedule_schedule_ops.py
@@ -86,7 +86,30 @@ def test_inline_mixed():
     s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
-    print(stmt)
+    def check(x):
+        if isinstance(x, tvm.expr.Call):
+            assert x.func != A2
+    tvm.ir_pass.PostOrderVisit(s[C].op.body, check)
+
+
+def test_scan_inline():
+    m = tvm.var("m")
+    n = tvm.var("n")
+    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
+    s_state1 = tvm.placeholder((m, n))
+    s_state2 = tvm.placeholder((m, n))
+    s_init1 = tvm.compute((1, n), lambda _, i: x[0, i])
+    s_init2 = tvm.compute((1, n), lambda _, i: x[0, i])
+    s_x1 = tvm.compute((m, n), lambda t, i: s_state1[t-1, i] + x[t, i], name="x1")
+    s_x2 = tvm.compute((m, n), lambda t, i: s_state2[t-1, i] + 1 , name="x2")
+    s_update1 = tvm.compute((m, n), lambda t, i: s_x1[t, i], "u1")
+    s_update2 = tvm.compute((m, n), lambda t, i: s_x2[t, i], "u2")
+    res1, res2 = tvm.scan([s_init1, s_init2],
+                          [s_update1, s_update2],
+                          [s_state1, s_state2])
+    s = tvm.create_schedule(res1.op)
+    s[s_x1].compute_inline()
+    stmt = tvm.lower(s, [x, res1, res2], with_api_wrapper=False)
 
 
 def test_schedule_cache():
@@ -105,6 +128,7 @@ def test_schedule_cache():
 
 
 if __name__ == "__main__":
+    test_scan_inline()
     test_inline_mixed()
     test_auto_inline()
     test_schedule_scan()
-- 
GitLab