From a1dfb9ae1817e13c53cedae59d5128692e6b609f Mon Sep 17 00:00:00 2001 From: xqdan <danxiaoqiang@126.com> Date: Tue, 30 Oct 2018 02:19:37 +0800 Subject: [PATCH] [PASS]unroll loops with extent=1 (#2027) --- src/pass/unroll_loop.cc | 4 +++- tests/python/unittest/test_pass_unroll.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index 6c0ac5175..d4481e86c 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -76,7 +76,9 @@ class LoopUnroller : public IRMutator { normal_loop_depth_ += 1; } - if (auto_unroll && explicit_unroll_) { + if ((auto_unroll && explicit_unroll_) || + // unroll loops with extent = 1, no matter how many steps in body + (value <= auto_max_extent_ && auto_max_extent_ == 1)) { return Unroll(op); } else { if (auto_unroll) { diff --git a/tests/python/unittest/test_pass_unroll.py b/tests/python/unittest/test_pass_unroll.py index dda3fdad1..68467b0c0 100644 --- a/tests/python/unittest/test_pass_unroll.py +++ b/tests/python/unittest/test_pass_unroll.py @@ -35,6 +35,23 @@ def test_unroll_loop(): assert isinstance(ret.rest, tvm.stmt.For) assert ret.rest.for_type != tvm.stmt.For.Unrolled +def test_unroll_fake_loop(): + ib = tvm.ir_builder.create() + dtype = 'int32' + n = tvm.var('n') + Ab = tvm.decl_buffer((n, ), dtype) + Aptr = ib.buffer_ptr(Ab) + # for i in 0 to n-1: + with ib.for_range(0, 1, name="i") as i: + Aptr[i*2] = 3 + with ib.for_range(0, 10, name="j") as j: + Aptr[j + 1] = Aptr[i] + 1 + + stmt = ib.get() + ret = tvm.ir_pass.UnrollLoop(stmt, 8, 0, 1, True) + assert isinstance(ret.first, tvm.stmt.Store) + if __name__ == "__main__": test_unroll_loop() + test_unroll_fake_loop() \ No newline at end of file -- GitLab