Commit cedd3900 by eqy Committed by Tianqi Chen

Support vector operations for AMD (llvm IR) (#623)

* Support vector operations for AMD (llvm IR)

* fix whitespace

* update comments, docstring
parent 25847a4f
...@@ -51,8 +51,10 @@ class Buffer : public NodeRef { ...@@ -51,8 +51,10 @@ class Buffer : public NodeRef {
* \brief Get access ptr to the entire buffer. * \brief Get access ptr to the entire buffer.
* \param access_mask The access mask * \param access_mask The access mask
* \param ptr_type The type of the pointer. * \param ptr_type The type of the pointer.
* \param content_lanes The number of lanes for the (data) type.
*/ */
TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle()) const; TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
int content_lanes = 1) const;
/*! /*!
* \brief Create an Expr that does a vector load at begin index. * \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index * \param begin The beginning index
......
...@@ -25,7 +25,7 @@ class Buffer(NodeBase): ...@@ -25,7 +25,7 @@ class Buffer(NodeBase):
READ = 1 READ = 1
WRITE = 2 WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle"): def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1):
"""Get an access pointer to the head of buffer. """Get an access pointer to the head of buffer.
This is the recommended method to get buffer data This is the recommended method to get buffer data
...@@ -41,6 +41,10 @@ class Buffer(NodeBase): ...@@ -41,6 +41,10 @@ class Buffer(NodeBase):
The data type of the result pointer. Do not specify The data type of the result pointer. Do not specify
unless we want to cast pointer to specific type. unless we want to cast pointer to specific type.
content_lanes: int, optional
The number of lanes for the data type. This value
is greater than one for vector types.
Examples Examples
-------- --------
.. code-block:: python .. code-block:: python
...@@ -63,7 +67,8 @@ class Buffer(NodeBase): ...@@ -63,7 +67,8 @@ class Buffer(NodeBase):
else: else:
raise ValueError("Unknown access_mask %s" % access_mask) raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask access_mask = mask
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type) return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
content_lanes)
def vload(self, begin, dtype=None): def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index. """Generate an Expr that loads dtype from begin index.
......
...@@ -159,7 +159,7 @@ TVM_REGISTER_API("_Buffer") ...@@ -159,7 +159,7 @@ TVM_REGISTER_API("_Buffer")
TVM_REGISTER_API("_BufferAccessPtr") TVM_REGISTER_API("_BufferAccessPtr")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer() *ret = args[0].operator Buffer()
.access_ptr(args[1], args[2]); .access_ptr(args[1], args[2], args[3]);
}); });
TVM_REGISTER_API("_BufferVLoad") TVM_REGISTER_API("_BufferVLoad")
......
...@@ -509,6 +509,18 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr( ...@@ -509,6 +509,18 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr(
return builder_->CreateInBoundsGEP(buffer, index); return builder_->CreateInBoundsGEP(buffer, index);
} }
llvm::Value* CodeGenLLVM::CreateBufferVecPtr(
Type t, llvm::Value* buffer, llvm::Value* index) {
CHECK_GT(t.lanes(), 1);
llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
CHECK(btype != nullptr);
llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
if (btype != ptype) {
buffer = builder_->CreatePointerCast(buffer, ptype);
}
return builder_->CreateInBoundsGEP(buffer, index);
}
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const { llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
auto it = var_map_.find(v); auto it = var_map_.find(v);
CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint; CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
...@@ -572,10 +584,21 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { ...@@ -572,10 +584,21 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) { } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>(); const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l); CHECK(op->args.size() == 1 && l);
llvm::Value* ptr = CreateBufferPtr( const Ramp *r = l->index.as<Ramp>();
l->type, MakeValue(l->buffer_var), MakeValue(l->index)); llvm::Value* ptr;
unsigned addrspace = llvm::dyn_cast<llvm::PointerType>( unsigned addrspace;
ptr->getType())->getAddressSpace(); if (!r) {
ptr = CreateBufferPtr(
l->type, MakeValue(l->buffer_var), MakeValue(l->index));
addrspace = llvm::dyn_cast<llvm::PointerType>(
ptr->getType())->getAddressSpace();
} else {
Expr index = r->base / make_const(Int(32), r->lanes);
ptr = CreateBufferVecPtr(
l->type, MakeValue(l->buffer_var), MakeValue(index));
addrspace = llvm::dyn_cast<llvm::PointerType>(
ptr->getType())->getAddressSpace();
}
return builder_->CreatePointerCast(ptr, t_void_->getPointerTo(addrspace)); return builder_->CreatePointerCast(ptr, t_void_->getPointerTo(addrspace));
} else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) { } else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
return llvm::Constant::getNullValue(t_void_p_); return llvm::Constant::getNullValue(t_void_p_);
......
...@@ -191,6 +191,7 @@ class CodeGenLLVM : ...@@ -191,6 +191,7 @@ class CodeGenLLVM :
llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index); llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateBufferVecPtr(Type t, llvm::Value* buffer, llvm::Value* index);
// Vector concatenation. // Vector concatenation.
llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
llvm::Value* CreateVecFlip(llvm::Value* vec); llvm::Value* CreateVecFlip(llvm::Value* vec);
......
...@@ -341,14 +341,23 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const { ...@@ -341,14 +341,23 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
0); 0);
} }
Expr Buffer::access_ptr(int access_mask, Type ptr_type) const { Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const {
const BufferNode* self = operator->(); const BufferNode* self = operator->();
Expr e_dtype = make_zero(self->dtype); Expr e_dtype;
Expr extent = (self->strides.size() == self->shape.size() ? Expr extent = (self->strides.size() == self->shape.size() ?
arith::ComputeExpr<ir::Mul>(self->strides[0], self->shape[0]): arith::ComputeExpr<ir::Mul>(self->strides[0], self->shape[0]):
arith::ComputeReduce<ir::Mul>(self->shape)); arith::ComputeReduce<ir::Mul>(self->shape));
Expr elem_offset = self->elem_offset;
if (content_lanes > 1) {
e_dtype = make_zero(self->dtype.with_lanes(content_lanes));
extent = extent / make_const(self->elem_offset.type(), content_lanes);
elem_offset = self->elem_offset / make_const(self->elem_offset.type(),
content_lanes);
} else {
e_dtype = make_zero(self->dtype);
}
Array<Expr> acc_args{ Array<Expr> acc_args{
e_dtype, self->data, self->elem_offset, e_dtype, self->data, elem_offset,
extent, make_const(Int(32), access_mask)}; extent, make_const(Int(32), access_mask)};
return ir::Call::make( return ir::Call::make(
ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic); ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic);
......
...@@ -102,6 +102,7 @@ inline Expr AddressOffset(Var handle, Type dtype, int offset) { ...@@ -102,6 +102,7 @@ inline Expr AddressOffset(Var handle, Type dtype, int offset) {
inline Expr AddressOffset(Var handle, Type dtype, Expr offset) { inline Expr AddressOffset(Var handle, Type dtype, Expr offset) {
if (dtype.lanes() != 1) { if (dtype.lanes() != 1) {
offset = offset * make_const(offset.type(), dtype.lanes()); offset = offset * make_const(offset.type(), dtype.lanes());
offset = Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes());
} }
return Call::make( return Call::make(
Handle(), intrinsic::tvm_address_of, Handle(), intrinsic::tvm_address_of,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment