From 3b8ad0a22870152d6eb4979d7dcbd55f5b960179 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Mon, 17 Apr 2017 18:03:12 -0700
Subject: [PATCH] [SCHEDULE] Normalize returns a new schedule (#94)

---
 include/tvm/schedule.h                        |  7 ++-
 python/tvm/build.py                           |  2 +-
 python/tvm/schedule.py                        | 11 +++--
 src/api/api_lang.cc                           |  2 +-
 src/schedule/schedule_dataflow_rewrite.cc     |  8 ++--
 src/schedule/schedule_lang.cc                 | 46 +++++++++++++++++++
 tests/python/integration/test_dot.py          |  2 +-
 tests/python/integration/test_gemm.py         |  2 +-
 tests/python/unittest/test_codegen_device.py  |  1 +
 .../unittest/test_schedule_bound_inference.py | 20 ++++----
 .../unittest/test_schedule_schedule_ops.py    |  6 +--
 .../integration/test_codegen_verilog.py       |  2 +-
 12 files changed, 86 insertions(+), 23 deletions(-)

diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
index ec1ac4c89..ca78e9e51 100644
--- a/include/tvm/schedule.h
+++ b/include/tvm/schedule.h
@@ -190,6 +190,11 @@ class Schedule : public NodeRef {
    * \param ops The ops to be scheduled.
    */
   explicit Schedule(Array<Operation> ops);
+  /*!
+   * \brief Get a copy of current schedule.
+   * \return The copied schedule.
+   */
+  Schedule copy() const;
   /*!
    * \brief Get the stage corresponds to the op
    * \param op The operation.
@@ -257,7 +262,7 @@ class Schedule : public NodeRef {
    *
    * \return A normalized schedule, can be same as current one.
    */
-  void normalize();
+  Schedule normalize();
   /*!
    * \brief access the internal node container
    * \return the pointer to the internal node container
diff --git a/python/tvm/build.py b/python/tvm/build.py
index f144dc260..6f1580e94 100644
--- a/python/tvm/build.py
+++ b/python/tvm/build.py
@@ -57,7 +57,7 @@ def lower(sch,
         else:
             raise ValueError("args must be Tensor, Buffer or Var")
     # normalize schedule first
-    sch.normalize()
+    sch = sch.normalize()
     bounds = schedule.InferBound(sch)
     stmt = schedule.ScheduleOps(sch, bounds)
     stmt = ir_pass.StorageFlatten(stmt, binds)
diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py
index 96569730d..0783f260b 100644
--- a/python/tvm/schedule.py
+++ b/python/tvm/schedule.py
@@ -78,12 +78,17 @@ class Schedule(NodeBase):
         return self.stage_map[k]
 
     def normalize(self):
-        """Build a normalized schedule.
+        """Build a normalized schedule from the current schedule.
 
         Insert necessary rebase to make certain iter var to start from 0.
         This is needed before bound inference and followup step.
+
+        Returns
+        -------
+        sch : Schedule
+            The normalized schedule.
         """
-        _api_internal._ScheduleNormalize(self)
+        return _api_internal._ScheduleNormalize(self)
 
     def create_group(self, outputs, inputs, include_inputs=False):
         """Create stage group by giving output and input boundary.
@@ -261,7 +266,7 @@ class Stage(NodeBase):
         threads : list of threads
             The threads to be launched.
         """
-        if isinstance(threads, _collections.IterVar):
+        if isinstance(threads, IterVar):
             threads = [threads]
         _api_internal._StageEnvThreads(self, threads)
 
diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc
index 788b7e1ac..15155a232 100644
--- a/src/api/api_lang.cc
+++ b/src/api/api_lang.cc
@@ -311,7 +311,7 @@ TVM_REGISTER_API(_StageParallel)
 
 TVM_REGISTER_API(_ScheduleNormalize)
 .set_body([](TVMArgs args, TVMRetValue* ret) {
-    args[0].operator Schedule()
+    *ret = args[0].operator Schedule()
         .normalize();
   });
 
diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
index 9705f18de..c0debcc29 100644
--- a/src/schedule/schedule_dataflow_rewrite.cc
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -242,9 +242,11 @@ void InjectInline(ScheduleNode* sch) {
   ReplaceDataFlow(sch->stages, &repl);
 }
 
-void Schedule::normalize() {
-  InjectInline(operator->());
-  RebaseNonZeroMinLoop(*this);
+Schedule Schedule::normalize() {
+  Schedule sn = copy();
+  InjectInline(sn.operator->());
+  RebaseNonZeroMinLoop(sn);
+  return sn;
 }
 
 // Handle reduction factor.
diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc
index e31d2418b..4ef21f780 100644
--- a/src/schedule/schedule_lang.cc
+++ b/src/schedule/schedule_lang.cc
@@ -355,6 +355,52 @@ Schedule::Schedule(Array<Operation> ops) {
   }
 }
 
+Stage CopyStage(const Stage& s) {
+  std::shared_ptr<StageNode> n =
+      std::make_shared<StageNode>(*s.operator->());
+  return Stage(n);
+}
+
+Schedule Schedule::copy() const {
+  // map of stages.
+  const ScheduleNode* self = operator->();
+  std::unordered_map<Stage, Stage, NodeHash, NodeEqual> smap;
+  std::shared_ptr<ScheduleNode> n = std::make_shared<ScheduleNode>();
+  n->outputs = self->outputs;
+  // Copy the stages.
+  for (Stage s : self->stages) {
+    Stage scopy = CopyStage(s);
+    smap[s] = scopy;
+    n->stages.push_back(scopy);
+  }
+  for (Stage g : self->groups) {
+    Stage gcopy = CopyStage(g);
+    smap[g] = gcopy;
+    n->groups.push_back(gcopy);
+  }
+  // Remaps the reference relations.
+  for (auto kv : self->stage_map) {
+    n->stage_map.Set(kv.first, smap.at(kv.second));
+  }
+  for (Stage s : n->stages) {
+    if (s->attach_stage.defined()) {
+      s->attach_stage = smap.at(s->attach_stage);
+    }
+    if (s->group.defined()) {
+      s->group = smap.at(s->group);
+    }
+  }
+  for (Stage s : n->groups) {
+    if (s->attach_stage.defined()) {
+      s->attach_stage = smap.at(s->attach_stage);
+    }
+    if (s->group.defined()) {
+      s->group = smap.at(s->group);
+    }
+  }
+  return Schedule(n);
+}
+
 Stage Schedule::operator[](const Operation& op) {
   auto it = (*this)->stage_map.find(op);
   CHECK(it != (*this)->stage_map.end())
diff --git a/tests/python/integration/test_dot.py b/tests/python/integration/test_dot.py
index 2aa3d0ddf..ab0b32e2a 100644
--- a/tests/python/integration/test_dot.py
+++ b/tests/python/integration/test_dot.py
@@ -10,7 +10,7 @@ def lower(s, args, name="mydot"):
         buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name)
         binds[x] = buf
         arg_list.append(buf)
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py
index 4af11364d..fb66c6011 100644
--- a/tests/python/integration/test_gemm.py
+++ b/tests/python/integration/test_gemm.py
@@ -60,7 +60,7 @@ def test_gemm():
 
     max_auto_unroll_step = 0
     # lowering test
-    s.normalize()
+    s = s.normalize()
 
     # one line to build the function.
     def check_device(device, host="stackvm"):
diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py
index 90c28a8af..8548a33fa 100644
--- a/tests/python/unittest/test_codegen_device.py
+++ b/tests/python/unittest/test_codegen_device.py
@@ -16,6 +16,7 @@ def test_add_pipeline():
     s[C].bind(xi, tvm.thread_axis("blockIdx.x"))
 
     # compile to IR
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py
index 21639399e..c47a29546 100644
--- a/tests/python/unittest/test_schedule_bound_inference.py
+++ b/tests/python/unittest/test_schedule_bound_inference.py
@@ -22,6 +22,8 @@ def test_bound2():
     A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
     s = tvm.create_schedule(A2.op)
     xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8)
+    # test normalize not affecting schedule
+    _ = s.normalize()
     s[A1].compute_at(s[A2], yo)
     bounds = tvm.schedule.InferBound(s)
     assert isinstance(bounds, tvm.collections.Map)
@@ -41,6 +43,8 @@ def test_bound3():
     xi0, xi1 = s[A2].split(xi, nparts=16)
     s[A2].bind(xi0, tvm.thread_axis("threadIdx.x"))
     yo, yi = s[A2].split(A2.op.axis[1], 16)
+    # test normalize not affecting schedule
+    _ = s.normalize()
     s[A2].reorder(xo, xi0, yo, xi1, yi)
     s[A1].compute_at(s[A2], yo)
 
@@ -63,7 +67,7 @@ def test_bound_scan():
     XX = s.cache_read(X, "local", s_update)
     xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
     s[XX].compute_at(s[s_update], xo)
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     assert bounds[XX.op.axis[1]].extent.value == 4
@@ -77,7 +81,7 @@ def test_bound_conv1d():
     B = tvm.compute(n, computeB, name='B')
     s = tvm.create_schedule(B.op)
     s[A].compute_at(s[B], B.op.axis[0])
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     assert(bounds[A.op.axis[0]].extent.value == 3)
 
@@ -92,7 +96,7 @@ def test_bound_blur():
     B = tvm.compute((n-2, n-2), computeB, name='B')
     s = tvm.create_schedule(B.op)
     s[A].compute_at(s[B], B.op.axis[1])
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     assert(bounds[A.op.axis[0]].extent.value == 3)
     assert(bounds[A.op.axis[1]].extent.value == 3)
@@ -106,7 +110,7 @@ def test_bound_rfactor():
     s = tvm.create_schedule(B.op)
     kf, ki = s[B].split(k, nparts=4)
     BF = s.rfactor(B, kf)
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
 
     assert(bounds[BF.op.axis[0]].extent.value == 4)
@@ -123,7 +127,7 @@ def test_bound_group_schedule():
     g.compute_at(s[x2], x2.op.axis[0])
     assert s[x1].group == g
     assert s[x].group == g
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     assert bounds[x.op.axis[0]].extent.value == 1
     assert bounds[x.op.axis[1]].extent == n
@@ -141,7 +145,7 @@ def test_bound_nest_group():
     assert s[x1].group == g2
     g2.compute_at(s[x2], x2.op.axis[0])
     g1.compute_at(s[x1], s[x1].op.axis[1])
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     assert bounds[x.op.axis[0]].extent.value == 1
     assert bounds[x.op.axis[1]].extent.value == 1
@@ -169,7 +173,7 @@ def test_bound_nest_thread():
     _, xi = s[A2].split(A2.op.axis[0], nparts=1)
     s[A2].bind(xi, thread_x)
     s[A1].compute_at(s[A3], tx)
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     assert(bounds[A1.op.axis[0]].extent.value==1)
     assert(bounds[A2.op.axis[0]].extent.value==32)
@@ -225,7 +229,7 @@ def test_gemm_bound():
     tx, xi = s[BB].split(xi, nparts=num_thread)
     s[BB].bind(ty, thread_y)
     s[BB].bind(tx, thread_x)
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     assert(bounds[BB.op.axis[0]].extent.value==64)
     assert(bounds[AA.op.axis[0]].extent.value==64)
diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py
index f52f9eca3..ea02c60b8 100644
--- a/tests/python/unittest/test_schedule_schedule_ops.py
+++ b/tests/python/unittest/test_schedule_schedule_ops.py
@@ -51,7 +51,7 @@ def test_schedule_scan():
 
     assert tuple(res.shape) == (m, n)
     s = tvm.create_schedule(res.op)
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     assert(bounds[res.op.scan_axis].min.value == 1)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
@@ -68,7 +68,7 @@ def test_auto_inline():
 
     s = tvm.create_schedule(T2.op)
     tvm.schedule.AutoInlineElemWise(s)
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
 
@@ -83,7 +83,7 @@ def test_inline_mixed():
     xo, xi = s[C].split(C.op.axis[0], factor=8)
     s[A1].compute_at(s[C], xo)
     s[A2].compute_inline()
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     print(stmt)
diff --git a/tests/verilog/integration/test_codegen_verilog.py b/tests/verilog/integration/test_codegen_verilog.py
index 7003fa936..f1994e38c 100644
--- a/tests/verilog/integration/test_codegen_verilog.py
+++ b/tests/verilog/integration/test_codegen_verilog.py
@@ -11,7 +11,7 @@ def lower(s, args, name):
         buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name)
         binds[x] = buf
         arg_list.append(buf)
-    s.normalize()
+    s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
-- 
GitLab