Commit cedd3900 by eqy Committed by Tianqi Chen

Support vector operations for AMD (llvm IR) (#623)

* Support vector operations for AMD (llvm IR)

* fix whitespace

* update comments, docstring
parent 25847a4f
......@@ -51,8 +51,10 @@ class Buffer : public NodeRef {
* \brief Get access ptr to the entire buffer.
* \param access_mask The access mask
* \param ptr_type The type of the pointer.
* \param content_lanes The number of lanes for the (data) type.
*/
TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle()) const;
TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
int content_lanes = 1) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
......
......@@ -25,7 +25,7 @@ class Buffer(NodeBase):
READ = 1
WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle"):
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1):
"""Get an access pointer to the head of buffer.
This is the recommended method to get buffer data
......@@ -41,6 +41,10 @@ class Buffer(NodeBase):
The data type of the result pointer. Do not specify
unless we want to cast pointer to specific type.
content_lanes: int, optional
The number of lanes for the data type. This value
is greater than one for vector types.
Examples
--------
.. code-block:: python
......@@ -63,7 +67,8 @@ class Buffer(NodeBase):
else:
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type)
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
content_lanes)
def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index.
......
......@@ -159,7 +159,7 @@ TVM_REGISTER_API("_Buffer")
TVM_REGISTER_API("_BufferAccessPtr")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer()
.access_ptr(args[1], args[2]);
.access_ptr(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_BufferVLoad")
......
......@@ -509,6 +509,18 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr(
return builder_->CreateInBoundsGEP(buffer, index);
}
llvm::Value* CodeGenLLVM::CreateBufferVecPtr(
Type t, llvm::Value* buffer, llvm::Value* index) {
CHECK_GT(t.lanes(), 1);
llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
CHECK(btype != nullptr);
llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
if (btype != ptype) {
buffer = builder_->CreatePointerCast(buffer, ptype);
}
return builder_->CreateInBoundsGEP(buffer, index);
}
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
auto it = var_map_.find(v);
CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
......@@ -572,10 +584,21 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
llvm::Value* ptr = CreateBufferPtr(
const Ramp *r = l->index.as<Ramp>();
llvm::Value* ptr;
unsigned addrspace;
if (!r) {
ptr = CreateBufferPtr(
l->type, MakeValue(l->buffer_var), MakeValue(l->index));
unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
addrspace = llvm::dyn_cast<llvm::PointerType>(
ptr->getType())->getAddressSpace();
} else {
Expr index = r->base / make_const(Int(32), r->lanes);
ptr = CreateBufferVecPtr(
l->type, MakeValue(l->buffer_var), MakeValue(index));
addrspace = llvm::dyn_cast<llvm::PointerType>(
ptr->getType())->getAddressSpace();
}
return builder_->CreatePointerCast(ptr, t_void_->getPointerTo(addrspace));
} else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
return llvm::Constant::getNullValue(t_void_p_);
......
......@@ -191,6 +191,7 @@ class CodeGenLLVM :
llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateBufferVecPtr(Type t, llvm::Value* buffer, llvm::Value* index);
// Vector concatenation.
llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
llvm::Value* CreateVecFlip(llvm::Value* vec);
......
......@@ -341,14 +341,23 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
0);
}
Expr Buffer::access_ptr(int access_mask, Type ptr_type) const {
Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const {
const BufferNode* self = operator->();
Expr e_dtype = make_zero(self->dtype);
Expr e_dtype;
Expr extent = (self->strides.size() == self->shape.size() ?
arith::ComputeExpr<ir::Mul>(self->strides[0], self->shape[0]):
arith::ComputeReduce<ir::Mul>(self->shape));
Expr elem_offset = self->elem_offset;
if (content_lanes > 1) {
e_dtype = make_zero(self->dtype.with_lanes(content_lanes));
extent = extent / make_const(self->elem_offset.type(), content_lanes);
elem_offset = self->elem_offset / make_const(self->elem_offset.type(),
content_lanes);
} else {
e_dtype = make_zero(self->dtype);
}
Array<Expr> acc_args{
e_dtype, self->data, self->elem_offset,
e_dtype, self->data, elem_offset,
extent, make_const(Int(32), access_mask)};
return ir::Call::make(
ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic);
......
......@@ -102,6 +102,7 @@ inline Expr AddressOffset(Var handle, Type dtype, int offset) {
inline Expr AddressOffset(Var handle, Type dtype, Expr offset) {
if (dtype.lanes() != 1) {
offset = offset * make_const(offset.type(), dtype.lanes());
offset = Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes());
}
return Call::make(
Handle(), intrinsic::tvm_address_of,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment