diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 1ac3887a489c9dc2a96b10a0d7f372859d641e73..4b499465240564ec1360f040cdab725515369232 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -234,10 +234,10 @@ Stmt InjectPrefetch(Stmt stmt); /*! * \brief Inject double buffer into stmt. * \param stmt The statment to be transformed. - * \param split_loop Whether split the loop containing double buffering. + * \param split_loop Loop splitting factor. * \return Transformed stmt. */ -Stmt InjectDoubleBuffer(Stmt stmt, bool split_loop); +Stmt InjectDoubleBuffer(Stmt stmt, int split_loop); /*! * \brief Rewrite storage allocation pattern. diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 2dfaec00b7ba8b45f56d3ad3dcb507047c2deb46..d08aa8ad0e854f9e1825f84fccb56b32905048a5 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -33,7 +33,7 @@ class BuildConfig(object): "offset_factor": 0, "data_alignment": -1, "restricted_func": True, - "double_buffer_split_loop": True, + "double_buffer_split_loop": 1, "add_lower_pass": None } def __init__(self, **kwargs): @@ -99,9 +99,10 @@ def build_config(**kwargs): not to overlap. This enables more optimization. Corresponds to restricted keyword in C99 - double_buffer_split_loop: bool, default=True - Whether split the loop containing double buffer so - that the buffer fetching won't contain condition. + double_buffer_split_loop: int, default=2 + Whether split the loop with factor. If it is zero, no splitting will happen. + It it is bigger than one, the logic will do a split with factor equals the integer + and unroll the inner loop. This allows the buffer fetching won't contain condition. add_lower_pass: list of tuiple (phase, function(Stmt->Stmt)), default=None phase contains an integer on which optimization pass we apply the pass. diff --git a/src/pass/inject_double_buffer.cc b/src/pass/inject_double_buffer.cc index 55dea73e5d0a63fbac524384af11d58cecc53ea1..c660730ed8375f8c6771a5044193363efb746344 100644 --- a/src/pass/inject_double_buffer.cc +++ b/src/pass/inject_double_buffer.cc @@ -34,9 +34,21 @@ class DoubleBufferDetector : public IRVisitor { std::unordered_set<const Variable*> touched_; }; + +class StripDoubleBufferWrite : public IRMutator { + public: + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + if (op->attr_key == attr::double_buffer_write) { + return Mutate(op->body); + } else { + return IRMutator::Mutate_(op, s); + } + } +}; + class DoubleBufferInjector : public IRMutator { public: - explicit DoubleBufferInjector(bool split_loop) + explicit DoubleBufferInjector(int split_loop) : split_loop_(split_loop) {} Stmt Inject(const Stmt& stmt) { @@ -97,17 +109,38 @@ class DoubleBufferInjector : public IRMutator { auto it = loop_pre_.find(op); if (it != loop_pre_.end()) { const For* old_loop = stmt.as<For>(); - if (split_loop_) { + if (split_loop_ != 0) { + // Explicitly unroll the loop + CHECK(split_loop_ % 2 == 0 || split_loop_ == 1) + << "It is better to split with multiple of 2"; + CHECK(is_zero(old_loop->min)); + Expr zero = old_loop->min; Expr new_ext = arith::ComputeExpr<Sub>( old_loop->extent, make_const(old_loop->loop_var.type(), 1)); - Stmt loop = For::make( - old_loop->loop_var, old_loop->min, new_ext, - old_loop->for_type, old_loop->device_api, - old_loop->body); + Expr factor = make_const(new_ext.type(), split_loop_); + Expr outer_ext = arith::ComputeExpr<Div>(new_ext, factor); + Expr tail_base = arith::ComputeExpr<Mul>(outer_ext, factor); + Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type()); std::unordered_map<const Variable*, Expr> vmap; - vmap[old_loop->loop_var.get()] = new_ext; - Stmt end = Substitute(old_loop->body, vmap); - stmt = Block::make(loop, end); + std::vector<Stmt> loop_seq; + for (size_t i = 0; i < split_loop_; ++i) { + vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.type(), i); + loop_seq.emplace_back(Substitute(old_loop->body, vmap)); + } + Stmt loop = For::make( + outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, + MergeSeq(loop_seq)); + // tail + std::vector<Stmt> tail_seq; + Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop->body); + for (size_t i = 0; i < split_loop_; ++i) { + Expr idx = tail_base + make_const(tail_base.type(), i); + vmap[old_loop->loop_var.get()] = idx; + tail_seq.emplace_back( + IfThenElse::make(idx < old_loop->extent, + Substitute(tail_body, vmap))); + } + stmt = Block::make(loop, MergeSeq(tail_seq)); } stmt = Block::make(MergeSeq(it->second), stmt); } @@ -205,7 +238,7 @@ class DoubleBufferInjector : public IRMutator { std::string scope; }; // Whether split loop - bool split_loop_; + int split_loop_; // Whether we are inside double buffer scope. bool in_double_buffer_scope_{false}; // The current loop next @@ -219,7 +252,7 @@ class DoubleBufferInjector : public IRMutator { }; -Stmt InjectDoubleBuffer(Stmt stmt, bool split_loop) { +Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) { return DoubleBufferInjector(split_loop).Inject(stmt); } } // namespace ir diff --git a/tests/python/unittest/test_pass_inject_double_buffer.py b/tests/python/unittest/test_pass_inject_double_buffer.py index 133ba7f7e17e92caff150e032bbbdea2bd52df52..3136e33197ecad0ff32f8fdfb388b5e05924d735 100644 --- a/tests/python/unittest/test_pass_inject_double_buffer.py +++ b/tests/python/unittest/test_pass_inject_double_buffer.py @@ -19,7 +19,7 @@ def test_double_buffer(): C[j] = B[j] + 1 stmt = ib.get() - stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, True) + stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, 2) stmt = tvm.ir_pass.Simplify(stmt) assert isinstance(stmt.body.body, tvm.stmt.Allocate) assert stmt.body.body.extents[0].value == 2 @@ -30,7 +30,7 @@ def test_double_buffer(): if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync": count[0] += 1 tvm.ir_pass.PostOrderVisit(f.body, count_sync) - assert count[0] == 2 + assert count[0] == 4 if __name__ == "__main__":