From 944de73b4b6a659a7b231461b04df7a8a255a52c Mon Sep 17 00:00:00 2001
From: Zhixun Tan <phisiart@gmail.com>
Date: Sun, 28 Jan 2018 00:50:52 -0500
Subject: [PATCH] Add type code and bits to AllocWorkspace. (#831)

---
 include/tvm/runtime/c_backend_api.h      | 10 ++++++--
 include/tvm/runtime/device_api.h         |  8 +++++--
 src/codegen/codegen_opengl.cc            |  2 ++
 src/codegen/stack_vm/codegen_stack_vm.cc |  4 +++-
 src/codegen/stack_vm/stack_vm.cc         | 15 +++++++-----
 src/pass/lower_tvm_builtin.cc            | 20 +++++++++-------
 src/pass/split_host_device.cc            |  5 ++++
 src/runtime/c_runtime_api.cc             | 21 +++++++++++++----
 src/runtime/cpu_device_api.cc            |  6 +++--
 src/runtime/cuda/cuda_device_api.cc      |  2 +-
 src/runtime/metal/metal_common.h         |  2 +-
 src/runtime/metal/metal_device_api.mm    |  4 +++-
 src/runtime/opencl/opencl_common.h       |  2 +-
 src/runtime/opencl/opencl_device_api.cc  |  4 +++-
 src/runtime/opengl/opengl_common.h       |  2 --
 src/runtime/opengl/opengl_device_api.cc  |  9 -------
 src/runtime/rocm/rocm_device_api.cc      |  2 +-
 tests/webgl/test_local_multi_stage.py    | 30 ++++++++++++++++++++++++
 18 files changed, 105 insertions(+), 43 deletions(-)
 create mode 100644 tests/webgl/test_local_multi_stage.py

diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h
index e512921c9..079ab1efb 100644
--- a/include/tvm/runtime/c_backend_api.h
+++ b/include/tvm/runtime/c_backend_api.h
@@ -44,14 +44,20 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
  *
  * \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
  *
- * \param size The size of the space requested.
+ * \param nbytes The size of the space requested.
  * \param device_type The device type which the space will be allocated.
  * \param device_id The device id which the space will be allocated.
+ * \param dtype_code_hint The type code of the array elements. Only used in
+ * certain backends such as OpenGL.
+ * \param dtype_bits_hint The type bits of the array elements. Only used in
+ * certain backends such as OpenGL.
  * \return nullptr when error is thrown, a valid ptr if success
  */
 TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
                                        int device_id,
-                                       uint64_t size);
+                                       uint64_t nbytes,
+                                       int dtype_code_hint,
+                                       int dtype_bits_hint);
 
 /*!
  * \brief Backend function to free temporal workspace.
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 9ba08fb86..45009f1d3 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -114,9 +114,13 @@ class DeviceAPI {
    *  - Workspace should not overlap between different threads(i.e. be threadlocal)
    *
    * \param ctx The context of allocation.
-   * \param size The size to be allocated.
+   * \param nbytes The size to be allocated.
+   * \param type_hint The type of elements. Only needed by certain backends such
+   * as OpenGL, as nbytes is sufficient for most backends.
    */
-  TVM_DLL virtual void* AllocWorkspace(TVMContext ctx, size_t size);
+  TVM_DLL virtual void* AllocWorkspace(TVMContext ctx,
+                                       size_t nbytes,
+                                       TVMType type_hint = {});
   /*!
    * \brief Free temporal workspace in backend execution.
    *
diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc
index 496b15b34..696082749 100644
--- a/src/codegen/codegen_opengl.cc
+++ b/src/codegen/codegen_opengl.cc
@@ -24,6 +24,8 @@ void CodeGenOpenGL::InitFuncState(LoweredFunc f) {
   inputs_.clear();
   output_iter_var_ = nullptr;
   thread_extent_var_ = "";
+  this->decl_stream.str("");
+  this->stream.str("");
 }
 
 void CodeGenOpenGL::AddFunction(LoweredFunc f) {
diff --git a/src/codegen/stack_vm/codegen_stack_vm.cc b/src/codegen/stack_vm/codegen_stack_vm.cc
index 5b01dae71..168e411fa 100644
--- a/src/codegen/stack_vm/codegen_stack_vm.cc
+++ b/src/codegen/stack_vm/codegen_stack_vm.cc
@@ -197,10 +197,12 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
     vm_.stack_size += size;
     this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size));
   } else if (op->name == "TVMBackendAllocWorkspace") {
-    CHECK_EQ(op->args.size(), 3U);
+    CHECK_EQ(op->args.size(), 5U);
     this->Push(op->args[0]);
     this->Push(op->args[1]);
     this->Push(op->args[2]);
+    this->Push(op->args[3]);
+    this->Push(op->args[4]);
     this->PushOp(StackVM::TVM_DEVICE_ALLOCA);
   } else if (op->name == "TVMBackendFreeWorkspace") {
     CHECK_EQ(op->args.size(), 3U);
diff --git a/src/codegen/stack_vm/stack_vm.cc b/src/codegen/stack_vm/stack_vm.cc
index a133c9797..95feeae36 100644
--- a/src/codegen/stack_vm/stack_vm.cc
+++ b/src/codegen/stack_vm/stack_vm.cc
@@ -455,12 +455,15 @@ void StackVM::Run(State* s) const {
         break;
       }
       case TVM_DEVICE_ALLOCA: {
-        int device_type = static_cast<int>(stack[sp - 2].v_int64);
-        int device_id = static_cast<int>(stack[sp - 1].v_int64);
-        size_t nbytes = static_cast<size_t>(stack[sp].v_int64);
-        void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes);
-        stack[sp - 2].v_handle = ptr;
-        sp = sp - 2;
+        int device_type = static_cast<int>(stack[sp - 4].v_int64);
+        int device_id = static_cast<int>(stack[sp - 3].v_int64);
+        size_t nbytes = static_cast<size_t>(stack[sp - 2].v_int64);
+        int dtype_code_hint = static_cast<int>(stack[sp - 1].v_int64);
+        int dtype_bits_hint = static_cast<int>(stack[sp].v_int64);
+        void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes,
+                                             dtype_code_hint, dtype_bits_hint);
+        stack[sp - 4].v_handle = ptr;
+        sp = sp - 4;
         pc = pc + 1;
         break;
       }
diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc
index 105d58b95..a63fef07b 100644
--- a/src/pass/lower_tvm_builtin.cc
+++ b/src/pass/lower_tvm_builtin.cc
@@ -96,14 +96,18 @@ class BuiltinLower : public IRMutator {
                                     {op->buffer_var}, Call::PureIntrinsic),
                          throw_last_error),
         op->body);
-    Stmt alloca = LetStmt::make(op->buffer_var,
-                                Call::make(op->buffer_var.type(),
-                                           "TVMBackendAllocWorkspace",
-                                           {cast(Int(32), device_type_),
-                                                 cast(Int(32), device_id_),
-                                                 cast(UInt(64), total_bytes)},
-                                           Call::Extern),
-                                body);
+
+    Stmt alloca = LetStmt::make(
+        op->buffer_var,
+        Call::make(op->buffer_var.type(),
+                   "TVMBackendAllocWorkspace",
+                   {cast(Int(32), device_type_),
+                    cast(Int(32), device_id_),
+                    cast(UInt(64), total_bytes),
+                    IntImm::make(Int(32), op->type.code()),
+                    IntImm::make(Int(32), op->type.bits())},
+                   Call::Extern),
+        body);
 
     Expr free_op = Call::make(Int(32),
                               "TVMBackendFreeWorkspace",
diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc
index 942e70339..dc326f3cb 100644
--- a/src/pass/split_host_device.cc
+++ b/src/pass/split_host_device.cc
@@ -146,6 +146,11 @@ class IRUseDefAnalysis : public IRMutator {
 
 class HostDeviceSplitter : public IRMutator {
  public:
+  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
+    handle_data_type_[op->buffer_var.get()] = make_const(op->type, 0);
+    return IRMutator::Mutate_(op, s);
+  }
+
   Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
     if (op->attr_key == attr::thread_extent ||
         op->attr_key == attr::pipeline_exec_scope) {
diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc
index 0d0e36f23..2177fc344 100644
--- a/src/runtime/c_runtime_api.cc
+++ b/src/runtime/c_runtime_api.cc
@@ -95,8 +95,9 @@ DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
       static_cast<int>(ctx.device_type), allow_missing);
 }
 
-void* DeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) {
-  TVMType type_hint{kDLUInt, 8, 1};
+void* DeviceAPI::AllocWorkspace(TVMContext ctx,
+                                size_t size,
+                                TVMType type_hint) {
   return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
 }
 
@@ -220,12 +221,22 @@ int TVMBackendGetFuncFromEnv(void* mod_node,
 }
 
 void* TVMBackendAllocWorkspace(int device_type,
-                             int device_id,
-                             uint64_t size) {
+                               int device_id,
+                               uint64_t size,
+                               int dtype_code_hint,
+                               int dtype_bits_hint) {
   TVMContext ctx;
   ctx.device_type = static_cast<DLDeviceType>(device_type);
   ctx.device_id = device_id;
-  return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, static_cast<size_t>(size));
+
+  TVMType type_hint;
+  type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
+  type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
+  type_hint.lanes = 1;
+
+  return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
+                                                    static_cast<size_t>(size),
+                                                    type_hint);
 }
 
 int TVMBackendFreeWorkspace(int device_type,
diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc
index 30c3bb7d5..7486f20a6 100644
--- a/src/runtime/cpu_device_api.cc
+++ b/src/runtime/cpu_device_api.cc
@@ -59,7 +59,7 @@ class CPUDeviceAPI final : public DeviceAPI {
   void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
   }
 
-  void* AllocWorkspace(TVMContext ctx, size_t size) final;
+  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
   void FreeWorkspace(TVMContext ctx, void* data) final;
 
   static const std::shared_ptr<CPUDeviceAPI>& Global() {
@@ -74,7 +74,9 @@ struct CPUWorkspacePool : public WorkspacePool {
       WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
 };
 
-void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) {
+void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx,
+                                   size_t size,
+                                   TVMType type_hint) {
   return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()
       ->AllocWorkspace(ctx, size);
 }
diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc
index 69b485a42..7885aa770 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -112,7 +112,7 @@ class CUDADeviceAPI final : public DeviceAPI {
         ->stream = static_cast<cudaStream_t>(stream);
   }
 
-  void* AllocWorkspace(TVMContext ctx, size_t size) final {
+  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
     return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
   }
 
diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index 7c2975fe7..fa73b8250 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -77,7 +77,7 @@ class MetalWorkspace final : public DeviceAPI {
                       TVMContext ctx_to,
                       TVMStreamHandle stream) final;
   void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
-  void* AllocWorkspace(TVMContext ctx, size_t size) final;
+  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
   void FreeWorkspace(TVMContext ctx, void* data) final;
   // get the global workspace
   static const std::shared_ptr<MetalWorkspace>& Global();
diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm
index 82c52a23e..6d376d314 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -228,7 +228,9 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
   [cb waitUntilCompleted];
 }
 
-void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
+void* MetalWorkspace::AllocWorkspace(TVMContext ctx,
+                                     size_t size,
+                                     TVMType type_hint) {
   return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
 }
 
diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h
index 29e205ced..67934a078 100644
--- a/src/runtime/opencl/opencl_common.h
+++ b/src/runtime/opencl/opencl_common.h
@@ -156,7 +156,7 @@ class OpenCLWorkspace final : public DeviceAPI {
                       TVMContext ctx_to,
                       TVMStreamHandle stream) final;
   void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
-  void* AllocWorkspace(TVMContext ctx, size_t size) final;
+  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
   void FreeWorkspace(TVMContext ctx, void* data) final;
   // get the global workspace
   static const std::shared_ptr<OpenCLWorkspace>& Global();
diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc
index 7518e72f9..a07fe15f8 100644
--- a/src/runtime/opencl/opencl_device_api.cc
+++ b/src/runtime/opencl/opencl_device_api.cc
@@ -108,7 +108,9 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
   OPENCL_CALL(clFinish(this->GetQueue(ctx)));
 }
 
-void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
+void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx,
+                                      size_t size,
+                                      TVMType type_hint) {
   return OpenCLThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
 }
 
diff --git a/src/runtime/opengl/opengl_common.h b/src/runtime/opengl/opengl_common.h
index 80b1d9f95..661c987e4 100644
--- a/src/runtime/opengl/opengl_common.h
+++ b/src/runtime/opengl/opengl_common.h
@@ -175,8 +175,6 @@ class OpenGLWorkspace final : public DeviceAPI {
                       TVMContext ctx_to,
                       TVMStreamHandle stream) final;
   void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
-  void* AllocWorkspace(TVMContext ctx, size_t size) final;
-  void FreeWorkspace(TVMContext ctx, void* data) final;
 
   /*!
    * \brief Get the global OpenGL workspace.
diff --git a/src/runtime/opengl/opengl_device_api.cc b/src/runtime/opengl/opengl_device_api.cc
index d90d12034..df2947db6 100644
--- a/src/runtime/opengl/opengl_device_api.cc
+++ b/src/runtime/opengl/opengl_device_api.cc
@@ -156,15 +156,6 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from,
 
 void OpenGLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {}
 
-void* OpenGLWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
-  LOG(FATAL) << "Cannot allocate OpenGL workspace.";
-  return nullptr;
-}
-
-void OpenGLWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
-  LOG(FATAL) << "Cannot free OpenGL workspace.";
-}
-
 OpenGLWorkspace::OpenGLWorkspace() {
   // Set an error handler.
   // This can be called before glfwInit().
diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc
index 443d76b76..877907c7e 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -110,7 +110,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
         ->stream = static_cast<hipStream_t>(stream);
   }
 
-  void* AllocWorkspace(TVMContext ctx, size_t size) final {
+  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
     return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
   }
 
diff --git a/tests/webgl/test_local_multi_stage.py b/tests/webgl/test_local_multi_stage.py
new file mode 100644
index 000000000..47fa5c76c
--- /dev/null
+++ b/tests/webgl/test_local_multi_stage.py
@@ -0,0 +1,30 @@
+import tvm
+import numpy as np
+
+def test_local_multi_stage():
+    if not tvm.module.enabled("opengl"):
+        return
+    if not tvm.module.enabled("llvm"):
+        return
+
+    n = tvm.var("n")
+    A = tvm.placeholder((n,), name='A', dtype="int32")
+    B = tvm.compute((n,), lambda i: A[i] + 1, name="B")
+    C = tvm.compute((n,), lambda i: B[i] * 2, name="C")
+
+    s = tvm.create_schedule(C.op)
+    s[B].opengl()
+    s[C].opengl()
+
+    f = tvm.build(s, [A, C], "opengl", name="multi_stage")
+
+    ctx = tvm.opengl(0)
+    n = 10
+    a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
+    c = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx)
+    f(a, c)
+
+    np.testing.assert_allclose(c.asnumpy(), (a.asnumpy() + 1) * 2)
+
+if __name__ == "__main__":
+    test_local_multi_stage()
-- 
GitLab