Commit e6d1c628 by Tianqi Chen Committed by GitHub

[CODEGEN] Fix alignment generation (#955)

parent b1ffac44
...@@ -339,6 +339,7 @@ void CodeGenLLVM::GetAlignment(Type t, ...@@ -339,6 +339,7 @@ void CodeGenLLVM::GetAlignment(Type t,
} }
arith::ModularEntry me = arith::EvalModular(index, align_map_); arith::ModularEntry me = arith::EvalModular(index, align_map_);
int align_bits = t.bits(); int align_bits = t.bits();
while (align_bits < max_align_bits && while (align_bits < max_align_bits &&
me.base % 2 == 0 && me.base % 2 == 0 &&
...@@ -814,13 +815,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) { ...@@ -814,13 +815,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
Type t = op->type; Type t = op->type;
int alignment, native_bits;
bool is_volatile = volatile_buf_.count(op->buffer_var.get()); bool is_volatile = volatile_buf_.count(op->buffer_var.get());
GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
llvm::Value* buffer = MakeValue(op->buffer_var); llvm::Value* buffer = MakeValue(op->buffer_var);
llvm::Value* index = MakeValue(op->index); llvm::Value* index = MakeValue(op->index);
if (t.lanes() == 1) { if (t.lanes() == 1) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
llvm::Value* ptr = CreateBufferPtr(t, buffer, index); llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
AddAliasInfo(load, op->buffer_var.get(), op->index, t); AddAliasInfo(load, op->buffer_var.get(), op->index, t);
...@@ -831,6 +832,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) { ...@@ -831,6 +832,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
buffer->getType())->getAddressSpace(); buffer->getType())->getAddressSpace();
if (const Ramp* ramp = op->index.as<Ramp>()) { if (const Ramp* ramp = op->index.as<Ramp>()) {
if (is_one(ramp->stride)) { if (is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
CHECK_EQ(ramp->lanes, t.lanes()); CHECK_EQ(ramp->lanes, t.lanes());
llvm::Value* ptr = CreateBufferPtr( llvm::Value* ptr = CreateBufferPtr(
t.element_of(), buffer, MakeValue(ramp->base)); t.element_of(), buffer, MakeValue(ramp->base));
...@@ -885,14 +888,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) { ...@@ -885,14 +888,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
void CodeGenLLVM::VisitStmt_(const Store* op) { void CodeGenLLVM::VisitStmt_(const Store* op) {
CHECK(is_one(op->predicate)); CHECK(is_one(op->predicate));
Type t = op->value.type(); Type t = op->value.type();
int alignment, native_bits;
bool is_volatile = volatile_buf_.count(op->buffer_var.get()); bool is_volatile = volatile_buf_.count(op->buffer_var.get());
GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
llvm::Value* buffer = MakeValue(op->buffer_var); llvm::Value* buffer = MakeValue(op->buffer_var);
llvm::Value* index = MakeValue(op->index); llvm::Value* index = MakeValue(op->index);
llvm::Value* value = MakeValue(op->value); llvm::Value* value = MakeValue(op->value);
if (t.lanes() == 1) { if (t.lanes() == 1) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
llvm::Value* ptr = CreateBufferPtr(t, buffer, index); llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile); llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type()); AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
...@@ -903,6 +906,8 @@ void CodeGenLLVM::VisitStmt_(const Store* op) { ...@@ -903,6 +906,8 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
buffer->getType())->getAddressSpace(); buffer->getType())->getAddressSpace();
if (const Ramp* ramp = op->index.as<Ramp>()) { if (const Ramp* ramp = op->index.as<Ramp>()) {
if (is_one(ramp->stride)) { if (is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
CHECK_EQ(ramp->lanes, t.lanes()); CHECK_EQ(ramp->lanes, t.lanes());
llvm::Value* ptr = CreateBufferPtr( llvm::Value* ptr = CreateBufferPtr(
t.element_of(), buffer, MakeValue(ramp->base)); t.element_of(), buffer, MakeValue(ramp->base));
......
...@@ -297,7 +297,22 @@ def test_rank_zero(): ...@@ -297,7 +297,22 @@ def test_rank_zero():
check_llvm(64) check_llvm(64)
def test_alignment():
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda i: A[i] * 3, name='B')
s = tvm.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=8)
s[B].vectorize(tx)
f = tvm.build(s, [A, B], "llvm")
for l in f.get_source().split("\n"):
if "align" in l and "4 x float" in l:
assert "align 32" in l
if __name__ == "__main__": if __name__ == "__main__":
test_alignment()
test_rank_zero() test_rank_zero()
test_llvm_bool() test_llvm_bool()
test_llvm_persist_parallel() test_llvm_persist_parallel()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment