diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index 9b8995bf5516ef1d1ab93a4c52d12b9c6c2a7176..fa42cefa07a7fd6f049dfeed4db59aa8c7c0500c 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -44,7 +44,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { if (info.alignment > 16) { info.alignment = 16; } - if (info.scope.rank == 2) { + if (info.scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = builder_->CreateAlloca( @@ -54,7 +54,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { } buf = alloca; } else { - CHECK_EQ(info.scope.rank, 1) + CHECK(info.scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; diff --git a/src/codegen/llvm/codegen_nvptx.cc b/src/codegen/llvm/codegen_nvptx.cc index c0002873d5fc3a0447c0201fea6fc1697824cc96..d354e3b9eaf01da03637c1596c21edc0f59f5ed0 100644 --- a/src/codegen/llvm/codegen_nvptx.cc +++ b/src/codegen/llvm/codegen_nvptx.cc @@ -47,7 +47,7 @@ class CodeGenNVPTX : public CodeGenLLVM { if (info.alignment > 16) { info.alignment = 16; } - if (info.scope.rank == 2) { + if (info.scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = builder_->CreateAlloca( @@ -57,7 +57,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } buf = alloca; } else { - CHECK_EQ(info.scope.rank, 1) + CHECK(info.scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 4d7a9b21ba5b5971dea9afb1c3d717df539eb172..d844a7b1139095b0e2f01f55351f8eaec04d0772 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -561,13 +561,13 @@ void CodeGenSPIRV::VisitStmt_(const Allocate* op) { spirv::Value buf; StorageInfo& info = storage_info_[op->buffer_var.get()]; spirv::SType etype = builder_->GetSType(op->type); - if (info.scope.rank == 2) { + if (info.scope.rank == runtime::StorageRank::kLocal) { buf = builder_->Allocate( etype, static_cast<uint32_t>(constant_size), spv::StorageClassFunction); } else { // shared memory - CHECK_EQ(info.scope.rank, 1) + CHECK(info.scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory buf = builder_->Allocate( diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index 9211f3f71de07a47fa5d357964c098ba224da700..09be1a53da42bd1398f70147acf0830c906f2ff6 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -210,7 +210,8 @@ void StorageAccessVisitor::Visit_(const Call* op) { StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const { auto it = storage_scope_.find(buf); - StorageScope s; s.rank = 0; + StorageScope s; + s.rank = StorageRank::kGlobal; if (it == storage_scope_.end()) return s; return it->second; } diff --git a/src/pass/storage_access.h b/src/pass/storage_access.h index 7268bb6683421b1d3b5130321e1ad19b19b2d374..4f313f8e7c24d6ed95edb40451ade06b7866f8a2 100644 --- a/src/pass/storage_access.h +++ b/src/pass/storage_access.h @@ -17,6 +17,7 @@ namespace tvm { namespace ir { using runtime::StorageScope; +using runtime::StorageRank; /*! * \brief Base class of storage access analysis */ diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 94332ff5cb7e0a27b86a2073090f5d0f655e7dca..f5cb98495ff9ce6b59b495bbc6986d418ce41197 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -23,6 +23,7 @@ namespace tvm { namespace ir { using HalideIR::Internal::Region; +using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; using intrinsic::tvm_address_of; @@ -141,7 +142,8 @@ class StorageFlattener : public IRMutator { const std::string& strkey = it->second; if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { - skey.rank = curr_thread_scope_.back().rank + 1; + skey.rank = runtime::DefaultStorageRank( + curr_thread_scope_.back().rank); } } else { skey = StorageScope::make(strkey); diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 0a8782366193c714543c0b3e1c92ff4661693891..998df034e5a1522422dc3def723557324a9da996 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -19,6 +19,7 @@ namespace tvm { namespace ir { +using runtime::StorageRank; using runtime::StorageScope; // Find a linear pattern of storage acess @@ -794,7 +795,7 @@ class StoragePlanRewriter : public IRMutator { // disable reuse of small arrays, they will be lowered to registers in LLVM // This rules only apply if we are using non special memory if (scope.tag.length() == 0) { - if (scope.rank > 1 || op->type.is_handle()) { + if (scope.rank >= StorageRank::kWarp || op->type.is_handle()) { return NewAlloc(op, attach_scope, scope, const_nbits); } if (const_nbits > 0 && const_nbits <= 32) { @@ -853,7 +854,8 @@ class StoragePlanRewriter : public IRMutator { // This rules only apply if we are using non special memory if (e->scope.tag.length() == 0) { // Disable sharing of local memory. - if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return; + if (e->scope.rank >= StorageRank::kWarp || + e->allocs[0]->type.is_handle()) return; // disable reuse of small arrays if (e->const_nbits > 0 && e->const_nbits <= 32) return; } diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index af3dc1f128e5e49c8fe704ce65ee091505b5d035..6e2d1020a6b539a54da0316dfae6ea86b40dc883 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -189,7 +189,7 @@ class ThreadSyncInserter : public IRMutator { if (syncs_.size() == 0) return stmt; if (syncs_.count(stmt.get())) { Stmt barrier; - if (sync_scope_.rank == 0) { + if (sync_scope_.rank == StorageRank::kGlobal) { barrier = MakeGlobalBarrier(); } else { barrier = Evaluate::make( @@ -206,15 +206,15 @@ class ThreadSyncInserter : public IRMutator { return stmt; } Expr Mutate_(const Load* op, const Expr& e) final { - if (sync_scope_.rank == 0 && - GetScope(op->buffer_var.get()).rank == 0) { + if (sync_scope_.rank == StorageRank::kGlobal && + GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].read_count; } return IRMutator::Mutate_(op, e); } Stmt Mutate_(const Store* op, const Stmt& s) final { - if (sync_scope_.rank == 0 && - GetScope(op->buffer_var.get()).rank == 0) { + if (sync_scope_.rank == StorageRank::kGlobal && + GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].write_count; } return IRMutator::Mutate_(op, s); @@ -228,7 +228,7 @@ class ThreadSyncInserter : public IRMutator { thread_extents_.pop_back(); std::swap(temp, in_thread_env_); // first thread scope. - if (!in_thread_env_ && sync_scope_.rank == 0) { + if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) { ret = InitGlobalBarrier(ret.as<AttrStmt>()); num_blocks_ = Expr(); is_lead_ = Expr(); @@ -253,7 +253,8 @@ class ThreadSyncInserter : public IRMutator { // Get current storage scope. StorageScope GetScope(const Variable* buf) const { auto it = storage_scope_.find(buf); - StorageScope s; s.rank = 0; + StorageScope s; + s.rank = StorageRank::kGlobal; if (it == storage_scope_.end()) return s; return it->second; } @@ -279,7 +280,7 @@ class ThreadSyncInserter : public IRMutator { return Block::make(prep, body); } Stmt MakeGlobalBarrier() { - CHECK_EQ(sync_scope_.rank, 0); + CHECK(sync_scope_.rank == StorageRank::kGlobal); if (!num_blocks_.defined()) { CHECK(!is_lead_.defined()); num_work_dim_ = thread_extents_.size(); diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 48b5e8f1ef16f5fe9c8d194c258a256baf68e5b3..647bbb82ea345d704eb4bbb8bb8a0690d444fb80 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -13,10 +13,47 @@ namespace tvm { namespace runtime { +/*! + * \brief Memory hierachy rank in the storage system + * \note The global rank and shared rank have one to one + * correspondence to the thread rank. + */ +enum class StorageRank { + /*! \brief global memory */ + kGlobal = 0, + /*! \brief shared memory among thread group */ + kShared = 1, + /*! + * \brief reserved for warp memory. + * This is only used by programming model. + * There is no such memory usually in GPU. + * Instead, we can simulate it by registers and shuffle. + */ + kWarp = 2, + /*! \brief thread local memory */ + kLocal = 3 +}; + +/*! + * \param thread_scope_rank The thread scope rank + * \return default storage rank given the thread scope + */ +inline StorageRank DefaultStorageRank(int thread_scope_rank) { + switch (thread_scope_rank) { + case -1: return StorageRank::kGlobal; + case 0: return StorageRank::kShared; + case 1: return StorageRank::kLocal; + default: { + LOG(FATAL) << "unknown rank"; + return StorageRank::kGlobal; + } + } +} + /*! \brief class to represent storage scope */ struct StorageScope { /*! \brief The rank of the storage */ - int rank{0}; + StorageRank rank{StorageRank::kGlobal}; /*! \brief tag for special purpose memory. */ std::string tag; // comparator @@ -29,9 +66,10 @@ struct StorageScope { inline std::string to_string() const { std::string ret; switch (rank) { - case 0: return "global" + tag; - case 1: return "shared" + tag; - case 2: return "local" + tag; + case StorageRank::kGlobal: return "global" + tag; + case StorageRank::kShared: return "shared" + tag; + case StorageRank::kWarp: return "warp" + tag; + case StorageRank::kLocal: return "local" + tag; default: LOG(FATAL) << "unknown storage scope"; return ""; } } @@ -43,13 +81,16 @@ struct StorageScope { static StorageScope make(const std::string& s) { StorageScope r; if (s.compare(0, 6, "global") == 0) { - r.rank = 0; + r.rank = StorageRank::kGlobal; r.tag = s.substr(6, std::string::npos); } else if (s.compare(0, 6, "shared") == 0) { - r.rank = 1; + r.rank = StorageRank::kShared; r.tag = s.substr(6, std::string::npos); + } else if (s.compare(0, 4, "warp") == 0) { + r.rank = StorageRank::kWarp; + r.tag = s.substr(4, std::string::npos); } else if (s.compare(0, 5, "local") == 0) { - r.rank = 2; + r.rank = StorageRank::kLocal; r.tag = s.substr(5, std::string::npos); } else { LOG(FATAL) << "unknown storage scope " << s; diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 1a06970e52e4ae67699d0af1762ba9a1fe5fc515..908b579ec9a41d7d54cbeca748ce50b5aa9a168f 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -16,8 +16,9 @@ namespace tvm { namespace schedule { -using runtime::ThreadScope; +using runtime::StorageRank; using runtime::StorageScope; +using runtime::ThreadScope; /*! \brief The graph context used during bound inference. */ struct GraphContext { @@ -41,7 +42,7 @@ bool NeedRelax(const IterVar& iv, if (tag.length() == 0 || tag == "pipeline") { return !found_attach; } - return scope.rank <= ThreadScope::make(tag).rank; + return static_cast<int>(scope.rank) <= ThreadScope::make(tag).rank; } // infer storage scope, if not given @@ -50,16 +51,17 @@ StorageScope InferStorageScope( if (stage->scope.length() != 0) { return StorageScope::make(stage->scope); } - int max_rank = 0; + int max_rank = -1; for (IterVar iv : ctx.attach_path.at(stage->op)) { auto it = ctx.bind_map.find(iv); const std::string& tag = ( it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag); if (tag != "pipeline" && tag.length() != 0) { - max_rank = std::max(max_rank, ThreadScope::make(tag).rank + 1); + max_rank = std::max(max_rank, ThreadScope::make(tag).rank); } } - StorageScope s; s.rank = max_rank; + StorageScope s; + s.rank = runtime::DefaultStorageRank(max_rank); return s; }