From 38274115a9b1f12a453b7c6dfcab1d0351328d3e Mon Sep 17 00:00:00 2001 From: libing4752 <libing475211023@sjtu.edu.cn> Date: Thu, 8 Mar 2018 12:04:21 +0800 Subject: [PATCH] enhance access_ptr that args can support Expr (#970) --- include/tvm/buffer.h | 2 +- python/tvm/schedule.py | 29 +++++++++++++++++++++-- src/lang/buffer.cc | 2 +- tests/python/unittest/test_lang_buffer.py | 9 +++++++ 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index d737341e1..8b04bf550 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -55,7 +55,7 @@ class Buffer : public NodeRef { * \param offset The offset of ptr. */ TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(), - int content_lanes = 1, int offset = 0) const; + int content_lanes = 1, Expr offset = make_const(Int(32), 0)) const; /*! * \brief Create an Expr that does a vector load at begin index. * \param begin The beginning index diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index b04945292..236570c24 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -2,12 +2,34 @@ from __future__ import absolute_import as _abs from ._ffi.base import string_types from ._ffi.node import NodeBase, register_node -from ._ffi.function import _init_api +from ._ffi.node import convert_to_node as _convert_to_node +from ._ffi.function import _init_api, Function +from ._ffi.function import convert_to_tvm_func as _convert_tvm_func from . import _api_internal from . import tensor as _tensor from . import expr as _expr from . import container as _container +def convert(value): + """Convert value to TVM node or function. + + Parameters + ---------- + value : python value + + Returns + ------- + tvm_val : Node or Function + Converted value in TVM + """ + if isinstance(value, (Function, NodeBase)): + return value + + if callable(value): + return _convert_tvm_func(value) + + return _convert_to_node(value) + @register_node class Buffer(NodeBase): """Symbolic data buffer in TVM. @@ -45,7 +67,7 @@ class Buffer(NodeBase): The number of lanes for the data type. This value is greater than one for vector types. - offset: int, optional + offset: Expr, optional The offset of pointer. We can use it to offset by the number of elements from the address of ptr. @@ -60,6 +82,8 @@ class Buffer(NodeBase): buffer.access_ptr(Buffer.READ | Buffer.WRITE) # Get access ptr for read/write with str flag buffer.access_ptr("rw") + # Get access ptr for read with offset + buffer.access_ptr("r", offset = 100) """ if isinstance(access_mask, string_types): mask = 0 @@ -71,6 +95,7 @@ class Buffer(NodeBase): else: raise ValueError("Unknown access_mask %s" % access_mask) access_mask = mask + offset = convert(offset) return _api_internal._BufferAccessPtr(self, access_mask, ptr_type, content_lanes, offset) diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 07e455e25..39566df45 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -335,7 +335,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const { 0); } -Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, int offset) const { +Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const { const BufferNode* self = operator->(); Expr e_dtype; Expr extent; diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py index fe0f1f0b7..a5a8f5d06 100644 --- a/tests/python/unittest/test_lang_buffer.py +++ b/tests/python/unittest/test_lang_buffer.py @@ -31,6 +31,15 @@ def test_buffer_access_ptr_offset(): offset = tvm.ir_pass.Simplify(aptr.args[2]) assert tvm.ir_pass.Equal(offset, 100) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE + v = tvm.var('int32') + aptr = Ab.access_ptr("rw", offset=100 + 100 + v) + offset = tvm.ir_pass.Simplify(aptr.args[2]) + assert tvm.ir_pass.Equal(offset, 200 + v) + assert aptr.args[4].value == Buffer.READ | Buffer.WRITE + aptr = Ab.access_ptr("rw", offset=tvm.call_extern('int32', "test_call", 100 + 100 + v)) + offset = tvm.ir_pass.Simplify(aptr.args[2]) + assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v)) + assert aptr.args[4].value == Buffer.READ | Buffer.WRITE def test_buffer_index_merge_mult_mod(): m = tvm.var('m') -- GitLab