From 38274115a9b1f12a453b7c6dfcab1d0351328d3e Mon Sep 17 00:00:00 2001
From: libing4752 <libing475211023@sjtu.edu.cn>
Date: Thu, 8 Mar 2018 12:04:21 +0800
Subject: [PATCH] enhance access_ptr that args can support Expr (#970)

---
 include/tvm/buffer.h                      |  2 +-
 python/tvm/schedule.py                    | 29 +++++++++++++++++++++--
 src/lang/buffer.cc                        |  2 +-
 tests/python/unittest/test_lang_buffer.py |  9 +++++++
 4 files changed, 38 insertions(+), 4 deletions(-)

diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h
index d737341e1..8b04bf550 100644
--- a/include/tvm/buffer.h
+++ b/include/tvm/buffer.h
@@ -55,7 +55,7 @@ class Buffer : public NodeRef {
    * \param offset The offset of ptr.
    */
   TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
-                          int content_lanes = 1, int offset = 0) const;
+                          int content_lanes = 1, Expr offset = make_const(Int(32), 0)) const;
   /*!
    * \brief Create an Expr that does a vector load at begin index.
    * \param begin The beginning index
diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py
index b04945292..236570c24 100644
--- a/python/tvm/schedule.py
+++ b/python/tvm/schedule.py
@@ -2,12 +2,34 @@
 from __future__ import absolute_import as _abs
 from ._ffi.base import string_types
 from ._ffi.node import NodeBase, register_node
-from ._ffi.function import _init_api
+from ._ffi.node import convert_to_node as _convert_to_node
+from ._ffi.function import _init_api, Function
+from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
 from . import _api_internal
 from . import tensor as _tensor
 from . import expr as _expr
 from . import container as _container
 
+def convert(value):
+    """Convert value to TVM node or function.
+
+    Parameters
+    ----------
+    value : python value
+
+    Returns
+    -------
+    tvm_val : Node or Function
+        Converted value in TVM
+    """
+    if isinstance(value, (Function, NodeBase)):
+        return value
+
+    if callable(value):
+        return _convert_tvm_func(value)
+
+    return _convert_to_node(value)
+
 @register_node
 class Buffer(NodeBase):
     """Symbolic data buffer in TVM.
@@ -45,7 +67,7 @@ class Buffer(NodeBase):
             The number of lanes for the data type. This value
             is greater than one for vector types.
 
-        offset: int, optional
+        offset: Expr, optional
             The offset of pointer. We can use it to offset by
             the number of elements from the address of ptr.
 
@@ -60,6 +82,8 @@ class Buffer(NodeBase):
           buffer.access_ptr(Buffer.READ | Buffer.WRITE)
           # Get access ptr for read/write with str flag
           buffer.access_ptr("rw")
+          # Get access ptr for read with offset
+          buffer.access_ptr("r", offset = 100)
         """
         if isinstance(access_mask, string_types):
             mask = 0
@@ -71,6 +95,7 @@ class Buffer(NodeBase):
                 else:
                     raise ValueError("Unknown access_mask %s" % access_mask)
             access_mask = mask
+        offset = convert(offset)
         return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
                                               content_lanes, offset)
 
diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc
index 07e455e25..39566df45 100644
--- a/src/lang/buffer.cc
+++ b/src/lang/buffer.cc
@@ -335,7 +335,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
                           0);
 }
 
-Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, int offset) const {
+Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const {
   const BufferNode* self = operator->();
   Expr e_dtype;
   Expr extent;
diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py
index fe0f1f0b7..a5a8f5d06 100644
--- a/tests/python/unittest/test_lang_buffer.py
+++ b/tests/python/unittest/test_lang_buffer.py
@@ -31,6 +31,15 @@ def test_buffer_access_ptr_offset():
     offset = tvm.ir_pass.Simplify(aptr.args[2])
     assert tvm.ir_pass.Equal(offset, 100)
     assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
+    v = tvm.var('int32')
+    aptr = Ab.access_ptr("rw", offset=100 + 100 + v)
+    offset = tvm.ir_pass.Simplify(aptr.args[2])
+    assert tvm.ir_pass.Equal(offset, 200 + v)
+    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
+    aptr = Ab.access_ptr("rw", offset=tvm.call_extern('int32', "test_call", 100 + 100 + v))
+    offset = tvm.ir_pass.Simplify(aptr.args[2])
+    assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
+    assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
 
 def test_buffer_index_merge_mult_mod():
     m = tvm.var('m')
-- 
GitLab