diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc
index 9b8995bf5516ef1d1ab93a4c52d12b9c6c2a7176..fa42cefa07a7fd6f049dfeed4db59aa8c7c0500c 100644
--- a/src/codegen/llvm/codegen_amdgpu.cc
+++ b/src/codegen/llvm/codegen_amdgpu.cc
@@ -44,7 +44,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
       if (info.alignment > 16) {
         info.alignment = 16;
       }
-      if (info.scope.rank == 2) {
+      if (info.scope.rank == runtime::StorageRank::kLocal) {
         // const int local_address_space = 5;
         // TODO(tqchen): for higher version of LLVM, local address space can be set.
         llvm::AllocaInst* alloca = builder_->CreateAlloca(
@@ -54,7 +54,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
         }
         buf = alloca;
       } else {
-        CHECK_EQ(info.scope.rank, 1)
+        CHECK(info.scope.rank == runtime::StorageRank::kShared)
             << "Can only allocate shared or local memory inside kernel";
         // Shared memory: address space  == 3
         const unsigned shared_address_space = 3;
diff --git a/src/codegen/llvm/codegen_nvptx.cc b/src/codegen/llvm/codegen_nvptx.cc
index c0002873d5fc3a0447c0201fea6fc1697824cc96..d354e3b9eaf01da03637c1596c21edc0f59f5ed0 100644
--- a/src/codegen/llvm/codegen_nvptx.cc
+++ b/src/codegen/llvm/codegen_nvptx.cc
@@ -47,7 +47,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
       if (info.alignment > 16) {
         info.alignment = 16;
       }
-      if (info.scope.rank == 2) {
+      if (info.scope.rank == runtime::StorageRank::kLocal) {
         // const int local_address_space = 5;
         // TODO(tqchen): for higher version of LLVM, local address space can be set.
         llvm::AllocaInst* alloca = builder_->CreateAlloca(
@@ -57,7 +57,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
         }
         buf = alloca;
       } else {
-        CHECK_EQ(info.scope.rank, 1)
+        CHECK(info.scope.rank == runtime::StorageRank::kShared)
             << "Can only allocate shared or local memory inside kernel";
         // Shared memory: address space  == 3
         const unsigned shared_address_space = 3;
diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc
index 4d7a9b21ba5b5971dea9afb1c3d717df539eb172..d844a7b1139095b0e2f01f55351f8eaec04d0772 100644
--- a/src/codegen/spirv/codegen_spirv.cc
+++ b/src/codegen/spirv/codegen_spirv.cc
@@ -561,13 +561,13 @@ void CodeGenSPIRV::VisitStmt_(const Allocate* op) {
   spirv::Value buf;
   StorageInfo& info = storage_info_[op->buffer_var.get()];
   spirv::SType etype = builder_->GetSType(op->type);
-  if (info.scope.rank == 2) {
+  if (info.scope.rank == runtime::StorageRank::kLocal) {
     buf = builder_->Allocate(
         etype, static_cast<uint32_t>(constant_size),
         spv::StorageClassFunction);
   } else {
     // shared memory
-    CHECK_EQ(info.scope.rank, 1)
+    CHECK(info.scope.rank == runtime::StorageRank::kShared)
         << "Can only allocate shared or local memory inside kernel";
     // Shared memory
     buf = builder_->Allocate(
diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc
index 9211f3f71de07a47fa5d357964c098ba224da700..09be1a53da42bd1398f70147acf0830c906f2ff6 100644
--- a/src/pass/storage_access.cc
+++ b/src/pass/storage_access.cc
@@ -210,7 +210,8 @@ void StorageAccessVisitor::Visit_(const Call* op) {
 
 StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
   auto it = storage_scope_.find(buf);
-  StorageScope s; s.rank = 0;
+  StorageScope s;
+  s.rank = StorageRank::kGlobal;
   if (it == storage_scope_.end()) return s;
   return it->second;
 }
diff --git a/src/pass/storage_access.h b/src/pass/storage_access.h
index 7268bb6683421b1d3b5130321e1ad19b19b2d374..4f313f8e7c24d6ed95edb40451ade06b7866f8a2 100644
--- a/src/pass/storage_access.h
+++ b/src/pass/storage_access.h
@@ -17,6 +17,7 @@ namespace tvm {
 namespace ir {
 
 using runtime::StorageScope;
+using runtime::StorageRank;
 /*!
  * \brief Base class of storage access analysis
  */
diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc
index 94332ff5cb7e0a27b86a2073090f5d0f655e7dca..f5cb98495ff9ce6b59b495bbc6986d418ce41197 100644
--- a/src/pass/storage_flatten.cc
+++ b/src/pass/storage_flatten.cc
@@ -23,6 +23,7 @@ namespace tvm {
 namespace ir {
 
 using HalideIR::Internal::Region;
+using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
 using intrinsic::tvm_address_of;
@@ -141,7 +142,8 @@ class StorageFlattener : public IRMutator {
       const std::string& strkey = it->second;
       if (strkey.length() == 0) {
         if (curr_thread_scope_.size() != 0) {
-          skey.rank = curr_thread_scope_.back().rank + 1;
+          skey.rank = runtime::DefaultStorageRank(
+              curr_thread_scope_.back().rank);
         }
       } else {
         skey = StorageScope::make(strkey);
diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc
index 0a8782366193c714543c0b3e1c92ff4661693891..998df034e5a1522422dc3def723557324a9da996 100644
--- a/src/pass/storage_rewrite.cc
+++ b/src/pass/storage_rewrite.cc
@@ -19,6 +19,7 @@
 namespace tvm {
 namespace ir {
 
+using runtime::StorageRank;
 using runtime::StorageScope;
 
 // Find a linear pattern of storage acess
@@ -794,7 +795,7 @@ class StoragePlanRewriter : public IRMutator {
     // disable reuse of small arrays, they will be lowered to registers in LLVM
     // This rules only apply if we are using non special memory
     if (scope.tag.length() == 0) {
-      if (scope.rank > 1 || op->type.is_handle()) {
+      if (scope.rank >= StorageRank::kWarp || op->type.is_handle()) {
         return NewAlloc(op, attach_scope, scope, const_nbits);
       }
       if (const_nbits > 0  &&  const_nbits <= 32) {
@@ -853,7 +854,8 @@ class StoragePlanRewriter : public IRMutator {
     // This rules only apply if we are using non special memory
     if (e->scope.tag.length() == 0) {
       // Disable sharing of local memory.
-      if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return;
+      if (e->scope.rank >= StorageRank::kWarp ||
+          e->allocs[0]->type.is_handle()) return;
       // disable reuse of small arrays
       if (e->const_nbits > 0 && e->const_nbits <= 32) return;
     }
diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc
index af3dc1f128e5e49c8fe704ce65ee091505b5d035..6e2d1020a6b539a54da0316dfae6ea86b40dc883 100644
--- a/src/pass/storage_sync.cc
+++ b/src/pass/storage_sync.cc
@@ -189,7 +189,7 @@ class ThreadSyncInserter : public IRMutator {
     if (syncs_.size() == 0) return stmt;
     if (syncs_.count(stmt.get())) {
       Stmt barrier;
-      if (sync_scope_.rank == 0) {
+      if (sync_scope_.rank == StorageRank::kGlobal) {
         barrier = MakeGlobalBarrier();
       } else {
         barrier = Evaluate::make(
@@ -206,15 +206,15 @@ class ThreadSyncInserter : public IRMutator {
     return stmt;
   }
   Expr Mutate_(const Load* op, const Expr& e) final {
-    if (sync_scope_.rank == 0 &&
-        GetScope(op->buffer_var.get()).rank == 0) {
+    if (sync_scope_.rank == StorageRank::kGlobal &&
+        GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
       ++rw_stats_[op->buffer_var].read_count;
     }
     return IRMutator::Mutate_(op, e);
   }
   Stmt Mutate_(const Store* op, const Stmt& s) final {
-    if (sync_scope_.rank == 0 &&
-        GetScope(op->buffer_var.get()).rank == 0) {
+    if (sync_scope_.rank == StorageRank::kGlobal &&
+        GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
       ++rw_stats_[op->buffer_var].write_count;
     }
     return IRMutator::Mutate_(op, s);
@@ -228,7 +228,7 @@ class ThreadSyncInserter : public IRMutator {
       thread_extents_.pop_back();
       std::swap(temp, in_thread_env_);
       // first thread scope.
-      if (!in_thread_env_ && sync_scope_.rank == 0) {
+      if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
         ret = InitGlobalBarrier(ret.as<AttrStmt>());
         num_blocks_ = Expr();
         is_lead_ = Expr();
@@ -253,7 +253,8 @@ class ThreadSyncInserter : public IRMutator {
   // Get current storage scope.
   StorageScope GetScope(const Variable* buf) const {
     auto it = storage_scope_.find(buf);
-    StorageScope s; s.rank = 0;
+    StorageScope s;
+    s.rank = StorageRank::kGlobal;
     if (it == storage_scope_.end()) return s;
     return it->second;
   }
@@ -279,7 +280,7 @@ class ThreadSyncInserter : public IRMutator {
     return Block::make(prep, body);
   }
   Stmt MakeGlobalBarrier() {
-    CHECK_EQ(sync_scope_.rank, 0);
+    CHECK(sync_scope_.rank == StorageRank::kGlobal);
     if (!num_blocks_.defined()) {
       CHECK(!is_lead_.defined());
       num_work_dim_ = thread_extents_.size();
diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h
index 48b5e8f1ef16f5fe9c8d194c258a256baf68e5b3..647bbb82ea345d704eb4bbb8bb8a0690d444fb80 100644
--- a/src/runtime/thread_storage_scope.h
+++ b/src/runtime/thread_storage_scope.h
@@ -13,10 +13,47 @@
 namespace tvm {
 namespace runtime {
 
+/*!
+ * \brief Memory hierachy rank in the storage system
+ * \note The global rank and shared rank have one to one
+ *       correspondence to the thread rank.
+ */
+enum class StorageRank {
+  /*! \brief global memory */
+  kGlobal = 0,
+  /*! \brief shared memory among thread group */
+  kShared = 1,
+  /*!
+   * \brief reserved for warp memory.
+   *  This is only used by programming model.
+   *  There is no such memory usually in GPU.
+   *  Instead, we can simulate it by registers and shuffle.
+   */
+  kWarp = 2,
+  /*! \brief thread local memory */
+  kLocal = 3
+};
+
+/*!
+ * \param thread_scope_rank The thread scope rank
+ * \return default storage rank given the thread scope
+ */
+inline StorageRank DefaultStorageRank(int thread_scope_rank) {
+  switch (thread_scope_rank) {
+    case -1: return StorageRank::kGlobal;
+    case 0: return StorageRank::kShared;
+    case 1: return StorageRank::kLocal;
+    default: {
+      LOG(FATAL) << "unknown rank";
+      return StorageRank::kGlobal;
+    }
+  }
+}
+
 /*! \brief class to represent storage scope */
 struct StorageScope {
   /*! \brief The rank of the storage */
-  int rank{0};
+  StorageRank rank{StorageRank::kGlobal};
   /*! \brief tag for special purpose memory. */
   std::string tag;
   // comparator
@@ -29,9 +66,10 @@ struct StorageScope {
   inline std::string to_string() const {
     std::string ret;
     switch (rank) {
-      case 0: return "global" + tag;
-      case 1: return "shared" + tag;
-      case 2: return "local" + tag;
+      case StorageRank::kGlobal: return "global" + tag;
+      case StorageRank::kShared: return "shared" + tag;
+      case StorageRank::kWarp: return "warp" + tag;
+      case StorageRank::kLocal: return "local" + tag;
       default: LOG(FATAL) << "unknown storage scope"; return "";
     }
   }
@@ -43,13 +81,16 @@ struct StorageScope {
   static StorageScope make(const std::string& s) {
     StorageScope r;
     if (s.compare(0, 6, "global")  == 0) {
-      r.rank = 0;
+      r.rank = StorageRank::kGlobal;
       r.tag = s.substr(6, std::string::npos);
     } else if (s.compare(0, 6, "shared") == 0) {
-      r.rank = 1;
+      r.rank = StorageRank::kShared;
       r.tag = s.substr(6, std::string::npos);
+    } else if (s.compare(0, 4, "warp") == 0) {
+      r.rank = StorageRank::kWarp;
+      r.tag = s.substr(4, std::string::npos);
     } else if (s.compare(0, 5, "local") == 0) {
-      r.rank = 2;
+      r.rank = StorageRank::kLocal;
       r.tag = s.substr(5, std::string::npos);
     } else {
       LOG(FATAL) << "unknown storage scope " << s;
diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc
index 1a06970e52e4ae67699d0af1762ba9a1fe5fc515..908b579ec9a41d7d54cbeca748ce50b5aa9a168f 100644
--- a/src/schedule/bound.cc
+++ b/src/schedule/bound.cc
@@ -16,8 +16,9 @@
 namespace tvm {
 namespace schedule {
 
-using runtime::ThreadScope;
+using runtime::StorageRank;
 using runtime::StorageScope;
+using runtime::ThreadScope;
 
 /*! \brief The graph context used during bound inference. */
 struct GraphContext {
@@ -41,7 +42,7 @@ bool NeedRelax(const IterVar& iv,
   if (tag.length() == 0 || tag == "pipeline") {
     return !found_attach;
   }
-  return scope.rank <= ThreadScope::make(tag).rank;
+  return static_cast<int>(scope.rank) <= ThreadScope::make(tag).rank;
 }
 
 // infer storage scope, if not given
@@ -50,16 +51,17 @@ StorageScope InferStorageScope(
   if (stage->scope.length() != 0) {
     return StorageScope::make(stage->scope);
   }
-  int max_rank = 0;
+  int max_rank = -1;
   for (IterVar iv : ctx.attach_path.at(stage->op)) {
     auto it = ctx.bind_map.find(iv);
     const std::string& tag = (
         it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
     if (tag != "pipeline" && tag.length() != 0) {
-      max_rank = std::max(max_rank, ThreadScope::make(tag).rank + 1);
+      max_rank = std::max(max_rank, ThreadScope::make(tag).rank);
     }
   }
-  StorageScope s; s.rank = max_rank;
+  StorageScope s;
+  s.rank = runtime::DefaultStorageRank(max_rank);
   return s;
 }