diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h
index e512921c969e8ec07ebc90ae849c7e027d0eeabe..079ab1efb040c3bd8a5650213a6da39c9369738c 100644
--- a/include/tvm/runtime/c_backend_api.h
+++ b/include/tvm/runtime/c_backend_api.h
@@ -44,14 +44,20 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
  *
  * \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
  *
- * \param size The size of the space requested.
+ * \param nbytes The size of the space requested.
  * \param device_type The device type which the space will be allocated.
  * \param device_id The device id which the space will be allocated.
+ * \param dtype_code_hint The type code of the array elements. Only used in
+ * certain backends such as OpenGL.
+ * \param dtype_bits_hint The type bits of the array elements. Only used in
+ * certain backends such as OpenGL.
  * \return nullptr when error is thrown, a valid ptr if success
  */
 TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
                                        int device_id,
-                                       uint64_t size);
+                                       uint64_t nbytes,
+                                       int dtype_code_hint,
+                                       int dtype_bits_hint);
 
 /*!
  * \brief Backend function to free temporal workspace.
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 9ba08fb868255c2baa87d9ca86a2a1a61e35e2ed..45009f1d3af3eecc4dc89a36b8295a66a40f6e32 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -114,9 +114,13 @@ class DeviceAPI {
    *  - Workspace should not overlap between different threads(i.e. be threadlocal)
    *
    * \param ctx The context of allocation.
-   * \param size The size to be allocated.
+   * \param nbytes The size to be allocated.
+   * \param type_hint The type of elements. Only needed by certain backends such
+   * as OpenGL, as nbytes is sufficient for most backends.
    */
-  TVM_DLL virtual void* AllocWorkspace(TVMContext ctx, size_t size);
+  TVM_DLL virtual void* AllocWorkspace(TVMContext ctx,
+                                       size_t nbytes,
+                                       TVMType type_hint = {});
   /*!
    * \brief Free temporal workspace in backend execution.
    *
diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc
index 496b15b34bfa6f2a727a12dd3532c79e9a254c74..696082749a37d18744ec8d8afd31e55186f41a05 100644
--- a/src/codegen/codegen_opengl.cc
+++ b/src/codegen/codegen_opengl.cc
@@ -24,6 +24,8 @@ void CodeGenOpenGL::InitFuncState(LoweredFunc f) {
   inputs_.clear();
   output_iter_var_ = nullptr;
   thread_extent_var_ = "";
+  this->decl_stream.str("");
+  this->stream.str("");
 }
 
 void CodeGenOpenGL::AddFunction(LoweredFunc f) {
diff --git a/src/codegen/stack_vm/codegen_stack_vm.cc b/src/codegen/stack_vm/codegen_stack_vm.cc
index 5b01dae7100ae9865b4d63f4375df57381f4de20..168e411fa6e227476d7436031387a84240fb6169 100644
--- a/src/codegen/stack_vm/codegen_stack_vm.cc
+++ b/src/codegen/stack_vm/codegen_stack_vm.cc
@@ -197,10 +197,12 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
     vm_.stack_size += size;
     this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size));
   } else if (op->name == "TVMBackendAllocWorkspace") {
-    CHECK_EQ(op->args.size(), 3U);
+    CHECK_EQ(op->args.size(), 5U);
     this->Push(op->args[0]);
     this->Push(op->args[1]);
     this->Push(op->args[2]);
+    this->Push(op->args[3]);
+    this->Push(op->args[4]);
     this->PushOp(StackVM::TVM_DEVICE_ALLOCA);
   } else if (op->name == "TVMBackendFreeWorkspace") {
     CHECK_EQ(op->args.size(), 3U);
diff --git a/src/codegen/stack_vm/stack_vm.cc b/src/codegen/stack_vm/stack_vm.cc
index a133c9797b1bf643f8dfae79344cb81d806de032..95feeae3679e927bbb4b3c9fdabf511e6b954aa1 100644
--- a/src/codegen/stack_vm/stack_vm.cc
+++ b/src/codegen/stack_vm/stack_vm.cc
@@ -455,12 +455,15 @@ void StackVM::Run(State* s) const {
         break;
       }
       case TVM_DEVICE_ALLOCA: {
-        int device_type = static_cast<int>(stack[sp - 2].v_int64);
-        int device_id = static_cast<int>(stack[sp - 1].v_int64);
-        size_t nbytes = static_cast<size_t>(stack[sp].v_int64);
-        void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes);
-        stack[sp - 2].v_handle = ptr;
-        sp = sp - 2;
+        int device_type = static_cast<int>(stack[sp - 4].v_int64);
+        int device_id = static_cast<int>(stack[sp - 3].v_int64);
+        size_t nbytes = static_cast<size_t>(stack[sp - 2].v_int64);
+        int dtype_code_hint = static_cast<int>(stack[sp - 1].v_int64);
+        int dtype_bits_hint = static_cast<int>(stack[sp].v_int64);
+        void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes,
+                                             dtype_code_hint, dtype_bits_hint);
+        stack[sp - 4].v_handle = ptr;
+        sp = sp - 4;
         pc = pc + 1;
         break;
       }
diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc
index 105d58b95829e5d7bc4c87aa8c09e0bedd503365..a63fef07bd12f1eebec60816b802d0f0c7f3a742 100644
--- a/src/pass/lower_tvm_builtin.cc
+++ b/src/pass/lower_tvm_builtin.cc
@@ -96,14 +96,18 @@ class BuiltinLower : public IRMutator {
                                     {op->buffer_var}, Call::PureIntrinsic),
                          throw_last_error),
         op->body);
-    Stmt alloca = LetStmt::make(op->buffer_var,
-                                Call::make(op->buffer_var.type(),
-                                           "TVMBackendAllocWorkspace",
-                                           {cast(Int(32), device_type_),
-                                                 cast(Int(32), device_id_),
-                                                 cast(UInt(64), total_bytes)},
-                                           Call::Extern),
-                                body);
+
+    Stmt alloca = LetStmt::make(
+        op->buffer_var,
+        Call::make(op->buffer_var.type(),
+                   "TVMBackendAllocWorkspace",
+                   {cast(Int(32), device_type_),
+                    cast(Int(32), device_id_),
+                    cast(UInt(64), total_bytes),
+                    IntImm::make(Int(32), op->type.code()),
+                    IntImm::make(Int(32), op->type.bits())},
+                   Call::Extern),
+        body);
 
     Expr free_op = Call::make(Int(32),
                               "TVMBackendFreeWorkspace",
diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc
index 942e70339488255ca211e4f91ff8cc0b3dccebbf..dc326f3cb2f17c48065a1b5f714bd6d0fa17f339 100644
--- a/src/pass/split_host_device.cc
+++ b/src/pass/split_host_device.cc
@@ -146,6 +146,11 @@ class IRUseDefAnalysis : public IRMutator {
 
 class HostDeviceSplitter : public IRMutator {
  public:
+  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
+    handle_data_type_[op->buffer_var.get()] = make_const(op->type, 0);
+    return IRMutator::Mutate_(op, s);
+  }
+
   Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
     if (op->attr_key == attr::thread_extent ||
         op->attr_key == attr::pipeline_exec_scope) {
diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc
index 0d0e36f239f28285a637d9fa22facf052599d5e6..2177fc344889bf9c1bbe023ad6890a4d948f16bd 100644
--- a/src/runtime/c_runtime_api.cc
+++ b/src/runtime/c_runtime_api.cc
@@ -95,8 +95,9 @@ DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
       static_cast<int>(ctx.device_type), allow_missing);
 }
 
-void* DeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) {
-  TVMType type_hint{kDLUInt, 8, 1};
+void* DeviceAPI::AllocWorkspace(TVMContext ctx,
+                                size_t size,
+                                TVMType type_hint) {
   return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
 }
 
@@ -220,12 +221,22 @@ int TVMBackendGetFuncFromEnv(void* mod_node,
 }
 
 void* TVMBackendAllocWorkspace(int device_type,
-                             int device_id,
-                             uint64_t size) {
+                               int device_id,
+                               uint64_t size,
+                               int dtype_code_hint,
+                               int dtype_bits_hint) {
   TVMContext ctx;
   ctx.device_type = static_cast<DLDeviceType>(device_type);
   ctx.device_id = device_id;
-  return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, static_cast<size_t>(size));
+
+  TVMType type_hint;
+  type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
+  type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
+  type_hint.lanes = 1;
+
+  return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
+                                                    static_cast<size_t>(size),
+                                                    type_hint);
 }
 
 int TVMBackendFreeWorkspace(int device_type,
diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc
index 30c3bb7d52df1d17a29428404da67e6ff88dcb87..7486f20a6ae16bf76571faa2e5218f6b546351d6 100644
--- a/src/runtime/cpu_device_api.cc
+++ b/src/runtime/cpu_device_api.cc
@@ -59,7 +59,7 @@ class CPUDeviceAPI final : public DeviceAPI {
   void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
   }
 
-  void* AllocWorkspace(TVMContext ctx, size_t size) final;
+  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
   void FreeWorkspace(TVMContext ctx, void* data) final;
 
   static const std::shared_ptr<CPUDeviceAPI>& Global() {
@@ -74,7 +74,9 @@ struct CPUWorkspacePool : public WorkspacePool {
       WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
 };
 
-void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) {
+void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx,
+                                   size_t size,
+                                   TVMType type_hint) {
   return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()
       ->AllocWorkspace(ctx, size);
 }
diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc
index 69b485a423c0a97e3cbb03991bd7ae9967aef84d..7885aa7705ed318da76de3081fbe458a798e0339 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -112,7 +112,7 @@ class CUDADeviceAPI final : public DeviceAPI {
         ->stream = static_cast<cudaStream_t>(stream);
   }
 
-  void* AllocWorkspace(TVMContext ctx, size_t size) final {
+  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
     return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
   }
 
diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index 7c2975fe7ccc0de6d38247790e33f2abd173e01c..fa73b8250c339d27a84a64c8d23ac5a85c632d32 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -77,7 +77,7 @@ class MetalWorkspace final : public DeviceAPI {
                       TVMContext ctx_to,
                       TVMStreamHandle stream) final;
   void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
-  void* AllocWorkspace(TVMContext ctx, size_t size) final;
+  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
   void FreeWorkspace(TVMContext ctx, void* data) final;
   // get the global workspace
   static const std::shared_ptr<MetalWorkspace>& Global();
diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm
index 82c52a23e03652fe314a2307d4e8f9944faf3b80..6d376d3144ac2aeb7f8b096acc98990811cb71a3 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -228,7 +228,9 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
   [cb waitUntilCompleted];
 }
 
-void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
+void* MetalWorkspace::AllocWorkspace(TVMContext ctx,
+                                     size_t size,
+                                     TVMType type_hint) {
   return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
 }
 
diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h
index 29e205ced4d7bf61ba92816caa6bdef238fe4924..67934a078665756547dc12b88264183f03a79ef8 100644
--- a/src/runtime/opencl/opencl_common.h
+++ b/src/runtime/opencl/opencl_common.h
@@ -156,7 +156,7 @@ class OpenCLWorkspace final : public DeviceAPI {
                       TVMContext ctx_to,
                       TVMStreamHandle stream) final;
   void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
-  void* AllocWorkspace(TVMContext ctx, size_t size) final;
+  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
   void FreeWorkspace(TVMContext ctx, void* data) final;
   // get the global workspace
   static const std::shared_ptr<OpenCLWorkspace>& Global();
diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc
index 7518e72f9d9b611067f1dbddf3bd8a5c0b2ef3d1..a07fe15f805fd92a6a69661d9ea83a0bdadf9da7 100644
--- a/src/runtime/opencl/opencl_device_api.cc
+++ b/src/runtime/opencl/opencl_device_api.cc
@@ -108,7 +108,9 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
   OPENCL_CALL(clFinish(this->GetQueue(ctx)));
 }
 
-void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
+void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx,
+                                      size_t size,
+                                      TVMType type_hint) {
   return OpenCLThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
 }
 
diff --git a/src/runtime/opengl/opengl_common.h b/src/runtime/opengl/opengl_common.h
index 80b1d9f95c8e0fa24dfda25386b527a552085015..661c987e4b3c509de9712980d685cf59f3b5ffaa 100644
--- a/src/runtime/opengl/opengl_common.h
+++ b/src/runtime/opengl/opengl_common.h
@@ -175,8 +175,6 @@ class OpenGLWorkspace final : public DeviceAPI {
                       TVMContext ctx_to,
                       TVMStreamHandle stream) final;
   void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
-  void* AllocWorkspace(TVMContext ctx, size_t size) final;
-  void FreeWorkspace(TVMContext ctx, void* data) final;
 
   /*!
    * \brief Get the global OpenGL workspace.
diff --git a/src/runtime/opengl/opengl_device_api.cc b/src/runtime/opengl/opengl_device_api.cc
index d90d12034ae6731468b68df6dc1aa6d2b8197747..df2947db625569afa556f1f1ad7580612f555f73 100644
--- a/src/runtime/opengl/opengl_device_api.cc
+++ b/src/runtime/opengl/opengl_device_api.cc
@@ -156,15 +156,6 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from,
 
 void OpenGLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {}
 
-void* OpenGLWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
-  LOG(FATAL) << "Cannot allocate OpenGL workspace.";
-  return nullptr;
-}
-
-void OpenGLWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
-  LOG(FATAL) << "Cannot free OpenGL workspace.";
-}
-
 OpenGLWorkspace::OpenGLWorkspace() {
   // Set an error handler.
   // This can be called before glfwInit().
diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc
index 443d76b76eb665f3b4f1fbc92a70e306a6aaf282..877907c7e0924c80150c0a2bfc7f21f4797a4e49 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -110,7 +110,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
         ->stream = static_cast<hipStream_t>(stream);
   }
 
-  void* AllocWorkspace(TVMContext ctx, size_t size) final {
+  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
     return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
   }
 
diff --git a/tests/webgl/test_local_multi_stage.py b/tests/webgl/test_local_multi_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..47fa5c76c7aa51b9cdc2fd35375b6c4275c5421e
--- /dev/null
+++ b/tests/webgl/test_local_multi_stage.py
@@ -0,0 +1,30 @@
+import tvm
+import numpy as np
+
+def test_local_multi_stage():
+    if not tvm.module.enabled("opengl"):
+        return
+    if not tvm.module.enabled("llvm"):
+        return
+
+    n = tvm.var("n")
+    A = tvm.placeholder((n,), name='A', dtype="int32")
+    B = tvm.compute((n,), lambda i: A[i] + 1, name="B")
+    C = tvm.compute((n,), lambda i: B[i] * 2, name="C")
+
+    s = tvm.create_schedule(C.op)
+    s[B].opengl()
+    s[C].opengl()
+
+    f = tvm.build(s, [A, C], "opengl", name="multi_stage")
+
+    ctx = tvm.opengl(0)
+    n = 10
+    a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
+    c = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx)
+    f(a, c)
+
+    np.testing.assert_allclose(c.asnumpy(), (a.asnumpy() + 1) * 2)
+
+if __name__ == "__main__":
+    test_local_multi_stage()