Commit 54450614 by Tianqi Chen Committed by ziheng

[CODEGEN] More storage alignment info aware generation (#186)

* [CODEGEN] More storage alignment info aware generation

* fix

* fix

* fix warning
parent 3b8e70ae
...@@ -146,6 +146,8 @@ constexpr const char* virtual_thread = "virtual_thread"; ...@@ -146,6 +146,8 @@ constexpr const char* virtual_thread = "virtual_thread";
constexpr const char* volatile_scope = "volatile_scope"; constexpr const char* volatile_scope = "volatile_scope";
/*! \brief Mark storage scope of buffers */ /*! \brief Mark storage scope of buffers */
constexpr const char* storage_scope = "storage_scope"; constexpr const char* storage_scope = "storage_scope";
/*! \brief Mark storage alignement requirement of buffers */
constexpr const char* storage_alignment = "storage_alignment";
/*! \brief Mark storage scope of realization */ /*! \brief Mark storage scope of realization */
constexpr const char* realize_scope = "realize_scope"; constexpr const char* realize_scope = "realize_scope";
/*! \brief The allocation context for global malloc in host. */ /*! \brief The allocation context for global malloc in host. */
......
...@@ -20,6 +20,10 @@ enum DeviceAttrKind : int { ...@@ -20,6 +20,10 @@ enum DeviceAttrKind : int {
kMaxThreadsPerBlock = 1, kMaxThreadsPerBlock = 1,
kWarpSize = 2 kWarpSize = 2
}; };
/*! \brief Number of bytes each allocation must align to */
constexpr int kAllocAlignment = 64;
/*! /*!
* \brief TVM Runtime Device API, abstracts the device * \brief TVM Runtime Device API, abstracts the device
* specific interface for memory management. * specific interface for memory management.
......
...@@ -562,7 +562,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"): ...@@ -562,7 +562,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
result = convert(result) result = convert(result)
id_elem = convert(id_elem) id_elem = convert(id_elem)
combiner = _make.CommReducer(lhs, rhs, result, id_elem) combiner = _make.CommReducer(lhs, rhs, result, id_elem)
axis = convert(axis if isinstance(axis, list) else [axis]) axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
if where is None: if where is None:
where = convert(True) where = convert(True)
outputs = tuple(_make.Reduce(combiner, expr, axis, where, i) outputs = tuple(_make.Reduce(combiner, expr, axis, where, i)
...@@ -570,7 +570,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"): ...@@ -570,7 +570,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
return outputs[0] if size == 1 else outputs return outputs[0] if size == 1 else outputs
def reducer(expr, axis, where=None, *args): def reducer(expr, axis, where=None, *args):
if isinstance(axis, (_schedule.IterVar, list)): if isinstance(axis, (_schedule.IterVar, list, tuple)):
assert not args assert not args
return _make_reduce(expr, axis, where) return _make_reduce(expr, axis, where)
if where is None: if where is None:
......
...@@ -91,6 +91,20 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { ...@@ -91,6 +91,20 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
module_->setTargetTriple(tm->getTargetTriple().str()); module_->setTargetTriple(tm->getTargetTriple().str());
module_->setDataLayout(tm->createDataLayout()); module_->setDataLayout(tm->createDataLayout());
data_layout_.reset(new llvm::DataLayout(module_.get())); data_layout_.reset(new llvm::DataLayout(module_.get()));
// initialize native vector bits
std::string target = tm->getTarget().getName();
if (target == "arm") {
native_vector_bits_ = 16 * 8;
} else if (target == "x86-64") {
// for avx512
native_vector_bits_ = 64 * 8;
} else if (target == "x86") {
native_vector_bits_ = 32 * 8;
} else {
native_vector_bits_ = 32 * 8;
LOG(WARNING) << "set native vector to be " << native_vector_bits_ / 8
<< " for target " << target;
}
} }
void CodeGenLLVM::InitGlobalContext() { void CodeGenLLVM::InitGlobalContext() {
...@@ -104,7 +118,7 @@ void CodeGenLLVM::InitGlobalContext() { ...@@ -104,7 +118,7 @@ void CodeGenLLVM::InitGlobalContext() {
void CodeGenLLVM::InitFuncState() { void CodeGenLLVM::InitFuncState() {
var_map_.clear(); var_map_.clear();
align_map_.clear(); align_map_.clear();
alloc_storage_scope_.clear(); alloc_storage_info_.clear();
} }
void CodeGenLLVM::AddFunction(const LoweredFunc& f) { void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
...@@ -750,7 +764,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { ...@@ -750,7 +764,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
int CodeGenLLVM::NativeVectorBits(const std::string& storage_scope) const { int CodeGenLLVM::NativeVectorBits(const std::string& storage_scope) const {
// By default, we ask the buffer to be aligned to 64 bytes // By default, we ask the buffer to be aligned to 64 bytes
return 64 * 8; return native_vector_bits_;
} }
void CodeGenLLVM::GetAlignment( void CodeGenLLVM::GetAlignment(
...@@ -759,17 +773,20 @@ void CodeGenLLVM::GetAlignment( ...@@ -759,17 +773,20 @@ void CodeGenLLVM::GetAlignment(
int& alignment = *p_alignment; int& alignment = *p_alignment;
int& native_bits = *p_native_bits; int& native_bits = *p_native_bits;
// The storage scope. // The storage scope.
std::string scope; StorageInfo info;
auto it = alloc_storage_scope_.find(buf_var); auto it = alloc_storage_info_.find(buf_var);
if (it != alloc_storage_scope_.end()) { if (it != alloc_storage_info_.end()) {
scope = it->second; info = it->second;
} }
arith::ModularEntry m = EvalModular(index, align_map_); arith::ModularEntry m = EvalModular(index, align_map_);
native_bits = NativeVectorBits(scope); native_bits = NativeVectorBits(info.scope);
alignment = t.element_of().bits(); alignment = t.element_of().bits();
// find alignment // find alignment, cannot exceed allocated alignment
int max_align_bits = std::min(
info.alignment * 8, alignment * t.lanes());
while ((m.coeff & 1) == 0 && while ((m.coeff & 1) == 0 &&
(m.base & 1) == 0 && (m.base & 1) == 0 &&
alignment < max_align_bits &&
alignment < native_bits) { alignment < native_bits) {
m.coeff /= 2; m.coeff /= 2;
m.base /= 2; m.base /= 2;
...@@ -1291,8 +1308,19 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { ...@@ -1291,8 +1308,19 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
int32_t constant_size = op->constant_allocation_size(); int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0) CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now"; << "Can only handle constant size stack allocation for now";
buf = builder_->CreateAlloca( llvm::AllocaInst* alloca = builder_->CreateAlloca(
LLVMType(op->type), ConstInt32(constant_size)); LLVMType(op->type), ConstInt32(constant_size));
buf = alloca;
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
// Align stack to be multiple of 4 if it is
// TODO(tqchen) have pass to detect vector access and pre-set alignment
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = op->type.bytes() * 4;
}
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
alloca->setAlignment(info.alignment);
}
info.alignment = alloca->getAlignment();
} }
buf = builder_->CreatePointerCast(buf, LLVMType(op->type)->getPointerTo()); buf = builder_->CreatePointerCast(buf, LLVMType(op->type)->getPointerTo());
CHECK(!var_map_.count(op->buffer_var.get())); CHECK(!var_map_.count(op->buffer_var.get()));
...@@ -1304,7 +1332,13 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { ...@@ -1304,7 +1332,13 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == ir::attr::storage_scope) { if (op->attr_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>(); const Variable* v = op->node.as<Variable>();
CHECK(v); CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value; alloc_storage_info_[v].scope = op->value.as<StringImm>()->value;
this->VisitStmt(op->body);
} else if (op->attr_key == ir::attr::storage_alignment) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_info_[v].alignment =
static_cast<int>(op->value.as<IntImm>()->value);
this->VisitStmt(op->body); this->VisitStmt(op->body);
} else { } else {
this->VisitStmt(op->body); this->VisitStmt(op->body);
......
...@@ -115,6 +115,13 @@ class CodeGenLLVM : ...@@ -115,6 +115,13 @@ class CodeGenLLVM :
virtual void Scalarize(const Expr& e, virtual void Scalarize(const Expr& e,
std::function<void(int i, llvm::Value* v)> f); std::function<void(int i, llvm::Value* v)> f);
protected: protected:
/*! \brief The storage information */
struct StorageInfo {
/*! \brief The storage scope */
std::string scope;
/*! \brief The alignment of allocation */
int alignment{0};
};
/*! /*!
* \param t The original type. * \param t The original type.
* \return LLVM type of t * \return LLVM type of t
...@@ -174,8 +181,10 @@ class CodeGenLLVM : ...@@ -174,8 +181,10 @@ class CodeGenLLVM :
llvm::Function* f_tvm_parallel_for_{nullptr}; llvm::Function* f_tvm_parallel_for_{nullptr};
// The acting body // The acting body
llvm::BasicBlock* block_{nullptr}; llvm::BasicBlock* block_{nullptr};
/*! \brief native vector bits of current targetx*/
int native_vector_bits_{0};
/*! \brief the storage scope of allocation */ /*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_; std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
private: private:
// comparison op // comparison op
......
...@@ -19,7 +19,6 @@ namespace codegen { ...@@ -19,7 +19,6 @@ namespace codegen {
class VPIDeviceAPI final : public runtime::DeviceAPI { class VPIDeviceAPI final : public runtime::DeviceAPI {
public: public:
VPIDeviceAPI() { VPIDeviceAPI() {
static const size_t kAllocAlign = 32U;
const char* s_ram_size = getenv("TVM_VPI_RAM_SIZE_MB"); const char* s_ram_size = getenv("TVM_VPI_RAM_SIZE_MB");
// 16 MB ram. // 16 MB ram.
int ram_size = 32; int ram_size = 32;
...@@ -27,7 +26,7 @@ class VPIDeviceAPI final : public runtime::DeviceAPI { ...@@ -27,7 +26,7 @@ class VPIDeviceAPI final : public runtime::DeviceAPI {
ram_size = atoi(s_ram_size); ram_size = atoi(s_ram_size);
} }
ram_.resize(ram_size << 17); ram_.resize(ram_size << 17);
ram_head_ = kAllocAlign; ram_head_ = runtime::kAllocAlignment;
ram_max_ = ram_.size() * sizeof(int64_t); ram_max_ = ram_.size() * sizeof(int64_t);
LOG(INFO) << "Initialize VPI simulated ram " << ram_size << "MB ..."; LOG(INFO) << "Initialize VPI simulated ram " << ram_size << "MB ...";
} }
...@@ -51,10 +50,9 @@ class VPIDeviceAPI final : public runtime::DeviceAPI { ...@@ -51,10 +50,9 @@ class VPIDeviceAPI final : public runtime::DeviceAPI {
} }
} }
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final { void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
static const size_t kAllocAlign = 32U;
// always align to 32 bytes at least. // always align to 32 bytes at least.
CHECK_LE(alignment, kAllocAlign); CHECK_LE(alignment, runtime::kAllocAlignment);
alignment = kAllocAlign; alignment = runtime::kAllocAlignment;
// always allocate block with aligned size. // always allocate block with aligned size.
size += alignment - (size % alignment); size += alignment - (size % alignment);
// This is not thread safe, but fine for simulation. // This is not thread safe, but fine for simulation.
...@@ -67,7 +65,7 @@ class VPIDeviceAPI final : public runtime::DeviceAPI { ...@@ -67,7 +65,7 @@ class VPIDeviceAPI final : public runtime::DeviceAPI {
b.is_free = false; b.is_free = false;
return reinterpret_cast<void*>(head); return reinterpret_cast<void*>(head);
} else { } else {
CHECK_EQ(ram_head_ % kAllocAlign, 0U); CHECK_EQ(ram_head_ % runtime::kAllocAlignment, 0U);
Block b; Block b;
b.size = size; b.size = size;
b.is_free = false; b.is_free = false;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/runtime/device_api.h>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <unordered_set> #include <unordered_set>
...@@ -180,6 +180,10 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -180,6 +180,10 @@ LoweredFunc MakeAPI(Stmt body,
v_arg->name_hint + ".data")) { v_arg->name_hint + ".data")) {
Var vptr(buf->data); Var vptr(buf->data);
handle_data_type.Set(vptr, make_const(buf->dtype, 0)); handle_data_type.Set(vptr, make_const(buf->dtype, 0));
// mark storage alignment of external buffer arguments.
seq_init.emplace_back(AttrStmt::make(
vptr, ir::attr::storage_alignment,
IntImm::make(Int(32), runtime::kAllocAlignment), nop));
} }
// shape field // shape field
Var v_shape(v_arg->name_hint + ".shape", Handle()); Var v_shape(v_arg->name_hint + ".shape", Handle());
......
...@@ -136,7 +136,7 @@ inline size_t GetDataSize(TVMArray* arr) { ...@@ -136,7 +136,7 @@ inline size_t GetDataSize(TVMArray* arr) {
inline size_t GetDataAlignment(TVMArray* arr) { inline size_t GetDataAlignment(TVMArray* arr) {
size_t align = (arr->dtype.bits / 8) * arr->dtype.lanes; size_t align = (arr->dtype.bits / 8) * arr->dtype.lanes;
if (align < 8) return 8; if (align < kAllocAlignment) return kAllocAlignment;
return align; return align;
} }
......
...@@ -22,7 +22,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -22,7 +22,7 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaSetDevice(ctx.device_id));
} }
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
int value; int value = 0;
switch (kind) { switch (kind) {
case kExist: case kExist:
value = ( value = (
......
...@@ -40,7 +40,7 @@ def test_llvm_add_pipeline(): ...@@ -40,7 +40,7 @@ def test_llvm_add_pipeline():
print("Skip because llvm is not enabled..") print("Skip because llvm is not enabled..")
return return
temp = util.tempdir() temp = util.tempdir()
target = "llvm -target=arm-none-linux-gnueabihf" target = "llvm -target=armv7-none-linux-gnueabihf"
f = tvm.build(s, [A, B, C], target) f = tvm.build(s, [A, B, C], target)
path = temp.relpath("myadd.o") path = temp.relpath("myadd.o")
f.save(path) f.save(path)
......
...@@ -33,6 +33,16 @@ def test_tensor_slice(): ...@@ -33,6 +33,16 @@ def test_tensor_slice():
B = tvm.compute((n,), lambda i: A[0][i] + A[0][i]) B = tvm.compute((n,), lambda i: A[0][i] + A[0][i])
def test_tensor_reduce_multi_axis():
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m, n), name='A')
k1 = tvm.reduce_axis((0, n), "k")
k2 = tvm.reduce_axis((0, m), "k")
C = tvm.compute((1,), lambda _: tvm.sum(A[k1, k2], axis=(k1, k2)))
C = tvm.compute((1,), lambda _: tvm.sum(A[k1, k2], axis=[k1, k2]))
def test_tensor_comm_reducer(): def test_tensor_comm_reducer():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -157,6 +167,7 @@ def test_tuple_with_different_deps(): ...@@ -157,6 +167,7 @@ def test_tuple_with_different_deps():
assert stmt.node == C.op and len(ret) == 1 assert stmt.node == C.op and len(ret) == 1
if __name__ == "__main__": if __name__ == "__main__":
test_tensor_reduce_multi_axis()
test_conv1d() test_conv1d()
test_tensor_slice() test_tensor_slice()
test_tensor() test_tensor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment