diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 5c22fe27bb2cf71952412096789ca6c362358333..d6a258053e11b299361b4c7d84e87241d38f529a 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -49,6 +49,30 @@ struct Reduce : public ExprNode<Reduce> { static constexpr const char* Min = "Min"; }; +/*! \brief namespace of possible attribute sin AttrStmt.type_key */ +namespace attr { +/*! + * \brief Mark scope of iteration variable, used by Schedule. + */ +constexpr const char* scope = "scope"; +/*! + * \brief Mark launching extent of thread, used by device API. + */ +constexpr const char* thread_extent = "thread_extent"; +/*! + * \brief Mark launching of a virtual thread. + */ +constexpr const char* virtual_thread = "virtual_thread"; +/*! + * \brief Mark storage scope of buffers + */ +constexpr const char* storage_scope = "storage_scope"; +/*! + * \brief Mark storage scope of realizations + */ +constexpr const char* realize_scope = "realize_scope"; +} // namespace attr + /*! \brief namespace of TVM Intrinsic functions */ namespace intrinsic { // Most of the intrinsics is to enab diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 649ea0239ef148182a21e8947ecbcbf5727a3a06..eea6a3343f37bba5061d9f7f0aba0ea6cfefb45a 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -63,6 +63,7 @@ class IRMutator { virtual Stmt Mutate_(const Store* op, const Stmt& s); virtual Stmt Mutate_(const Free* op, const Stmt& s); virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s); + virtual Stmt Mutate_(const Block* op, const Stmt& s); virtual Expr Mutate_(const Call* op, const Expr& e); virtual Expr Mutate_(const Load* op, const Expr& s); virtual Expr Mutate_(const Variable* op, const Expr& e); diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 4ce90e3b7739ae722e8c8021dca1cfa89b6763be..ce67c9cd3a94c292f21754cabcba10192a6669ef 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -100,6 +100,7 @@ Stmt Inline(Stmt stmt, * \param stmt The stmt to be trasnformed. * \param extern_buffer Map specifies external * buffer assignment of input and outputs. + * \return Transformed stmt. */ Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer); @@ -108,15 +109,34 @@ Stmt StorageFlatten(Stmt stmt, * \brief unroll the constant loops * \param stmt The statment to be unrolled. * \param max_auto_step The maximum step to stop performing automatic unrolling. + * \return Transformed stmt. */ Stmt UnrollLoop(Stmt stmt, int max_auto_step); /*! * \brief vectorize the constant loops * \param stmt The statment to be vectorized. + * \return Transformed stmt. */ Stmt VectorizeLoop(Stmt stmt); +/*! + * \brief Inject virtual thread loops into stmt. + * \param stmt The statment to be transformed. + * \return Transformed stmt. + */ +Stmt InjectVirtualThread(Stmt stmt); + +/*! + * \brief Lift storage allocation to relevant outpost location + * + * Only do this after vectorization and virtual thread injection completes. + * + * \param stmt The stmt to be trasnformed + * \return Transformed stmt. + */ +Stmt LiftAllocate(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * diff --git a/python/tvm/build.py b/python/tvm/build.py index 7ddabaf7631d0c8e4539428374b994cb4fdde958..4704efe76face8266e8d86e4ae9d0b6820f76210 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -70,6 +70,8 @@ def build(sch, stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.VectorizeLoop(stmt) + stmt = ir_pass.InjectVirtualThread(stmt) + stmt = ir_pass.LiftAllocate(stmt) stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step) stmt = ir_pass.Simplify(stmt) fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list)) diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index b8f3cbc3bd9e643641706dda73cde271e08adc52..1192dc25dd76c52a9171c9c703b0896dc54bfda0 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -67,6 +67,8 @@ REGISTER_PASS2(UnrollLoop); REGISTER_PASS2(StorageSync); REGISTER_PASS4(MakeAPI); REGISTER_PASS1(SplitHostDevice); +REGISTER_PASS1(LiftAllocate); +REGISTER_PASS1(InjectVirtualThread); } // namespace ir } // namespace tvm diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 1bc28374418fb1ef603bafda4193b3c4fc6dfcb1..8ae8ed47e0d511f2309a82cf6ffbb6df15b2ff66 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -288,7 +288,8 @@ class Canonical::Internal : public IRMutator { } // AttrStmt Stmt Mutate_(const AttrStmt* op, const Stmt& s) { - if (op->type_key == "thread_extent") { + if (op->type_key == attr::thread_extent || + op->type_key == attr::virtual_thread) { ++level_counter_; IterVar iv(op->node.node_); CHECK_NE(iv->thread_tag.length(), 0U); diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 52acf8ead92576b152c2afdf4ae44de7a2420118..c89faeb6c209202e237bf8aa94e333c85eaef950 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -743,7 +743,7 @@ void CodeGenC::PrintStmt(const Allocate* op) { } void CodeGenC::PrintStmt(const AttrStmt* op) { - if (op->type_key == "scope") { + if (op->type_key == ir::attr::thread_extent) { IterVar iv(op->node.node_); if (iv->thread_tag.length() != 0) { if (!var_idmap_.count(iv->var.get())) { @@ -756,7 +756,7 @@ void CodeGenC::PrintStmt(const AttrStmt* op) { stream << ";\n"; } } - } else if (op->type_key == "storage_scope") { + } else if (op->type_key == ir::attr::storage_scope) { const Variable* v = op->node.as<Variable>(); CHECK(v); alloc_storage_scope_[v] = op->value.as<StringImm>()->value; diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index deee2aa2fa3632a2385cd2e355478ee85c26d204..c4c5d99f35ad3e52d9fcbb576c2a62c62438432a 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -9,6 +9,7 @@ #include <string> #include "./codegen_cuda.h" #include "./codegen_stack_vm.h" +#include "../arithmetic/compute_expr.h" #include "../runtime/cuda/cuda_common.h" #include "../runtime/cuda/cuda_module.h" @@ -22,6 +23,17 @@ std::string CodeGenCUDA::Compile( return CodeGenC::Compile(f, output_ssa); } +void CodeGenCUDA::PrintStmt(const ir::For* op) { + int ext; + CHECK(is_zero(op->min)); + if (arith::GetConstInt(op->extent, &ext) && + ext <= max_auto_unroll_) { + PrintIndent(); + stream << "#pragma unroll\n"; + } + CodeGenC::PrintStmt(op); +} + void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 478faf76b74cc094b7e307cdf765c3a610e5f924..428f9ffddd2e4ed29722a602da2bb39b70f566b4 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -27,6 +27,7 @@ class CodeGenCUDA : public CodeGenC { bool output_ssa); // override behavior + void PrintStmt(const ir::For* op) final; void PrintStorageSync(const std::string& sync) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp( @@ -37,6 +38,11 @@ class CodeGenCUDA : public CodeGenC { const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*) void PrintVecElemStore( const std::string& vec, Type t, int i, const std::string& value) final; + + private: + // magic number to add pragma unroll to it. + // used to generate code that is compact but still unrolls. + int max_auto_unroll_{8}; }; } // namespace codegen diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc new file mode 100644 index 0000000000000000000000000000000000000000..0a9f5b38ff62ee40f588151abea327f2ecf05d92 --- /dev/null +++ b/src/pass/inject_virtual_thread.cc @@ -0,0 +1,419 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file inject_virtual_thread.cc + */ +#include <tvm/ir.h> +#include <tvm/ir_visitor.h> +#include <tvm/ir_mutator.h> +#include <tvm/ir_pass.h> +#include <unordered_set> +#include "../arithmetic/compute_expr.h" + +namespace tvm { +namespace ir { + +// If expression is touched by var. +class ExprTouched : public IRVisitor { + public: + explicit ExprTouched(const std::unordered_set<const Variable*> &touched) + : touched_var_(touched) {} + void Visit(const NodeRef& n) final { + // early stopping + if (expr_touched_) return; + IRVisitor::Visit(n); + } + void Visit_(const Load *op) final { + HandleUseVar(op->buffer_var.get()); + IRVisitor::Visit_(op); + } + void Visit_(const Variable *op) final { + HandleUseVar(op); + } + void HandleUseVar(const Variable* var) { + auto it = touched_var_.find(var); + if (it != touched_var_.end()) { + expr_touched_ = true; + } + // rember the used vars + // in case the var get touched later in a loop. + if (!expr_touched_) { + used_vars_.push_back(var); + } + } + // the fields. + bool expr_touched_{false}; + std::vector<const Variable*> used_vars_; + const std::unordered_set<const Variable*>& touched_var_; +}; + +// Analyze if the buffers are invariant to value of var +class VarTouchedAnalysis : public IRVisitor { + public: + void Visit_(const LetStmt *op) { + ExprTouched tc(touched_var_); + tc.Visit(op->value); + Record(op->var.get(), tc); + this->Visit(op->body); + } + void Visit_(const Store *op) { + ExprTouched tc(touched_var_); + tc.Visit(op->value); + tc.Visit(op->index); + Record(op->buffer_var.get(), tc); + } + void Visit_(const For *op) { + ExprTouched tc(touched_var_); + tc.Visit(op->min); + tc.Visit(op->extent); + Record(op->loop_var.get(), tc); + this->Visit(op->body); + } + void Visit_(const Allocate *op) { + ExprTouched tc(touched_var_); + for (size_t i = 0; i < op->extents.size(); ++i) { + tc.Visit(op->extents[i]); + } + tc.Visit(op->condition); + if (op->new_expr.defined()) { + tc.Visit(op->new_expr); + } + Record(op->buffer_var.get(), tc); + this->Visit(op->body); + } + void Record(const Variable* var, + const ExprTouched& tc) { + if (touched_var_.count(var)) return; + if (tc.expr_touched_) { + touched_var_.insert(var); + } else { + for (const Variable* r : tc.used_vars_) { + affect_[r].push_back(var); + } + } + } + + std::unordered_set<const Variable*> + TouchedVar(const Stmt& stmt, + const Variable* var) { + touched_var_.insert(var); + this->Visit(stmt); + // do a DFS to push affect around dependency. + std::vector<const Variable*> pending( + touched_var_.begin(), touched_var_.end()); + while (!pending.empty()) { + const Variable* v = pending.back(); + pending.pop_back(); + for (const Variable* r : affect_[v]) { + if (!touched_var_.count(r)) { + touched_var_.insert(r); + pending.push_back(r); + } + } + } + return std::move(touched_var_); + } + + private: + // Whether variable is touched by the thread variable. + std::unordered_set<const Variable*> touched_var_; + // x -> all the buffers x read from + std::unordered_map<const Variable*, + std::vector<const Variable*> > affect_; +}; + + +// Inject virtual thread loop +// rewrite the buffer access pattern when necessary. +class VTInjector : public IRMutator { + public: + using IRMutator::Mutate; + // constructor + VTInjector(Var var, + int num_threads, + std::unordered_set<const Variable*> touched_var) + : var_(var), num_threads_(num_threads), touched_var_(touched_var) { + } + // Inject VTLoop when needed. + Stmt Mutate(Stmt stmt) final { + CHECK(!visit_touched_var_) + << stmt->type_key() << stmt; + stmt = IRMutator::Mutate(stmt); + if (visit_touched_var_) { + if (!vt_loop_injected_) return InjectVTLoop(stmt, false); + visit_touched_var_ = false; + } + return stmt; + } + // Variable + Expr Mutate_(const Variable *op, const Expr& e) final { + if (touched_var_.count(op)) { + visit_touched_var_ = true; + } + return e; + } + Expr RewriteIndex(Expr index, Expr alloc_extent) const { + if (index_rewrite_strategy_ == 0) { + return index * num_threads_ + var_; + } else { + return index + var_ * alloc_extent; + } + } + // Load + Expr Mutate_(const Load* op, const Expr& e) final { + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as<Load>(); + if (touched_var_.count(op->buffer_var.get())) { + visit_touched_var_ = true; + } + auto it = touched_alloc_.find(op->buffer_var.get()); + if (it != touched_alloc_.end()) { + return Load::make(op->type, op->buffer_var, + RewriteIndex(op->index, it->second)); + } else { + return expr; + } + } + // Store + Stmt Mutate_(const Store* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<Store>(); + if (touched_var_.count(op->buffer_var.get())) { + visit_touched_var_ = true; + } + auto it = touched_alloc_.find(op->buffer_var.get()); + if (it != touched_alloc_.end()) { + return Store::make(op->buffer_var, + op->value, + RewriteIndex(op->index, it->second)); + } else { + return stmt; + } + } + // Attribute + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + if (op->type_key == attr::scope) { + return Mutate(op->body); + } else { + Expr value = Mutate(op->value); + if (visit_touched_var_) { + return InjectVTLoop(s, true); + } else { + Stmt body = Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return s; + } else { + return AttrStmt::make(op->node, op->type_key, value, body); + } + } + } + } + // LetStmt + Stmt Mutate_(const LetStmt* op, const Stmt& s) final { + Expr value = this->Mutate(op->value); + if (visit_touched_var_ && !vt_loop_injected_) { + return InjectVTLoop(s, true); + } + visit_touched_var_ = false; + Stmt body = Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return s; + } else { + return LetStmt::make(op->var, value, body); + } + } + // For + Stmt Mutate_(const For* op, const Stmt& s) final { + CHECK(is_zero(op->min)); + Expr extent = Mutate(op->extent); + if (visit_touched_var_ && !vt_loop_injected_) { + Stmt stmt = InjectVTLoop(s, true); + ++max_loop_depth_; + return stmt; + } + visit_touched_var_ = false; + Stmt body = Mutate(op->body); + ++max_loop_depth_; + if (extent.same_as(op->extent) && + body.same_as(op->body)) { + return s; + } else { + return For::make( + op->loop_var, op->min, extent, op->for_type, op->device_api, body); + } + } + // IfThenElse + Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { + Expr condition = this->Mutate(op->condition); + if (visit_touched_var_ && !vt_loop_injected_) { + return InjectVTLoop(s, true); + } + visit_touched_var_ = false; + CHECK_EQ(max_loop_depth_, 0); + Stmt then_case = this->Mutate(op->then_case); + Stmt else_case; + if (else_case.defined()) { + int temp = max_loop_depth_; + max_loop_depth_ = 0; + else_case = this->Mutate(op->else_case); + max_loop_depth_ = std::max(temp, max_loop_depth_); + } + if (condition.same_as(op->condition) && + then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return s; + } else { + return IfThenElse::make(condition, then_case, else_case); + } + } + // Block + Stmt Mutate_(const Block* op, const Stmt& s) final { + CHECK_EQ(max_loop_depth_, 0); + Stmt first = this->Mutate(op->first); + int temp = max_loop_depth_; + max_loop_depth_ = 0; + Stmt rest = this->Mutate(op->rest); + max_loop_depth_ = std::max(max_loop_depth_, temp); + if (first.same_as(op->first) && + rest.same_as(op->rest)) { + return s; + } else { + return Block::make(first, rest); + } + } + // Allocate + Stmt Mutate_(const Allocate* op, const Stmt& s) final { + if (op->new_expr.defined() && !vt_loop_injected_) { + return InjectVTLoop(s, true); + } + Expr condition = Mutate(op->condition); + if (visit_touched_var_ && !vt_loop_injected_) { + return InjectVTLoop(s, true); + } + + bool changed = false; + Array<Expr> extents; + for (size_t i = 0; i < op->extents.size(); i++) { + Expr new_ext = Mutate(op->extents[i]); + if (visit_touched_var_ && !vt_loop_injected_) { + return InjectVTLoop(s, true); + } + if (!new_ext.same_as(op->extents[i])) changed = true; + extents.push_back(new_ext); + } + visit_touched_var_ = false; + + Stmt body; + if (touched_var_.count(op->buffer_var.get())) { + // place v on highest dimension. + Expr stride = extents[0]; + for (size_t i = 1; i < extents.size(); ++i) { + stride = arith::ComputeExpr<Mul>(stride, extents[i]); + } + Array<Expr> other; + other.push_back(num_threads_); + for (Expr e : extents) { + other.push_back(e); + } + extents = other; + changed = true; + // mark this buffer get touched. + touched_alloc_[op->buffer_var.get()] = stride; + // Mutate the body. + body = Mutate(op->body); + } else { + // Mutate the body. + body = Mutate(op->body); + } + if (!changed && + body.same_as(op->body) && + condition.same_as(op->condition)) { + return s; + } else { + return Allocate::make( + op->buffer_var, op->type, + extents, condition, body, + op->new_expr, op->free_function); + } + } + + // inject vthread loop + Stmt InjectVTLoop(Stmt stmt, bool before_mutation) { + CHECK(!vt_loop_injected_); + // reset the flags + visit_touched_var_ = false; + vt_loop_injected_ = true; + if (before_mutation) { + stmt = this->Mutate(stmt); + } + // reset the flags after processing. + vt_loop_injected_ = false; + visit_touched_var_ = false; + if (max_loop_depth_ == 0) { + // do unrolling if it is inside innermost content. + Stmt blk = Substitute(stmt, {{var_, make_zero(var_.type())}}); + for (int i = 1; i < num_threads_; ++i) { + blk = Block::make( + blk, Substitute(stmt, {{var_, make_const(var_.type(), i)}})); + } + return blk; + } else { + // insert a for loop + Var idx(var_->name_hint + ".s", var_->type); + stmt = Substitute(stmt, {{var_, idx}}); + return For::make(idx, 0, num_threads_, + ForType::Serial, DeviceAPI::None, stmt); + } + } + + private: + // vthread variable + Var var_; + // the threads/lanes + int num_threads_; + // Index rewriting strategy + int index_rewrite_strategy_{1}; + // whethe the loop is already injected. + bool vt_loop_injected_{false}; + // whether current expression get touched. + bool visit_touched_var_{false}; + // the counter of loops in after mutation. + int max_loop_depth_{0}; + // The variables that get touched. + std::unordered_set<const Variable*> touched_var_; + // The allocations that get touched -> extent + std::unordered_map<const Variable*, Expr> touched_alloc_; +}; + + +class VirtualThreadInjector : public IRMutator { + public: + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<AttrStmt>(); + if (op->type_key == attr::virtual_thread) { + IterVar iv(op->node.node_); + int nthread = static_cast<int>(op->value.as<IntImm>()->value); + VarTouchedAnalysis vs; + auto touched = vs.TouchedVar(op->body, iv->var.get()); + VTInjector injecter(iv->var, nthread, touched); + return injecter.Mutate(op->body); + } else { + return stmt; + } + } + + Stmt Mutate_(const Provide* op, const Stmt& s) final { + LOG(FATAL) << "Need to call StorageFlatten first"; + return s; + } +}; + +Stmt InjectVirtualThread(Stmt stmt) { + stmt = VirtualThreadInjector().Mutate(stmt); + return ConvertSSA(stmt); +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index 3c39b6c50afcd937a57e671d59d5735a660fb9b9..f10c9c089f1d9e0a104bc4f63168f790e7e91e65 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -77,6 +77,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) .DISPATCH_TO_MUTATE_STMT(IfThenElse) .DISPATCH_TO_MUTATE_STMT(For) .DISPATCH_TO_MUTATE_STMT(Allocate) +.DISPATCH_TO_MUTATE_STMT(Block) .DISPATCH_TO_MUTATE_STMT(Free); Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { @@ -212,6 +213,17 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { } } +Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) { + Stmt first = this->Mutate(op->first); + Stmt rest = this->Mutate(op->rest); + if (first.same_as(op->first) && + rest.same_as(op->rest)) { + return s; + } else { + return Block::make(first, rest); + } +} + TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .DISPATCH_TO_MUTATE_EXPR(Call) .DISPATCH_TO_MUTATE_EXPR(Let) @@ -370,16 +382,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) return ProducerConsumer::make(op->func, op->is_producer, body); } }) -.set_dispatch<Block>([](const Block *op, const Stmt& s, IRMutator* m) { - Stmt first = m->Mutate(op->first); - Stmt rest = m->Mutate(op->rest); - if (first.same_as(op->first) && - rest.same_as(op->rest)) { - return s; - } else { - return Block::make(first, rest); - } - }) .set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) { Expr v = m->Mutate(op->value); if (v.same_as(op->value)) { diff --git a/src/pass/lift_allocate.cc b/src/pass/lift_allocate.cc new file mode 100644 index 0000000000000000000000000000000000000000..2aa08607221e91676813222245129dcb572b3435 --- /dev/null +++ b/src/pass/lift_allocate.cc @@ -0,0 +1,96 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file lift_allocate.cc + */ +#include <tvm/ir.h> +#include <tvm/ir_mutator.h> +#include <tvm/ir_pass.h> +#include <unordered_map> +#include "./ir_util.h" +#include "../runtime/thread_storage_scope.h" + +namespace tvm { +namespace ir { + +using runtime::StorageScope; +using runtime::ThreadScope; + +class AllocateLifter : public IRMutator { + public: + Stmt Lift(Stmt stmt) { + stmt = this->Mutate(stmt); + StorageScope key; key.rank = 0; + stmt = MergeNest(allocs_[key], stmt); + return stmt; + } + + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + CHECK(op->type_key != attr::virtual_thread) + << "InjectVirtualThread before LiftStorageAlloc"; + if (op->type_key == attr::storage_scope) { + StorageScope sc = StorageScope::make(op->value.as<StringImm>()->value); + allocs_[sc].emplace_back( + AttrStmt::make( + op->node, attr::storage_scope, + op->value, Evaluate::make(0))); + storage_scope_[op->node.get()] = sc; + return this->Mutate(op->body); + } else if (op->type_key == attr::thread_extent) { + IterVar iv(op->node.node_); + ThreadScope ts = ThreadScope::make(iv->thread_tag); + curr_thread_scope_.push_back(ts); + Stmt stmt = IRMutator::Mutate_(op, s); + curr_thread_scope_.pop_back(); + op = stmt.as<AttrStmt>(); + + bool first_scope = true; + for (const ThreadScope& t : curr_thread_scope_) { + if (t.rank == ts.rank) first_scope = false; + } + if (first_scope) { + StorageScope key; + key.rank = ts.rank + 1; + std::vector<Stmt>& vec = allocs_[key]; + if (vec.size() != 0) { + Stmt body = MergeNest(vec, op->body); + vec.clear(); + return AttrStmt::make( + op->node, op->type_key, op->value, body); + } + } + return stmt; + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const For* op, const Stmt& s) final { + CHECK(op->for_type != ForType::Vectorized) + << "VectorizeLoop before LiftStorageAlloc"; + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Allocate* op, const Stmt& s) final { + auto it = storage_scope_.find(op->buffer_var.get()); + CHECK(it != storage_scope_.end()); + allocs_[it->second].emplace_back( + Allocate::make( + op->buffer_var, op->type, op->extents, op->condition, + Evaluate::make(0))); + return this->Mutate(op->body); + } + + private: + // storage scope of internal allocation. + std::unordered_map<const Node*, StorageScope> storage_scope_; + // The current thread scope. + std::vector<ThreadScope> curr_thread_scope_; + // The allocations by rank + std::unordered_map<StorageScope, std::vector<Stmt> > allocs_; +}; + +Stmt LiftAllocate(Stmt stmt) { + return AllocateLifter().Mutate(stmt); +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 5cfb8f9c5c2a2b5aa0f18d4f34b6755e48ee3f0f..944a8c0a496dbe3b188e405b711f642eb943fcc9 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -6,7 +6,6 @@ #include <tvm/ir_mutator.h> #include <tvm/ir_pass.h> #include <unordered_map> -#include "./ir_util.h" #include "../runtime/thread_storage_scope.h" namespace tvm { @@ -61,46 +60,17 @@ class StorageFlattener : public IRMutator { } } - Stmt Flatten(Stmt stmt) { - stmt = this->Mutate(stmt); - StorageScope key; key.rank = 0; - if (move_alloc_out_) { - StorageScope key; key.rank = 0; - stmt = MergeNest(allocs_[key], stmt); - } - return stmt; - } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { - if (op->type_key == "realize_scope") { + if (op->type_key == attr::realize_scope) { storage_scope_[op->node.get()] = op->value.as<StringImm>()->value; return this->Mutate(op->body); - } else if (op->type_key == "scope") { + } else if (op->type_key == attr::thread_extent) { IterVar iv(op->node.node_); - if (iv->thread_tag.length() != 0) { - ThreadScope ts = ThreadScope::make(iv->thread_tag); - curr_thread_scope_.push_back(ts); - Stmt stmt = IRMutator::Mutate_(op, s); - curr_thread_scope_.pop_back(); - op = stmt.as<AttrStmt>(); - - bool first_scope = true; - for (const ThreadScope& t : curr_thread_scope_) { - if (t.rank == ts.rank) first_scope = false; - } - if (first_scope && move_alloc_out_) { - StorageScope key; - key.rank = ts.rank + 1; - std::vector<Stmt>& vec = allocs_[key]; - if (vec.size() != 0) { - Stmt body = MergeNest(vec, op->body); - vec.clear(); - return AttrStmt::make( - op->node, op->type_key, op->value, body); - } - } - return stmt; - } + ThreadScope ts = ThreadScope::make(iv->thread_tag); + curr_thread_scope_.push_back(ts); + Stmt stmt = IRMutator::Mutate_(op, s); + curr_thread_scope_.pop_back(); + return stmt; } return IRMutator::Mutate_(op, s); } @@ -140,37 +110,22 @@ class StorageFlattener : public IRMutator { // deduce current storage scope. auto it = storage_scope_.find(op->func.get()); CHECK(it != storage_scope_.end()); - StorageScope key; key.rank = 0; - const std::string& skey = it->second; - if (skey.length() == 0) { + StorageScope skey; + const std::string& strkey = it->second; + if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { - key.rank = curr_thread_scope_.back().rank + 1; + skey.rank = curr_thread_scope_.back().rank + 1; } } else { - key = StorageScope::make(skey); - } - - if (move_alloc_out_) { - allocs_[key].push_back( - AttrStmt::make( - e.buffer->data, "storage_scope", - StringImm::make(key.to_string()), - Evaluate::make(0))); - allocs_[key].push_back( - Allocate::make( - e.buffer->data, e.buffer->dtype, e.buffer->shape, - make_const(Bool(e.buffer->dtype.lanes()), true), - Evaluate::make(0))); - return body; - } else { - Stmt ret = Allocate::make( - e.buffer->data, e.buffer->dtype, e.buffer->shape, - make_const(Bool(e.buffer->dtype.lanes()), true), body); - ret = AttrStmt::make( - e.buffer->data, "storage_scope", - StringImm::make(key.to_string()), ret); - return ret; + skey = StorageScope::make(strkey); } + Stmt ret = Allocate::make( + e.buffer->data, e.buffer->dtype, e.buffer->shape, + make_const(Bool(e.buffer->dtype.lanes()), true), body); + ret = AttrStmt::make( + e.buffer->data, attr::storage_scope, + StringImm::make(skey.to_string()), ret); + return ret; } } @@ -217,20 +172,16 @@ class StorageFlattener : public IRMutator { } } }; - // whether move allocation to the outmost scope as possible. - bool move_alloc_out_{true}; // The buffer assignment map std::unordered_map<TensorKey, BufferEntry> buf_map_; std::unordered_map<const Node*, std::string> storage_scope_; // The current thread scope. std::vector<ThreadScope> curr_thread_scope_; - // The allocations by rank - std::unordered_map<StorageScope, std::vector<Stmt> > allocs_; }; Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer) { - stmt = StorageFlattener(extern_buffer).Flatten(stmt); + stmt = StorageFlattener(extern_buffer).Mutate(stmt); return stmt; } diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 436fe015ad81a4bcdf771da11d1a386d82fcbe15..da623567bbbc0fc587cd411dad4e2875c46a84a8 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -62,7 +62,11 @@ struct ThreadScope { */ static ThreadScope make(const std::string& s) { ThreadScope r; - if (s.compare(0, 9, "blockIdx.") == 0) { + if (s == "vthread") { + // virtual thread at the same level as local + r.rank = 1; + r.dim_index = -1; + } else if (s.compare(0, 9, "blockIdx.") == 0) { r.rank = 0; r.dim_index = static_cast<int>(s[9] - 'x'); } else if (s.compare(0, 10, "threadIdx.") == 0) { diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 05e6e009e50a7a91a9765b3d8340fbb1c9910c35..ed4ad7011a4e32d84a3fd003bc168b18ba999670 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -203,18 +203,27 @@ MakeLoopNest(const Stage& sch, nest[i + 1].emplace_back( LetStmt::make(var, new_value, no_op)); } + } else if (iv->thread_tag == "vthread") { + // virtual thread + // Always restrict threaded IterVar to starts from 0. + CHECK(is_zero(dom->min)); + CHECK(is_positive_const(dom->extent)); + // annotate the extent of the IterVar + nest[i + 1].emplace_back( + AttrStmt::make(iv, ir::attr::virtual_thread, dom->extent, no_op)); + value_map[iv] = var; } else { // Always restrict threaded IterVar to starts from 0. CHECK(is_zero(dom->min)); // annotate the extent of the IterVar nest[i + 1].emplace_back( - AttrStmt::make(iv, "thread_extent", dom->extent, no_op)); + AttrStmt::make(iv, ir::attr::thread_extent, dom->extent, no_op)); value_map[iv] = var; } if (!reduce_init_loop) { // annotate the extent of the IterVar nest[i + 1].emplace_back( - AttrStmt::make(iv, "scope", iv->var, no_op)); + AttrStmt::make(iv, ir::attr::scope, iv->var, no_op)); } } // message passing to get offset of root iter vars. diff --git a/tests/python/unittest/test_pass_virtual_thread.py b/tests/python/unittest/test_pass_virtual_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..abda94ac720b3aecd950a141103a2426f3314a89 --- /dev/null +++ b/tests/python/unittest/test_pass_virtual_thread.py @@ -0,0 +1,28 @@ +import tvm + +def test_virtual_thread(): + m = tvm.Var('m') + A = tvm.placeholder((m, ), name='A') + A1 = tvm.compute((m,), lambda i: A[i], name='A1') + A2 = tvm.compute((m,), lambda i: A1[i] + 3, name='A2') + + s = tvm.Schedule(A2.op) + + vx = tvm.IterVar((0, 2), "vx", thread_tag="vthread") + xo, xi = s[A2].split(A2.op.axis[0], outer=vx) + xo, xi = s[A2].split(xi, 8) + s[A1].compute_at(s[A2], xo) + + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.collections.Map) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + Ab = tvm.Buffer(A.shape, A.dtype, name='A') + A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}) + stmt = tvm.ir_pass.Simplify(stmt) + stmt = tvm.ir_pass.InjectVirtualThread(stmt) + print(stmt) + +if __name__ == "__main__": + test_virtual_thread()