From 160e41076c92ed3fe9af875c7c6eddba79672d02 Mon Sep 17 00:00:00 2001 From: Yizhi Liu <liuyizhi@apache.org> Date: Sun, 23 Sep 2018 17:23:52 -0700 Subject: [PATCH] fix buffer elem_offset calculation (#1762) --- src/lang/buffer.cc | 14 +++++--------- tests/python/unittest/test_lang_buffer.py | 9 +++++++++ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index cb3194f8e..69967c55a 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -226,16 +226,12 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) { Expr base = n->elem_offset; if (n->strides.size() == 0) { CHECK_EQ(n->shape.size(), index.size()); - if (n->shape.size() != 0) { - if (is_zero(base)) { - base = index[0]; - } else { - base = base + index[0]; + if (index.size() > 0) { + Expr offset = index[0]; + for (size_t i = 1; i < index.size(); ++i) { + offset = MergeMulMod(offset * n->shape[i] + index[i]); } - } - base = MergeMulMod(base); - for (size_t i = 1; i < index.size(); ++i) { - base = MergeMulMod(base * n->shape[i] + index[i]); + base = base + offset; } } else { CHECK_EQ(n->strides.size(), index.size()); diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py index a5a8f5d06..51f1e3abb 100644 --- a/tests/python/unittest/test_lang_buffer.py +++ b/tests/python/unittest/test_lang_buffer.py @@ -41,6 +41,14 @@ def test_buffer_access_ptr_offset(): assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v)) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE +def test_buffer_vload(): + m = tvm.var('m') + n = tvm.var('n') + Ab = tvm.decl_buffer((m, n), tvm.float32, elem_offset=100) + load = Ab.vload([2, 3]) + offset = tvm.ir_pass.Simplify(load.index) + assert tvm.ir_pass.Equal(offset, n * 2 + 103) + def test_buffer_index_merge_mult_mod(): m = tvm.var('m') n = tvm.var('n') @@ -76,4 +84,5 @@ if __name__ == "__main__": test_buffer() test_buffer_access_ptr() test_buffer_access_ptr_offset() + test_buffer_vload() test_buffer_index_merge_mult_mod() -- GitLab