From 7f82912bfb3d9af249a6df28998d01dca60dc1e6 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sun, 15 Jan 2017 21:40:21 -0800
Subject: [PATCH] [PASS] Basic storage flatten (#13)

---
 python/tvm/_ctypes/_api.py                |   2 +-
 python/tvm/function.py                    |   3 +-
 src/c_api/c_api_pass.cc                   |   1 +
 src/lang/buffer.cc                        |   2 +-
 src/pass/ir_mutator.cc                    |   2 +-
 src/pass/storage_flatten.cc               | 168 ++++++++++++++++++++++
 tests/python/test_pass_storage_flatten.py |  24 ++++
 7 files changed, 198 insertions(+), 4 deletions(-)
 create mode 100644 src/pass/storage_flatten.cc
 create mode 100644 tests/python/test_pass_storage_flatten.py

diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py
index 4de64f9db..cf5df6190 100644
--- a/python/tvm/_ctypes/_api.py
+++ b/python/tvm/_ctypes/_api.py
@@ -225,7 +225,7 @@ def _make_function(handle, name):
         """TVM function"""
         cargs = []
         for x in args:
-            if isinstance(x, (list, tuple, SliceBase)):
+            if isinstance(x, (list, tuple, dict, SliceBase)):
                 cargs.append(convert(x))
             else:
                 cargs.append(x)
diff --git a/python/tvm/function.py b/python/tvm/function.py
index 78491404d..72ec0d268 100644
--- a/python/tvm/function.py
+++ b/python/tvm/function.py
@@ -133,7 +133,8 @@ def compute(shape, fcompute, name="compute"):
 
 
 def Buffer(shape, dtype=None,
-           name="buffer", ptr=None,
+           name="buffer",
+           ptr=None,
            strides=None):
     """Create a new buffer
 
diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc
index 2d4cb6e3f..e05f696bd 100644
--- a/src/c_api/c_api_pass.cc
+++ b/src/c_api/c_api_pass.cc
@@ -36,6 +36,7 @@ REGISTER_PASS1(ConvertSSA);
 REGISTER_PASS1(VerifySSA);
 REGISTER_PASS4(Inline);
 REGISTER_PASS2(ScheduleOps);
+REGISTER_PASS2(StorageFlatten);
 
 }  // namespace ir
 }  // namespace tvm
diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc
index 02cd05224..b44bca783 100644
--- a/src/lang/buffer.cc
+++ b/src/lang/buffer.cc
@@ -51,7 +51,7 @@ Expr Buffer::MakeLoad(Array<Expr> index) const {
 Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
   const BufferNode* n = operator->();
   CHECK_EQ(value.type(), n->dtype);
-  return ir::Store::make(n->ptr, BufferOffset(n, index), value);
+  return ir::Store::make(n->ptr, value, BufferOffset(n, index));
 }
 
 Buffer BufferNode::make(std::string name,
diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc
index 2c534a6c1..b2572b88e 100644
--- a/src/pass/ir_mutator.cc
+++ b/src/pass/ir_mutator.cc
@@ -83,7 +83,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
         body.same_as(op->body)) {
       return s;
     } else {
-      return AttrStmt::make(op->node, op->type_key, op->value, op->body);
+      return AttrStmt::make(op->node, op->type_key, value, body);
     }
   });
 
diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc
new file mode 100644
index 000000000..6058b6907
--- /dev/null
+++ b/src/pass/storage_flatten.cc
@@ -0,0 +1,168 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ * \file storage_flatten.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_mutator.h>
+#include <tvm/ir_pass.h>
+#include <unordered_map>
+
+namespace tvm {
+namespace ir {
+
+// key of function buffer
+struct TensorKey {
+  FunctionRef f;
+  int value_index;
+
+  inline bool operator==(const TensorKey& other) const {
+    return f == other.f && value_index == other.value_index;
+  }
+  inline std::string GetName() const {
+    if (f->num_outputs() == 1) return f->func_name();
+    std::ostringstream os;
+    os << f->func_name() << ".v" << value_index;
+    return os.str();
+  }
+};
+
+}  // namespace ir
+}  // namespace tvm
+
+namespace std {
+template <>
+struct hash<::tvm::ir::TensorKey> {
+  std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
+    size_t lhs = k.f.hash();
+    size_t rhs = static_cast<size_t>(k.value_index);
+    lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
+    return lhs;
+  }
+};
+}  // namespace std
+
+namespace tvm {
+namespace ir {
+
+using Halide::Internal::Region;
+
+// inliner to inline a function
+// the result may not be SSA,
+// ConvertSSA need to be applied after this pass
+class StorageFlattener : public IRMutator {
+ public:
+  explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer) {
+    for (auto kv : extern_buffer) {
+      BufferEntry e;
+      e.buffer = kv.second;
+      e.external = true;
+      buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
+    }
+  }
+  Expr Mutate(Expr expr) final {
+    expr = IRMutator::Mutate(expr);
+    const Call* op = expr.as<Call>();
+    if (op != nullptr && op->call_type == Call::Halide) {
+      TensorKey key{op->func, op->value_index};
+      auto it = buf_map_.find(key);
+      CHECK(it != buf_map_.end())
+          << "Cannot find allocated buffer for " << key.f;
+      const BufferEntry& e = it->second;
+      CHECK(!e.released)
+          << "Read a buffer that is already out of scope";
+      return e.buffer.MakeLoad(e.RelIndex(op->args));
+    } else {
+      return expr;
+    }
+  }
+
+  Stmt Mutate(Stmt stmt) final {
+    const Realize* realize = stmt.as<Realize>();
+    if (realize != nullptr) {
+      return HandleRealize(realize);
+    } else if (stmt.as<Provide>()) {
+      return HandleProvide(stmt);
+    } else {
+      return IRMutator::Mutate(stmt);
+    }
+  }
+
+ private:
+  // The buffer entry in the flatten map
+  struct BufferEntry {
+    // the buffer of storage
+    Buffer buffer;
+    // the bounds of realization, can be null
+    Region bounds;
+    // Whether the buffer is external
+    bool external{false};
+    // Whether we are out of allocation bounds and buffer get released.
+    bool released{false};
+    // TODO(tqchen) allow permutation and inference of index dimension.
+    // relative index
+    inline Array<Expr> RelIndex(Array<Expr> args) const {
+      if (bounds.size() != 0) {
+        Array<Expr> index;
+        CHECK_EQ(bounds.size(), args.size());
+        for (size_t i = 0; i < bounds.size(); ++i) {
+          index.push_back(args[i] - bounds[i]->min);
+        }
+        return index;
+      } else {
+        return args;
+      }
+    }
+  };
+
+  // The buffer assignment map
+  std::unordered_map<TensorKey, BufferEntry> buf_map_;
+
+  Stmt HandleRealize(const Realize* op) {
+    TensorKey key{op->func, op->value_index};
+    if (buf_map_.count(key)) {
+      CHECK(buf_map_.at(key).external);
+      return this->Mutate(op->body);
+    } else {
+      // create a buffer entry
+      // TODO(tqchen) allow permutation and inference of index dimension.
+      BufferEntry e;
+      e.bounds = op->bounds;
+      Array<Expr> shape;
+      for (auto r : e.bounds) {
+        shape.push_back(r->extent);
+      }
+      e.buffer = Buffer(shape, op->type, key.GetName());
+
+      buf_map_[key] = e;
+      Stmt body = this->Mutate(op->body);
+      buf_map_[key].released = true;
+
+      return Allocate::make(
+          e.buffer->ptr, e.buffer->dtype, e.buffer->shape,
+          make_const(Bool(e.buffer->dtype.lanes()), true), body);
+    }
+  }
+
+  Stmt HandleProvide(Stmt stmt) {
+    stmt = IRMutator::Mutate(stmt);
+    const Provide* op = stmt.as<Provide>();
+    TensorKey key{op->func, op->value_index};
+    auto it = buf_map_.find(key);
+    CHECK(it != buf_map_.end())
+        << "Cannot find allocated buffer for " << key.f;
+    const BufferEntry& e = it->second;
+    CHECK(!e.released)
+        << "Read a buffer that is already out of scope";
+    return e.buffer.MakeStore(e.RelIndex(op->args), op->value);
+  }
+};
+
+
+Stmt StorageFlatten(Stmt stmt,
+                    Map<Tensor, Buffer> extern_buffer) {
+  stmt = StorageFlattener(extern_buffer).Mutate(stmt);
+  return stmt;
+}
+
+}  // namespace ir
+}  // namespace tvm
diff --git a/tests/python/test_pass_storage_flatten.py b/tests/python/test_pass_storage_flatten.py
new file mode 100644
index 000000000..b7dff05d0
--- /dev/null
+++ b/tests/python/test_pass_storage_flatten.py
@@ -0,0 +1,24 @@
+import tvm
+
+def test_flatten2():
+    m = tvm.Var('m')
+    l = tvm.Var('l')
+    A = tvm.placeholder((m, l), name='A')
+    A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
+    A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+
+    s = tvm.Schedule(A2.op)
+    xo, xi = s[A2].split(A2.op.axis[0], 8)
+    s[A1].compute_at(s[A2], xo)
+    bounds = tvm.schedule.InferBound(s)
+    assert isinstance(bounds, tvm.collections.Map)
+    stmt = tvm.ir_pass.ScheduleOps(s, bounds)
+
+    print(stmt)
+    Ab = tvm.Buffer(A.shape, A.dtype, name='A')
+    A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
+    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
+    print(stmt)
+
+if __name__ == "__main__":
+    test_flatten2()
-- 
GitLab