diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index 610532e261a32850e4ba57f00caa0d505abcbb69..ad4872b8e4e066906f9671214081ac7ff75a55ca 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -51,8 +51,10 @@ class Buffer : public NodeRef { * \brief Get access ptr to the entire buffer. * \param access_mask The access mask * \param ptr_type The type of the pointer. + * \param content_lanes The number of lanes for the (data) type. */ - TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle()) const; + TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(), + int content_lanes = 1) const; /*! * \brief Create an Expr that does a vector load at begin index. * \param begin The beginning index diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 26be2de1a69a3dd20826d24c38bf56984ba5c7f4..6abe4aae2f6fc1471c88a2b9ed24bad63a3ece44 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -25,7 +25,7 @@ class Buffer(NodeBase): READ = 1 WRITE = 2 - def access_ptr(self, access_mask, ptr_type="handle"): + def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1): """Get an access pointer to the head of buffer. This is the recommended method to get buffer data @@ -41,6 +41,10 @@ class Buffer(NodeBase): The data type of the result pointer. Do not specify unless we want to cast pointer to specific type. + content_lanes: int, optional + The number of lanes for the data type. This value + is greater than one for vector types. + Examples -------- .. code-block:: python @@ -63,7 +67,8 @@ class Buffer(NodeBase): else: raise ValueError("Unknown access_mask %s" % access_mask) access_mask = mask - return _api_internal._BufferAccessPtr(self, access_mask, ptr_type) + return _api_internal._BufferAccessPtr(self, access_mask, ptr_type, + content_lanes) def vload(self, begin, dtype=None): """Generate an Expr that loads dtype from begin index. diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 85b7d92c6a25efd508dea211f739c3b1a5b2071b..94075b6ec059675a3fe2183350c919f42837abbe 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -159,7 +159,7 @@ TVM_REGISTER_API("_Buffer") TVM_REGISTER_API("_BufferAccessPtr") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Buffer() - .access_ptr(args[1], args[2]); + .access_ptr(args[1], args[2], args[3]); }); TVM_REGISTER_API("_BufferVLoad") diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index a4bb815b1b93b06f63af2fb86b0714f0ea766815..d274af73ed82919a012abcc72ec8b4a29c82cf15 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -509,6 +509,18 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr( return builder_->CreateInBoundsGEP(buffer, index); } +llvm::Value* CodeGenLLVM::CreateBufferVecPtr( + Type t, llvm::Value* buffer, llvm::Value* index) { + CHECK_GT(t.lanes(), 1); + llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType()); + CHECK(btype != nullptr); + llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace()); + if (btype != ptype) { + buffer = builder_->CreatePointerCast(buffer, ptype); + } + return builder_->CreateInBoundsGEP(buffer, index); +} + llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const { auto it = var_map_.find(v); CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint; @@ -572,10 +584,21 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { const Load *l = op->args[0].as<Load>(); CHECK(op->args.size() == 1 && l); - llvm::Value* ptr = CreateBufferPtr( - l->type, MakeValue(l->buffer_var), MakeValue(l->index)); - unsigned addrspace = llvm::dyn_cast<llvm::PointerType>( - ptr->getType())->getAddressSpace(); + const Ramp *r = l->index.as<Ramp>(); + llvm::Value* ptr; + unsigned addrspace; + if (!r) { + ptr = CreateBufferPtr( + l->type, MakeValue(l->buffer_var), MakeValue(l->index)); + addrspace = llvm::dyn_cast<llvm::PointerType>( + ptr->getType())->getAddressSpace(); + } else { + Expr index = r->base / make_const(Int(32), r->lanes); + ptr = CreateBufferVecPtr( + l->type, MakeValue(l->buffer_var), MakeValue(index)); + addrspace = llvm::dyn_cast<llvm::PointerType>( + ptr->getType())->getAddressSpace(); + } return builder_->CreatePointerCast(ptr, t_void_->getPointerTo(addrspace)); } else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) { return llvm::Constant::getNullValue(t_void_p_); diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index e4a0b24d381ad76371b5a7a1c5a98ed7842b3e67..fbc74f092825a187a8e0e82bfed292b1d166d748 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -191,6 +191,7 @@ class CodeGenLLVM : llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index); + llvm::Value* CreateBufferVecPtr(Type t, llvm::Value* buffer, llvm::Value* index); // Vector concatenation. llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); llvm::Value* CreateVecFlip(llvm::Value* vec); diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 34abada14118f8ab9fad781e52f891f95e21c62a..5cf7ddef3018e377407bb524c9532b6f50a807b1 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -341,14 +341,23 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const { 0); } -Expr Buffer::access_ptr(int access_mask, Type ptr_type) const { +Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const { const BufferNode* self = operator->(); - Expr e_dtype = make_zero(self->dtype); + Expr e_dtype; Expr extent = (self->strides.size() == self->shape.size() ? arith::ComputeExpr<ir::Mul>(self->strides[0], self->shape[0]): arith::ComputeReduce<ir::Mul>(self->shape)); + Expr elem_offset = self->elem_offset; + if (content_lanes > 1) { + e_dtype = make_zero(self->dtype.with_lanes(content_lanes)); + extent = extent / make_const(self->elem_offset.type(), content_lanes); + elem_offset = self->elem_offset / make_const(self->elem_offset.type(), + content_lanes); + } else { + e_dtype = make_zero(self->dtype); + } Array<Expr> acc_args{ - e_dtype, self->data, self->elem_offset, + e_dtype, self->data, elem_offset, extent, make_const(Int(32), access_mask)}; return ir::Call::make( ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic); diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index ae7a026c1ecb408ae1d8fc5a77fe28b3e5a2f62b..082d580a0e45cac81eb6b3b864fc0b5b65dccf35 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -102,6 +102,7 @@ inline Expr AddressOffset(Var handle, Type dtype, int offset) { inline Expr AddressOffset(Var handle, Type dtype, Expr offset) { if (dtype.lanes() != 1) { offset = offset * make_const(offset.type(), dtype.lanes()); + offset = Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes()); } return Call::make( Handle(), intrinsic::tvm_address_of,