diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index 72b5753adbcf8dda1bff729ae5fddef1e77847c7..07f59cb1b308d45a08f529caf5fc3235edafecc8 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -229,7 +229,8 @@ class VTInjector : public IRMutator { if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(s, true); } else if (!allow_share_ && !vt_loop_injected_ && - op->attr_key == attr::coproc_uop_scope) { + (op->attr_key == attr::coproc_uop_scope || + op->attr_key == attr::coproc_scope)) { return InjectVTLoop(s, true); } else { Stmt body = Mutate(op->body); diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index fdf692782523f1b44fd1fa6f377cb43940525b7d..a3a60aaac4d1484832d695f1fa4fc948ab8047aa 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -7,6 +7,7 @@ */ #include <tvm/ir_pass.h> #include <tvm/ir_mutator.h> +#include "./ir_util.h" namespace tvm { namespace ir { @@ -57,41 +58,16 @@ class AttrScopeLifter : public IRMutator { } Stmt Mutate_(const Block* op, const Stmt& s) final { - Stmt first = this->Mutate(op->first); - NodeRef first_node_; - Expr first_value_; - std::swap(first_node_, attr_node_); - std::swap(first_value_, attr_value_); - Stmt rest = this->Mutate(op->rest); - if (attr_node_.defined() && - attr_value_.defined() && - first_node_.defined() && - first_value_.defined() && - attr_node_.same_as(first_node_) && - attr_value_.same_as(first_value_)) { - if (first.same_as(op->first) && rest.same_as(op->rest)) { - return s; - } else { - return Block::make(first, rest); - } - } else { - if (first_node_.defined()) { - first = AttrStmt::make( - first_node_, attr_key_, first_value_, first); - } - if (attr_node_.defined()) { - rest = AttrStmt::make( - attr_node_, attr_key_, attr_value_, rest); - // undefine them - attr_node_ = NodeRef(); - attr_value_ = Expr(); - } - if (first.same_as(op->first) && rest.same_as(op->rest)) { - return s; - } else { - return Block::make(first, rest); - } + std::vector<Stmt> seq; + FlattenSeq(op->first, &seq); + FlattenSeq(op->rest, &seq); + seq = MutateSeq(seq); + if (seq.size() == 2 && + seq[0].same_as(op->first) && + seq[1].same_as(op->rest)) { + return s; } + return MergeSeq(seq); } Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { @@ -99,17 +75,17 @@ class AttrScopeLifter : public IRMutator { return IRMutator::Mutate_(op, s); } Stmt then_case = this->Mutate(op->then_case); - NodeRef first_node_; - Expr first_value_; - std::swap(first_node_, attr_node_); - std::swap(first_value_, attr_value_); + NodeRef first_node; + Expr first_value; + std::swap(first_node, attr_node_); + std::swap(first_value, attr_value_); Stmt else_case = this->Mutate(op->else_case); if (attr_node_.defined() && attr_value_.defined() && - first_node_.defined() && - first_value_.defined() && - attr_node_.same_as(first_node_) && - attr_value_.same_as(first_value_)) { + first_node.defined() && + first_value.defined() && + attr_node_.same_as(first_node) && + ValueSame(attr_value_, first_value)) { if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return s; @@ -117,9 +93,9 @@ class AttrScopeLifter : public IRMutator { return IfThenElse::make(op->condition, then_case, else_case); } } else { - if (first_node_.defined()) { + if (first_node.defined()) { then_case = AttrStmt::make( - first_node_, attr_key_, first_value_, then_case); + first_node, attr_key_, first_value, then_case); } if (attr_node_.defined()) { else_case = AttrStmt::make( @@ -138,6 +114,82 @@ class AttrScopeLifter : public IRMutator { } private: + void FlattenSeq(Stmt s, std::vector<Stmt>* res) { + if (const Block* op = s.as<Block>()) { + FlattenSeq(op->first, res); + FlattenSeq(op->rest, res); + } else if (const ProducerConsumer* op = s.as<ProducerConsumer>()) { + if (!op->is_producer) { + FlattenSeq(op->body, res); + } else { + res->emplace_back(s); + } + } else { + res->emplace_back(s); + } + } + + std::vector<Stmt> MutateSeq(const std::vector<Stmt>& seq) { + std::vector<Stmt> res_seq; + NodeRef curr_node; + Expr curr_value; + Stmt curr_stmt; + for (const Stmt & stmt : seq) { + attr_node_ = NodeRef(); + attr_value_ = Expr(); + Stmt rest = this->Mutate(stmt); + if (attr_node_.defined() && + attr_value_.defined() && + curr_node.defined() && + curr_value.defined() && + attr_node_.same_as(curr_node) && + ValueSame(attr_value_, curr_value)) { + curr_stmt = Block::make(curr_stmt, rest); + } else { + if (curr_stmt.defined()) { + if (curr_node.defined()) { + curr_stmt = AttrStmt::make( + curr_node, attr_key_, curr_value, curr_stmt); + } + res_seq.push_back(curr_stmt); + } + curr_stmt = rest; + curr_node = attr_node_; + curr_value = attr_value_; + } + } + + if (curr_stmt.defined()) { + // keep attr_node_, attr_node_ + if (res_seq.size() == 0) { + return {curr_stmt}; + } + if (curr_node.defined()) { + curr_stmt = AttrStmt::make( + curr_node, attr_key_, curr_value, curr_stmt); + } + res_seq.push_back(curr_stmt); + // reset + attr_node_ = NodeRef(); + attr_value_ = Expr(); + } + return res_seq; + } + + // value comparison that also compares content of int constant + static bool ValueSame(const Expr& a, const Expr& b) { + if (a.same_as(b)) return true; + if (a->type_key() != b->type_key()) return false; + if (a.type() != b.type()) return false; + if (const IntImm* op = a.as<IntImm>()) { + return op->value == b.as<IntImm>()->value; + } + if (const UIntImm* op = a.as<UIntImm>()) { + return op->value == b.as<UIntImm>()->value; + } + return false; + } + std::string attr_key_; NodeRef attr_node_; Expr attr_value_;