From b19e01bf27ecc0c4c0b0f461588bcddd58dd7d65 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sat, 11 Mar 2017 20:12:42 -0800
Subject: [PATCH] [PASS] RemoveNoOp. (#68)

---
 include/tvm/buffer.h                          |   1 -
 include/tvm/channel.h                         |  54 +++++
 include/tvm/expr.h                            |   1 +
 include/tvm/ir.h                              |  11 +-
 include/tvm/ir_pass.h                         |  14 ++
 src/api/api_pass.cc                           |   2 +
 src/lang/channel.cc                           |  22 ++
 src/pass/remove_no_op.cc                      | 111 ++++++++++
 src/pass/simple_passes.cc                     |   1 +
 src/pass/split_host_device.cc                 |  10 +-
 src/pass/split_pipeline.cc                    | 194 ++++++++++++++++++
 src/schedule/schedule_ops.cc                  |   6 +-
 .../python/unittest/test_pass_remove_no_op.py |  29 +++
 .../unittest/test_pass_split_pipeline.py      |  31 +++
 14 files changed, 477 insertions(+), 10 deletions(-)
 create mode 100644 include/tvm/channel.h
 create mode 100644 src/lang/channel.cc
 create mode 100644 src/pass/remove_no_op.cc
 create mode 100644 src/pass/split_pipeline.cc
 create mode 100644 tests/python/unittest/test_pass_remove_no_op.py
 create mode 100644 tests/python/unittest/test_pass_split_pipeline.py

diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h
index 141e4b68f..9f266844f 100644
--- a/include/tvm/buffer.h
+++ b/include/tvm/buffer.h
@@ -1,4 +1,3 @@
-
 /*!
  *  Copyright (c) 2016 by Contributors
  * \file buffer.h
diff --git a/include/tvm/channel.h b/include/tvm/channel.h
new file mode 100644
index 000000000..81f9e5248
--- /dev/null
+++ b/include/tvm/channel.h
@@ -0,0 +1,54 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file channel.h
+ * \brief Channel object for pipeline.
+ */
+#ifndef TVM_CHANNEL_H_
+#define TVM_CHANNEL_H_
+
+#include <tvm/expr.h>
+
+namespace tvm {
+// Node container of channel
+struct ChannelNode;
+
+/*! \brief The data channel. */
+class Channel : public NodeRef {
+ public:
+  /*! \brief default constructor  */
+  Channel() {}
+  explicit Channel(std::shared_ptr<Node> n) : NodeRef(n) {}
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const ChannelNode* operator->() const;
+};
+
+/*!
+ * \brief Generalized FIFO channel.
+ */
+struct ChannelNode : public Node {
+  /*! \brief Variable to channel handle */
+  Var handle_var;
+  /*! \brief default data type in read/write */
+  Type dtype;
+
+  // visit all attributes
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("handle_var", &handle_var);
+    v->Visit("dtype", &dtype);
+  }
+
+  static Channel make(Var handle_var, Type dtype);
+  static constexpr const char* _type_key = "Channel";
+
+  TVM_DECLARE_NODE_TYPE_INFO(ChannelNode, Node);
+};
+
+// Inline implementations
+inline const ChannelNode* Channel::operator->() const {
+  return static_cast<const ChannelNode*>(node_.get());
+}
+}  // namespace tvm
+#endif  // TVM_CHANNEL_H_
diff --git a/include/tvm/expr.h b/include/tvm/expr.h
index 761cd2b04..8d100d272 100644
--- a/include/tvm/expr.h
+++ b/include/tvm/expr.h
@@ -39,6 +39,7 @@ using Halide::Internal::as_const_int;
 using Halide::Internal::as_const_uint;
 using Halide::Internal::const_true;
 using Halide::Internal::const_false;
+using Halide::Internal::is_no_op;
 
 inline Type TVMType2Type(TVMType t) {
   return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
diff --git a/include/tvm/ir.h b/include/tvm/ir.h
index 153d3105f..6b7ba2927 100644
--- a/include/tvm/ir.h
+++ b/include/tvm/ir.h
@@ -90,9 +90,7 @@ constexpr const char* virtual_thread = "virtual_thread";
  * \brief Mark storage scope of buffers
  */
 constexpr const char* storage_scope = "storage_scope";
-/*!
- * \brief Mark storage scope of realizations
- */
+/*! \brief Mark storage scope of realization */
 constexpr const char* realize_scope = "realize_scope";
 /*! \brief Mark of loop scope */
 constexpr const char* loop_scope = "loop_scope";
@@ -100,6 +98,13 @@ constexpr const char* loop_scope = "loop_scope";
 constexpr const char* scan_update_scope = "scan_update_scope";
 /*! \brief Mark of scan init scope */
 constexpr const char* scan_init_scope = "scan_init_scope";
+// Pipeline related attributes
+/*! \brief channel read scope */
+constexpr const char* channel_read_scope = "channel_read_scope";
+/*! \brief channel write scope */
+constexpr const char* channel_write_scope = "channel_write_scope";
+/*! \brief pipeline module scope */
+constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
 }  // namespace attr
 
 /*! \brief namespace of TVM Intrinsic functions */
diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index f1ee06188..8f71ad145 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -105,6 +105,20 @@ Stmt Inline(Stmt stmt,
 Stmt StorageFlatten(Stmt stmt,
                     Map<Tensor, Buffer> extern_buffer);
 
+/*!
+ * \brief Remove No Op from the Stmt.
+ * \param stmt The stmt to be trasnformed
+ * \return Transformed stmt.
+ */
+Stmt RemoveNoOp(Stmt stmt);
+
+/*!
+ * \brief Split statement into pipeine stages.
+ * \param stmt The stmt to be splitted
+ * \return Transformed stmt.
+ */
+Stmt SplitPipeline(Stmt stmt);
+
 /*!
  * \brief unroll the constant loops
  * \param stmt The statment to be unrolled.
diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc
index f995f13d1..82de8addf 100644
--- a/src/api/api_pass.cc
+++ b/src/api/api_pass.cc
@@ -70,6 +70,8 @@ REGISTER_PASS1(SplitHostDevice);
 REGISTER_PASS1(LiftAllocate);
 REGISTER_PASS1(InjectVirtualThread);
 REGISTER_PASS1(LoopPartition);
+REGISTER_PASS1(RemoveNoOp);
+REGISTER_PASS1(SplitPipeline);
 
 }  // namespace ir
 }  // namespace tvm
diff --git a/src/lang/channel.cc b/src/lang/channel.cc
new file mode 100644
index 000000000..dd850becf
--- /dev/null
+++ b/src/lang/channel.cc
@@ -0,0 +1,22 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file channel.cc
+ */
+#include <tvm/channel.h>
+
+namespace tvm {
+
+Channel ChannelNode::make(Var handle_var, Type dtype) {
+  auto n = std::make_shared<ChannelNode>();
+  n->handle_var = handle_var;
+  n->dtype = dtype;
+  return Channel(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ChannelNode>([](const ChannelNode *op, IRPrinter *p) {
+    p->stream << "channel(" << op->handle_var << ", " << op->dtype << ")";
+});
+
+TVM_REGISTER_NODE_TYPE(ChannelNode);
+}  // namespace tvm
diff --git a/src/pass/remove_no_op.cc b/src/pass/remove_no_op.cc
new file mode 100644
index 000000000..9709ae1b8
--- /dev/null
+++ b/src/pass/remove_no_op.cc
@@ -0,0 +1,111 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file remove_no_op.cc
+ * \brief Remove no op from the stmt
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_mutator.h>
+#include <unordered_map>
+
+namespace tvm {
+namespace ir {
+
+// Mark the statment of each stage.
+class NoOpRemover : public IRMutator {
+ public:
+  Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<LetStmt>();
+    return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
+  }
+  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<AttrStmt>();
+    return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
+  }
+  Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<IfThenElse>();
+    if (op->else_case.defined()) {
+      if (is_no_op(op->else_case)) {
+        if (is_no_op(op->then_case)) {
+          return MakeEvaluate(op->condition);
+        } else {
+          return IfThenElse::make(op->condition, op->then_case);
+        }
+      } else {
+        return stmt;
+      }
+    } else {
+      if (is_no_op(op->then_case)) {
+        return MakeEvaluate(op->condition);
+      } else {
+        return stmt;
+      }
+    }
+  }
+  Stmt Mutate_(const For* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<For>();
+    return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt;
+  }
+  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<Allocate>();
+    return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt;
+  }
+  Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<ProducerConsumer>();
+    return is_no_op(op->body) ? op->body : stmt;
+  }
+  Stmt Mutate_(const Realize* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<Realize>();
+    return is_no_op(op->body) ? op->body : stmt;
+  }
+  Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
+    if (HasSideEffect(op->value)) return s;
+    return Evaluate::make(0);
+  }
+  Stmt Mutate_(const Block* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<Block>();
+    if (is_no_op(op->first)) {
+      return op->rest;
+    } else if (is_no_op(op->rest)) {
+      return op->first;
+    } else {
+      return stmt;
+    }
+  }
+
+ private:
+  Stmt MakeEvaluate(Expr value) {
+    if (HasSideEffect(value)) {
+      return Evaluate::make(value);
+    } else {
+      return Evaluate::make(0);
+    }
+  }
+  Stmt MakeEvaluate(const Array<Expr>& values) {
+    Stmt stmt;
+    for (Expr e : values) {
+      if (HasSideEffect(e)) {
+        if (stmt.defined()) {
+          stmt = Block::make(stmt, Evaluate::make(e));
+        } else {
+          stmt = Evaluate::make(e);
+        }
+      }
+    }
+    return stmt.defined() ? stmt : Evaluate::make(0);
+  }
+};
+
+Stmt RemoveNoOp(Stmt stmt) {
+  return NoOpRemover().Mutate(stmt);
+}
+}  // namespace ir
+}  // namespace tvm
diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc
index 5fc928cdd..70af63ce7 100644
--- a/src/pass/simple_passes.cc
+++ b/src/pass/simple_passes.cc
@@ -48,6 +48,7 @@ class IRSubstitue : public IRMutator {
 };
 
 Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
+  if (value_map.size() == 0) return stmt;
   IRSubstitue m;
   for (auto kv : value_map) {
     m.smap[kv.first.get()] = kv.second;
diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc
index 642c1ed12..d64eff792 100644
--- a/src/pass/split_host_device.cc
+++ b/src/pass/split_host_device.cc
@@ -5,6 +5,7 @@
  */
 #include <tvm/ir.h>
 #include <tvm/lowered_func.h>
+#include <tvm/channel.h>
 #include <tvm/ir_pass.h>
 #include <tvm/ir_mutator.h>
 #include <tvm/runtime/module.h>
@@ -17,7 +18,7 @@ namespace ir {
 class IRUseDefAnalysis : public IRMutator {
  public:
   Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
-    if (op->type_key == "thread_extent") {
+    if (op->type_key == attr::thread_extent) {
       IterVar iv(op->node.node_);
       CHECK_NE(iv->thread_tag.length(), 0U);
       // thread_extent can appear multiple times
@@ -35,6 +36,13 @@ class IRUseDefAnalysis : public IRMutator {
       Stmt body = this->Mutate(op->body);
       if (value.same_as(value) && body.same_as(body)) return s;
       return AttrStmt::make(op->node, op->type_key, value, body);
+    } else if (op->type_key == attr::channel_write_scope ||
+               op->type_key == attr::channel_read_scope) {
+      Channel ch(op->node.node_);
+      if (!use_count_.count(ch->handle_var.get())) {
+        this->HandleDef(ch->handle_var.get());
+      }
+      return IRMutator::Mutate_(op, s);
     } else {
       return IRMutator::Mutate_(op, s);
     }
diff --git a/src/pass/split_pipeline.cc b/src/pass/split_pipeline.cc
new file mode 100644
index 000000000..93b3b86ed
--- /dev/null
+++ b/src/pass/split_pipeline.cc
@@ -0,0 +1,194 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file split_pipeline.cc
+ * \brief Split statement into pipeline stage modules.
+ */
+#include <tvm/ir.h>
+#include <tvm/expr.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_visitor.h>
+#include <tvm/ir_mutator.h>
+#include <tvm/channel.h>
+#include <unordered_map>
+#include "./ir_util.h"
+
+namespace tvm {
+namespace ir {
+
+class MarkChannelAccess : public IRMutator {
+ public:
+  MarkChannelAccess(
+      const std::unordered_map<const Variable*, Channel>& cmap)
+      : cmap_(cmap) {}
+
+  Expr Mutate_(const Load *op, const Expr& e) final {
+    auto it = rmap_.find(op->buffer_var.get());
+    if (it != rmap_.end()) {
+      ++it->second.read_count;
+    }
+    return IRMutator::Mutate_(op, e);
+  }
+  Stmt Mutate_(const Store *op, const Stmt& s) final {
+    auto it = rmap_.find(op->buffer_var.get());
+    if (it != rmap_.end()) {
+      ++it->second.write_count;
+    }
+    return IRMutator::Mutate_(op, s);
+  }
+  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
+    if (cmap_.count(op->buffer_var.get())) {
+      CHECK(!rmap_.count(op->buffer_var.get()));
+      rmap_[op->buffer_var.get()] = Entry();
+      Stmt body = Mutate(op->body);
+      body = CreateChannelAccess(op, body);
+      rmap_.erase(op->buffer_var.get());
+      return body;
+    } else {
+      return IRMutator::Mutate_(op, s);
+    }
+  }
+  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
+    if (op->type_key == ir::attr::storage_scope) {
+      Var buf_var(op->node.node_);
+      if (cmap_.count(buf_var.get())) return Mutate(op->body);
+    }
+    return IRMutator::Mutate_(op, s);
+  }
+
+ private:
+  // Create channel access wrap
+  Stmt CreateChannelAccess(const Allocate* op, Stmt body) {
+    const Entry& rw = rmap_.at(op->buffer_var.get());
+    CHECK(rw.write_count == 0 || rw.read_count == 0)
+        << "Cannot read/write to the same channel " << op->buffer_var
+        <<  " body:" << body;
+    if (rw.write_count == 0 && rw.read_count == 0) {
+      return body;
+    }
+    const Channel& ch = cmap_.at(op->buffer_var.get());
+    int32_t csize = op->constant_allocation_size();
+    Expr alloc_size;
+    if (csize > 0) {
+      alloc_size = IntImm::make(Int(32), csize);
+    } else {
+      alloc_size = op->extents[0];
+      for (size_t i = 1; i < op->extents.size(); ++i) {
+        alloc_size *= op->extents[i];
+      }
+      alloc_size = ir::Simplify(alloc_size);
+    }
+
+    if (rw.write_count) {
+      return AttrStmt::make(
+          ch, ir::attr::channel_write_scope, alloc_size, body);
+    } else {
+      CHECK(rw.read_count);
+      return AttrStmt::make(
+          ch, ir::attr::channel_read_scope, alloc_size, body);
+    }
+  }
+  struct Entry {
+    int read_count{0};
+    int write_count{0};
+  };
+  // The channels of each allocation.
+  const std::unordered_map<const Variable*, Channel>& cmap_;
+  // the result.
+  std::unordered_map<const Variable*, Entry> rmap_;
+};
+
+
+// Mark the statment of each stage.
+class StageSplitter : public IRMutator {
+ public:
+  Stmt Mutate(Stmt stmt) final {
+    nest_.push_back(stmt);
+    Stmt ret = IRMutator::Mutate(stmt);
+    nest_.pop_back();
+    return ret;
+  }
+  Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) {
+    if (!op->is_producer) return IRMutator::Mutate_(op, s);
+    Stmt body = Mutate(op->body);
+    stages_.emplace_back(BuildStage(body, op->func));
+    return Evaluate::make(0);
+  }
+
+  Stmt Split(Stmt stmt) {
+    stmt = Mutate(stmt);
+    stmt = RemoveNoOp(stmt);
+    CHECK(is_no_op(stmt));
+    CHECK_NE(stages_.size(), 0);
+    stmt = stages_.back();
+    for (size_t i = stages_.size() - 1; i != 0; --i) {
+      stmt = Block::make(stages_[i - 1], stmt);
+    }
+    stmt = MarkChannelAccess(cmap_).Mutate(stmt);
+    return RemoveNoOp(stmt);
+  }
+
+ private:
+  // Build the stage.
+  Stmt BuildStage(Stmt body, NodeRef target) {
+    int stage_index = static_cast<size_t>(stages_.size());
+    std::string stage_suffix = "." + std::to_string(stage_index);
+    // The Substitute
+    Map<Var, Expr> subst;
+    std::vector<Stmt> nest;
+    Stmt no_op = Evaluate::make(0);
+
+    for (const Stmt& s : nest_) {
+      if (const For* op = s.as<For>()) {
+        Var loop_var(op->loop_var);
+        Var new_var = loop_var.copy_with_suffix(stage_suffix);
+        subst.Set(loop_var, new_var);
+        nest.emplace_back(For::make(
+            new_var, op->min, op->extent,
+            op->for_type, op->device_api, no_op));
+      } else if (const LetStmt* op = s.as<LetStmt>()) {
+        Var var(op->var);
+        Var new_var = var.copy_with_suffix(stage_suffix);
+        subst.Set(var, new_var);
+        nest.emplace_back(LetStmt::make(new_var, op->value, no_op));
+      } else if (const IfThenElse* op = s.as<IfThenElse>()) {
+        CHECK(!op->else_case.defined());
+        nest.emplace_back(IfThenElse::make(op->condition, no_op));
+      } else if (const AttrStmt* op = s.as<AttrStmt>()) {
+        nest.emplace_back(AttrStmt::make(
+            op->node, op->type_key, op->value, no_op));
+      } else if (s.as<ProducerConsumer>()) {
+      } else if (s.as<Block>()) {
+      } else if (const Allocate* op = s.as<Allocate>()) {
+        nest.emplace_back(Allocate::make(
+            op->buffer_var, op->type, op->extents,
+            op->condition, no_op, op->new_expr, op->free_function));
+        MarkChannel(op);
+      } else {
+        LOG(FATAL) << "not supported nest type " << s->type_key();
+      }
+    }
+    body = Substitute(MergeNest(nest, body), subst);
+    return AttrStmt::make(
+        target, ir::attr::pipeline_stage_scope,
+        make_const(Int(32), stage_index), body);
+  }
+  void MarkChannel(const Allocate* op) {
+    if (!cmap_.count(op->buffer_var.get())) {
+      Channel ch = ChannelNode::make(Var(op->buffer_var), op->type);
+      cmap_[op->buffer_var.get()] = ch;
+    }
+  }
+  // The stack
+  std::vector<Stmt> nest_;
+  // The stages
+  std::vector<Stmt> stages_;
+  // channel map
+  std::unordered_map<const Variable*, Channel> cmap_;
+};
+
+Stmt SplitPipeline(Stmt stmt) {
+  return StageSplitter().Split(stmt);
+}
+
+}  // namespace ir
+}  // namespace tvm
diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc
index 43081d8b3..6489f21e7 100644
--- a/src/schedule/schedule_ops.cc
+++ b/src/schedule/schedule_ops.cc
@@ -26,12 +26,8 @@ Stmt MakePipeline(const Stage& s,
     producer = ProducerConsumer::make(s->op, true, producer);
   }
   Stmt pipeline = producer;
-  // check if consumer is nop.
-  bool is_no_op{false};
-  const Evaluate* ev = consumer.as<Evaluate>();
-  if (ev && ev->value.as<IntImm>()) is_no_op = true;
 
-  if (consumer.defined() && !is_no_op) {
+  if (consumer.defined() && !is_no_op(consumer)) {
     consumer = ProducerConsumer::make(s->op, false, consumer);
     pipeline = Block::make(producer, consumer);
   }
diff --git a/tests/python/unittest/test_pass_remove_no_op.py b/tests/python/unittest/test_pass_remove_no_op.py
new file mode 100644
index 000000000..8aadaf8c0
--- /dev/null
+++ b/tests/python/unittest/test_pass_remove_no_op.py
@@ -0,0 +1,29 @@
+import tvm
+
+def test_remove_no_op():
+    i = tvm.Var('i')
+    j = tvm.Var('j')
+    k = tvm.Var('k')
+    m = tvm.Var('m')
+    n = tvm.Var('n')
+    dtype = 'int64'
+    Ab = tvm.Buffer((n, ), dtype)
+    stmt = tvm.make.For(
+        i, 0, 4, 0, 0,
+        tvm.make.For(
+            j, 0, n, 0, 0,
+            tvm.make.For(
+                k, 0, m, 0, 0,
+                tvm.make.IfThenElse(
+                    (i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)))))
+    ret = tvm.ir_pass.RemoveNoOp(stmt)
+    assert(isinstance(ret, tvm.stmt.Evaluate))
+    store = tvm.make.Store(Ab.data,
+                           tvm.make.Load(dtype, Ab.data, i) + 1,
+                           i + 1)
+    stmt2 = tvm.make.Block(stmt, store)
+    assert(tvm.ir_pass.RemoveNoOp(stmt2) == store)
+
+
+if __name__ == "__main__":
+    test_remove_no_op()
diff --git a/tests/python/unittest/test_pass_split_pipeline.py b/tests/python/unittest/test_pass_split_pipeline.py
new file mode 100644
index 000000000..86beb5eee
--- /dev/null
+++ b/tests/python/unittest/test_pass_split_pipeline.py
@@ -0,0 +1,31 @@
+import tvm
+
+def test_basic_pipeline():
+    n = tvm.convert(128)
+    A = tvm.placeholder((n,), name='A')
+    stages = []
+    num_stage = 3
+
+    B = A
+    for k in range(num_stage):
+        stages.append(B)
+        B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k)
+
+    s = tvm.Schedule(B.op)
+    xo, xi = s[B].split(B.op.axis[0], factor=4)
+    for S in stages:
+        s[S].compute_at(s[B], xo)
+
+    # Lowering
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+    Ab = tvm.Buffer(A.shape, A.dtype, name='A')
+    Bb = tvm.Buffer(B.shape, B.dtype, name='B')
+    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb})
+    stmt = tvm.ir_pass.Simplify(stmt)
+    stmt = tvm.ir_pass.SplitPipeline(stmt)
+    print(stmt)
+    assert(tvm.ir_pass.VerifySSA(stmt))
+
+if __name__ == "__main__":
+    test_basic_pipeline()
-- 
GitLab