From a1dfb9ae1817e13c53cedae59d5128692e6b609f Mon Sep 17 00:00:00 2001
From: xqdan <danxiaoqiang@126.com>
Date: Tue, 30 Oct 2018 02:19:37 +0800
Subject: [PATCH] [PASS]unroll loops with extent=1 (#2027)

---
 src/pass/unroll_loop.cc                   |  4 +++-
 tests/python/unittest/test_pass_unroll.py | 17 +++++++++++++++++
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc
index 6c0ac5175..d4481e86c 100644
--- a/src/pass/unroll_loop.cc
+++ b/src/pass/unroll_loop.cc
@@ -76,7 +76,9 @@ class LoopUnroller : public IRMutator {
       normal_loop_depth_ += 1;
     }
 
-    if (auto_unroll && explicit_unroll_) {
+    if ((auto_unroll && explicit_unroll_) ||
+        // unroll loops with extent = 1, no matter how many steps in body
+        (value <= auto_max_extent_ && auto_max_extent_ == 1)) {
       return Unroll(op);
     } else {
       if (auto_unroll) {
diff --git a/tests/python/unittest/test_pass_unroll.py b/tests/python/unittest/test_pass_unroll.py
index dda3fdad1..68467b0c0 100644
--- a/tests/python/unittest/test_pass_unroll.py
+++ b/tests/python/unittest/test_pass_unroll.py
@@ -35,6 +35,23 @@ def test_unroll_loop():
     assert isinstance(ret.rest, tvm.stmt.For)
     assert ret.rest.for_type != tvm.stmt.For.Unrolled
 
+def test_unroll_fake_loop():
+    ib = tvm.ir_builder.create()
+    dtype = 'int32'
+    n = tvm.var('n')
+    Ab = tvm.decl_buffer((n, ), dtype)
+    Aptr = ib.buffer_ptr(Ab)
+    # for i in 0 to n-1:
+    with ib.for_range(0, 1, name="i") as i:
+        Aptr[i*2] = 3
+        with ib.for_range(0, 10, name="j") as j:
+            Aptr[j + 1] = Aptr[i] + 1
+
+    stmt = ib.get()
+    ret = tvm.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
+    assert isinstance(ret.first, tvm.stmt.Store)
+
 
 if __name__ == "__main__":
     test_unroll_loop()
+    test_unroll_fake_loop()
\ No newline at end of file
-- 
GitLab