From 160e41076c92ed3fe9af875c7c6eddba79672d02 Mon Sep 17 00:00:00 2001
From: Yizhi Liu <liuyizhi@apache.org>
Date: Sun, 23 Sep 2018 17:23:52 -0700
Subject: [PATCH] fix buffer elem_offset calculation (#1762)

---
 src/lang/buffer.cc                        | 14 +++++---------
 tests/python/unittest/test_lang_buffer.py |  9 +++++++++
 2 files changed, 14 insertions(+), 9 deletions(-)

diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc
index cb3194f8e..69967c55a 100644
--- a/src/lang/buffer.cc
+++ b/src/lang/buffer.cc
@@ -226,16 +226,12 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
   Expr base = n->elem_offset;
   if (n->strides.size() == 0) {
     CHECK_EQ(n->shape.size(), index.size());
-    if (n->shape.size() != 0) {
-      if (is_zero(base)) {
-        base = index[0];
-      } else {
-        base = base + index[0];
+    if (index.size() > 0) {
+      Expr offset = index[0];
+      for (size_t i = 1; i < index.size(); ++i) {
+        offset = MergeMulMod(offset * n->shape[i] + index[i]);
       }
-    }
-    base = MergeMulMod(base);
-    for (size_t i = 1; i < index.size(); ++i) {
-      base = MergeMulMod(base * n->shape[i] + index[i]);
+      base = base + offset;
     }
   } else {
     CHECK_EQ(n->strides.size(), index.size());
diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py
index a5a8f5d06..51f1e3abb 100644
--- a/tests/python/unittest/test_lang_buffer.py
+++ b/tests/python/unittest/test_lang_buffer.py
@@ -41,6 +41,14 @@ def test_buffer_access_ptr_offset():
     assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
     assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
 
+def test_buffer_vload():
+    m = tvm.var('m')
+    n = tvm.var('n')
+    Ab = tvm.decl_buffer((m, n), tvm.float32, elem_offset=100)
+    load = Ab.vload([2, 3])
+    offset = tvm.ir_pass.Simplify(load.index)
+    assert tvm.ir_pass.Equal(offset, n * 2 + 103)
+
 def test_buffer_index_merge_mult_mod():
     m = tvm.var('m')
     n = tvm.var('n')
@@ -76,4 +84,5 @@ if __name__ == "__main__":
     test_buffer()
     test_buffer_access_ptr()
     test_buffer_access_ptr_offset()
+    test_buffer_vload()
     test_buffer_index_merge_mult_mod()
-- 
GitLab