Commit 0a19b16a by Tianqi Chen Committed by GitHub

[CODEGEN/PASS] add restricted, alignment option (#221)

* [CODEGEN/PASS] add restricted, alignment option

* fix lint

* Fix the alloca
parent 00506a62
...@@ -233,6 +233,9 @@ Stmt LoopPartition(Stmt stmt); ...@@ -233,6 +233,9 @@ Stmt LoopPartition(Stmt stmt);
* \param api_args Arguments to the function, can be either Var, or Buffer * \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_unpacked_args Number of arguments that * \param num_unpacked_args Number of arguments that
* are processed in plain form instead of packed form. * are processed in plain form instead of packed form.
* \param is_restricted Whether the caller can guarantee that each buffer argument do not overlap.
* It is recommended to set to true for optimized code if such invariant holds.
*
* \return a LoweredFunc with the specified signiture. * \return a LoweredFunc with the specified signiture.
* *
* \note * \note
...@@ -254,7 +257,8 @@ Stmt LoopPartition(Stmt stmt); ...@@ -254,7 +257,8 @@ Stmt LoopPartition(Stmt stmt);
LoweredFunc MakeAPI(Stmt body, LoweredFunc MakeAPI(Stmt body,
std::string name, std::string name,
Array<NodeRef> api_args, Array<NodeRef> api_args,
int num_unpacked_args); int num_unpacked_args,
bool is_restricted);
/*! /*!
* \brief Find undefined vars in the statment. * \brief Find undefined vars in the statment.
......
...@@ -86,6 +86,11 @@ class LoweredFuncNode : public FunctionBaseNode { ...@@ -86,6 +86,11 @@ class LoweredFuncNode : public FunctionBaseNode {
LoweredFuncType func_type{kMixedFunc}; LoweredFuncType func_type{kMixedFunc};
/*! \brief Whether this function is packed function */ /*! \brief Whether this function is packed function */
bool is_packed_func{true}; bool is_packed_func{true};
/*!
* \brief Whether function ensures that argument pointers do not alias.
* This corresponds to restrict keyword in C.
*/
bool is_restricted{true};
/*! \brief The body statment of the function */ /*! \brief The body statment of the function */
Stmt body; Stmt body;
/*! \return name of the operation */ /*! \return name of the operation */
...@@ -104,6 +109,7 @@ class LoweredFuncNode : public FunctionBaseNode { ...@@ -104,6 +109,7 @@ class LoweredFuncNode : public FunctionBaseNode {
v->Visit("handle_data_type", &handle_data_type); v->Visit("handle_data_type", &handle_data_type);
v->Visit("func_type", &func_type); v->Visit("func_type", &func_type);
v->Visit("is_packed_func", &is_packed_func); v->Visit("is_packed_func", &is_packed_func);
v->Visit("is_restricted", &is_restricted);
v->Visit("body", &body); v->Visit("body", &body);
} }
......
...@@ -24,6 +24,9 @@ enum DeviceAttrKind : int { ...@@ -24,6 +24,9 @@ enum DeviceAttrKind : int {
/*! \brief Number of bytes each allocation must align to */ /*! \brief Number of bytes each allocation must align to */
constexpr int kAllocAlignment = 64; constexpr int kAllocAlignment = 64;
/*! \brief Number of bytes each allocation must align to in temporary allocation */
constexpr int kTempAllocaAlignment = 64;
/*! /*!
* \brief TVM Runtime Device API, abstracts the device * \brief TVM Runtime Device API, abstracts the device
* specific interface for memory management. * specific interface for memory management.
......
...@@ -27,7 +27,9 @@ class BuildConfig(object): ...@@ -27,7 +27,9 @@ class BuildConfig(object):
'auto_unroll_min_depth': 1, 'auto_unroll_min_depth': 1,
'unroll_explicit': True, 'unroll_explicit': True,
'detect_global_barrier': False, 'detect_global_barrier': False,
'offset_factor': 0 'offset_factor': 0,
'data_alignment': 0,
'restricted_func': True
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._old_scope = None self._old_scope = None
...@@ -77,10 +79,20 @@ def build_config(**kwargs): ...@@ -77,10 +79,20 @@ def build_config(**kwargs):
detect_global_barrier: bool, default=True detect_global_barrier: bool, default=True
Whether detect global barrier. Whether detect global barrier.
data_alignment: int, optional
The alignment of data pointer in bytes.
If 0 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, default=0 offset_factor: int, default=0
The factor used in default buffer declaration. The factor used in default buffer declaration.
If specified as 0, offset field is not used. If specified as 0, offset field is not used.
restricted_func: bool, default=True
Whether build restricted function.
That is each buffer argument to the function are guaranteed
not to overlap. This enables more optimization.
Corresponds to restricted keyword in C99
Returns Returns
------- -------
config: BuildConfig config: BuildConfig
...@@ -110,12 +122,15 @@ def get_binds(args, binds=None): ...@@ -110,12 +122,15 @@ def get_binds(args, binds=None):
The list of symbolic buffers of arguments. The list of symbolic buffers of arguments.
""" """
binds = {} if binds is None else binds.copy() binds = {} if binds is None else binds.copy()
offset_factor = BuildConfig.current.offset_factor cfg = BuildConfig.current
arg_list = [] arg_list = []
for x in args: for x in args:
if isinstance(x, tensor.Tensor): if isinstance(x, tensor.Tensor):
buf = api.decl_buffer(x.shape, dtype=x.dtype, name=x.name, buf = api.decl_buffer(x.shape,
offset_factor=offset_factor) dtype=x.dtype,
name=x.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor)
assert x not in binds assert x not in binds
binds[x] = buf binds[x] = buf
arg_list.append(buf) arg_list.append(buf)
...@@ -181,7 +196,7 @@ def lower(sch, ...@@ -181,7 +196,7 @@ def lower(sch,
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
if simple_mode: if simple_mode:
return stmt return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0) return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
def build(sch, def build(sch,
......
...@@ -68,6 +68,12 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") ...@@ -68,6 +68,12 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
*ret = PassName(args[0], args[1], args[2], args[3]); \ *ret = PassName(args[0], args[1], args[2], args[3]); \
}) \ }) \
#define REGISTER_PASS5(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1], args[2], args[3], args[4]); \
}) \
REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA); REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(CanonicalSimplify); REGISTER_PASS1(CanonicalSimplify);
...@@ -76,7 +82,7 @@ REGISTER_PASS2(StorageFlatten); ...@@ -76,7 +82,7 @@ REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop); REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS4(UnrollLoop); REGISTER_PASS4(UnrollLoop);
REGISTER_PASS2(StorageSync); REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI); REGISTER_PASS5(MakeAPI);
REGISTER_PASS1(SplitHostDevice); REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(StorageRewrite); REGISTER_PASS1(StorageRewrite);
REGISTER_PASS1(InjectVirtualThread); REGISTER_PASS1(InjectVirtualThread);
......
...@@ -46,6 +46,9 @@ void CodeGenC::AddFunction(LoweredFunc f) { ...@@ -46,6 +46,9 @@ void CodeGenC::AddFunction(LoweredFunc f) {
if (handle_data_type_.count(v.get())) { if (handle_data_type_.count(v.get())) {
PrintType(handle_data_type_.at(v.get()), stream); PrintType(handle_data_type_.at(v.get()), stream);
stream << "*"; stream << "*";
if (f->is_restricted && restrict_keyword_.length() != 0) {
stream << ' ' << restrict_keyword_;
}
} else { } else {
PrintType(v.type(), stream); PrintType(v.type(), stream);
} }
......
...@@ -167,6 +167,8 @@ class CodeGenC : ...@@ -167,6 +167,8 @@ class CodeGenC :
// override // override
void PrintSSAAssign( void PrintSSAAssign(
const std::string& target, const std::string& src, Type t) final; const std::string& target, const std::string& src, Type t) final;
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */ /*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_; std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */ /*! \brief the data type of allocated buffers */
......
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
CodeGenCUDA::CodeGenCUDA() {
restrict_keyword_ = "__restrict__";
}
void CodeGenCUDA::Init(bool output_ssa) { void CodeGenCUDA::Init(bool output_ssa) {
CodeGenC::Init(output_ssa); CodeGenC::Init(output_ssa);
vid_global_barrier_state_ = GetUniqueName(runtime::symbol::tvm_global_barrier_state); vid_global_barrier_state_ = GetUniqueName(runtime::symbol::tvm_global_barrier_state);
......
...@@ -16,6 +16,7 @@ namespace codegen { ...@@ -16,6 +16,7 @@ namespace codegen {
class CodeGenCUDA final : public CodeGenC { class CodeGenCUDA final : public CodeGenC {
public: public:
CodeGenCUDA();
void Init(bool output_ssa); void Init(bool output_ssa);
void AddFunction(LoweredFunc f); void AddFunction(LoweredFunc f);
// override behavior // override behavior
......
...@@ -12,6 +12,10 @@ ...@@ -12,6 +12,10 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
CodeGenOpenCL::CodeGenOpenCL() {
restrict_keyword_ = "restrict";
}
void CodeGenOpenCL::InitFuncState(LoweredFunc f) { void CodeGenOpenCL::InitFuncState(LoweredFunc f) {
CodeGenC::InitFuncState(f); CodeGenC::InitFuncState(f);
for (Var arg : f->args) { for (Var arg : f->args) {
......
...@@ -16,6 +16,7 @@ namespace codegen { ...@@ -16,6 +16,7 @@ namespace codegen {
class CodeGenOpenCL final : public CodeGenC { class CodeGenOpenCL final : public CodeGenC {
public: public:
CodeGenOpenCL();
void AddFunction(LoweredFunc f); void AddFunction(LoweredFunc f);
// override print thread tag. // override print thread tag.
void InitFuncState(LoweredFunc f) final; void InitFuncState(LoweredFunc f) final;
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "./codegen_llvm.h" #include "./codegen_llvm.h"
#include "../../pass/ir_util.h" #include "../../pass/ir_util.h"
...@@ -65,6 +66,8 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -65,6 +66,8 @@ void CodeGenLLVM::Init(const std::string& module_name,
md_very_likely_branch_ = md_very_likely_branch_ =
md_builder_->createBranchWeights(1 << 30, 0); md_builder_->createBranchWeights(1 << 30, 0);
md_tbaa_root_ = md_builder_->createTBAARoot("tvmtbaa"); md_tbaa_root_ = md_builder_->createTBAARoot("tvmtbaa");
md_tbaa_alias_set_ = md_builder_->createTBAAScalarTypeNode(
"alias_set", md_tbaa_root_);
} }
ctx_ = ctx; ctx_ = ctx;
// initialize modules // initialize modules
...@@ -131,10 +134,12 @@ void CodeGenLLVM::InitFuncState() { ...@@ -131,10 +134,12 @@ void CodeGenLLVM::InitFuncState() {
var_map_.clear(); var_map_.clear();
align_map_.clear(); align_map_.clear();
alloc_storage_info_.clear(); alloc_storage_info_.clear();
alias_var_set_.clear();
} }
void CodeGenLLVM::AddFunction(const LoweredFunc& f) { void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
this->InitFuncState(); this->InitFuncState();
is_restricted_ = f->is_restricted;
CHECK(!module_->getFunction(f->name)) CHECK(!module_->getFunction(f->name))
<< "Function " << f->name << "already exists in module"; << "Function " << f->name << "already exists in module";
std::vector<llvm::Type*> arg_type; std::vector<llvm::Type*> arg_type;
...@@ -143,6 +148,9 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) { ...@@ -143,6 +148,9 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
if (t.is_handle() && f->handle_data_type.count(arg)) { if (t.is_handle() && f->handle_data_type.count(arg)) {
arg_type.push_back( arg_type.push_back(
LLVMType(f->handle_data_type[arg].type())->getPointerTo()); LLVMType(f->handle_data_type[arg].type())->getPointerTo());
if (!is_restricted_) {
alias_var_set_.insert(arg.get());
}
} else { } else {
arg_type.push_back(LLVMType(t)); arg_type.push_back(LLVMType(t));
} }
...@@ -265,6 +273,14 @@ llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) { ...@@ -265,6 +273,14 @@ llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) {
void CodeGenLLVM::AddAliasInfo( void CodeGenLLVM::AddAliasInfo(
llvm::Instruction* inst, const Variable* buffer, Expr index, Type t) { llvm::Instruction* inst, const Variable* buffer, Expr index, Type t) {
if (alias_var_set_.count(buffer) != 0) {
// Mark all possibly aliased pointer as same type.
llvm::MDNode* meta = md_tbaa_alias_set_;
inst->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(meta, meta, 0));
return;
}
int base = 0, width = 0; int base = 0, width = 0;
// create meta-data for alias analysis // create meta-data for alias analysis
// Use a group of binary tree ranges. // Use a group of binary tree ranges.
...@@ -1324,10 +1340,10 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { ...@@ -1324,10 +1340,10 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
LLVMType(op->type), ConstInt32(constant_size)); LLVMType(op->type), ConstInt32(constant_size));
buf = alloca; buf = alloca;
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
// Align stack to be multiple of 4 if it is // Align stack to be TempAllocaAlignment.
// TODO(tqchen) have pass to detect vector access and pre-set alignment // TODO(tqchen) have pass to detect vector access and pre-set alignment
if (constant_size % 4 == 0 && info.alignment == 0) { if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = op->type.bytes() * 4; info.alignment = GetTempAllocaAlignment(op->type, constant_size);
} }
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) { if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
alloca->setAlignment(info.alignment); alloca->setAlignment(info.alignment);
...@@ -1408,6 +1424,11 @@ void CodeGenLLVM::VisitStmt_(const LetStmt* op) { ...@@ -1408,6 +1424,11 @@ void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
llvm::Value* v = MakeValue(op->value); llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get())); CHECK(!var_map_.count(op->var.get()));
CHECK(!align_map_.count(op->var.get())); CHECK(!align_map_.count(op->var.get()));
if (op->var.type().is_handle()) {
if (!is_restricted_) {
alias_var_set_.insert(op->var.get());
}
}
var_map_[op->var.get()] = v; var_map_[op->var.get()] = v;
align_map_[op->var.get()] = arith::EvalModular(op->value, align_map_); align_map_[op->var.get()] = arith::EvalModular(op->value, align_map_);
this->VisitStmt(op->body); this->VisitStmt(op->body);
......
...@@ -174,6 +174,7 @@ class CodeGenLLVM : ...@@ -174,6 +174,7 @@ class CodeGenLLVM :
// branch // branch
llvm::MDNode* md_very_likely_branch_{nullptr}; llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr}; llvm::MDNode* md_tbaa_root_{nullptr};
llvm::MDNode* md_tbaa_alias_set_{nullptr};
// TVM related data types // TVM related data types
llvm::Type* t_tvm_shape_index_{nullptr}; llvm::Type* t_tvm_shape_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr}; llvm::Type* t_tvm_func_handle_{nullptr};
...@@ -234,6 +235,10 @@ class CodeGenLLVM : ...@@ -234,6 +235,10 @@ class CodeGenLLVM :
std::unordered_map<std::string, llvm::Constant*> str_map_; std::unordered_map<std::string, llvm::Constant*> str_map_;
// The alignment information // The alignment information
std::unordered_map<const Variable*, arith::ModularEntry> align_map_; std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
// Whether current function is restricted
bool is_restricted_{true};
// set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_;
// The local module_context // The local module_context
llvm::GlobalVariable* gv_mod_ctx_{nullptr}; llvm::GlobalVariable* gv_mod_ctx_{nullptr};
// global to packed function handle // global to packed function handle
......
...@@ -100,7 +100,7 @@ Buffer Buffer::MakeStrideView() const { ...@@ -100,7 +100,7 @@ Buffer Buffer::MakeStrideView() const {
Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const { Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
const BufferNode* n = operator->(); const BufferNode* n = operator->();
Expr elem_offset = ElemOffset(n, begins); Expr elem_offset = ir::Simplify(ElemOffset(n, begins));
Array<Expr> strides = n->strides; Array<Expr> strides = n->strides;
if (strides.size() == 0) { if (strides.size() == 0) {
bool can_relax = true; bool can_relax = true;
...@@ -146,6 +146,9 @@ Buffer BufferNode::make(Var data, ...@@ -146,6 +146,9 @@ Buffer BufferNode::make(Var data,
n->shape = std::move(shape); n->shape = std::move(shape);
n->strides = std::move(strides); n->strides = std::move(strides);
n->name = std::move(name); n->name = std::move(name);
if (scope.length() == 0) {
scope = "global";
}
n->scope = std::move(scope); n->scope = std::move(scope);
if (!elem_offset.defined()) { if (!elem_offset.defined()) {
elem_offset = make_const(n->shape[0].type(), 0); elem_offset = make_const(n->shape[0].type(), 0);
...@@ -156,7 +159,7 @@ Buffer BufferNode::make(Var data, ...@@ -156,7 +159,7 @@ Buffer BufferNode::make(Var data,
if (offset_factor == 0) { if (offset_factor == 0) {
offset_factor = 1; offset_factor = 1;
} }
n->elem_offset = elem_offset; n->elem_offset = std::move(elem_offset);
n->data_alignment = data_alignment; n->data_alignment = data_alignment;
n->offset_factor = offset_factor; n->offset_factor = offset_factor;
return Buffer(n); return Buffer(n);
......
...@@ -80,41 +80,64 @@ void ArgBinder::BindBuffer(const Buffer& arg, ...@@ -80,41 +80,64 @@ void ArgBinder::BindBuffer(const Buffer& arg,
CHECK_EQ(arg->scope, value->scope) CHECK_EQ(arg->scope, value->scope)
<< "Argument " << arg_name << "Argument " << arg_name
<< " Buffer bind scope mismatch"; << " Buffer bind scope mismatch";
CHECK_EQ(arg->dtype, value->dtype)
<< "Argument " << arg_name
<< " Buffer bind data type mismatch";
if (value->data_alignment % arg->data_alignment != 0) {
LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement "
<< " required_alignment=" << arg->data_alignment
<< ", provided_alignment=" << value->data_alignment;
}
// bind pointer and offset.
if (is_zero(arg->elem_offset) && !is_zero(value->elem_offset)) {
// If target buffer only accepts pointer, try to fold offset into pointer.
Expr addr = AddressOffset(value->data, value->dtype, value->elem_offset);
if (Bind_(arg->data, addr, arg_name + ".data", true)) {
int offset_factor = arg->data_alignment * 8 / (arg->dtype.bits() * arg->dtype.lanes());
if (offset_factor > 1) {
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
} else {
this->Bind(arg->data, value->data, arg_name + ".data"); this->Bind(arg->data, value->data, arg_name + ".data");
if (arg->shape.size() > value->shape.size()) { if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), arg->offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
}
if (arg->shape.size() < value->shape.size()) {
CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch"; CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
size_t diff = arg->shape.size() - value->shape.size(); size_t diff = value->shape.size() - arg->shape.size();
for (size_t i = 0; i < diff; ++i) { for (size_t i = 0; i < diff; ++i) {
CHECK(is_one(arg->shape[i])) CHECK(is_one(value->shape[i]))
<< "Argument " << arg_name << " shape mismatch" << "Argument " << arg_name << " shape mismatch"
<< arg->shape << " vs " << value->shape; << arg->shape << " vs " << value->shape;
} }
for (size_t i = 0; i < value->shape.size(); ++i) { for (size_t i = 0; i < arg->shape.size(); ++i) {
std::ostringstream os; std::ostringstream os;
os << arg_name << ".shape[" << i << "]"; os << arg_name << ".shape[" << i << "]";
this->Bind(arg->shape[i + diff], value->shape[i], os.str()); this->Bind(arg->shape[i], value->shape[i + diff], os.str());
} }
if (arg->strides.size() != 0) { if (value->strides.size() != 0) {
CHECK_EQ(arg->strides.size(), arg->shape.size()); CHECK_EQ(arg->strides.size(), arg->shape.size());
CHECK_EQ(value->strides.size(), value->shape.size()); CHECK_EQ(value->strides.size(), value->shape.size());
for (size_t i = 0; i < value->strides.size(); ++i) { for (size_t i = 0; i < arg->strides.size(); ++i) {
std::ostringstream os; std::ostringstream os;
os << arg_name << ".strides[" << i << "]"; os << arg_name << ".strides[" << i << "]";
this->Bind(arg->strides[i + diff], value->strides[i], os.str()); this->Bind(arg->strides[i], value->strides[i + diff], os.str());
} }
} }
} else { } else {
this->BindArray(arg->shape, value->shape, arg_name + ".shape"); this->BindArray(arg->shape, value->shape, arg_name + ".shape");
this->BindArray(arg->strides, value->strides, arg_name + ".strides"); this->BindArray(arg->strides, value->strides, arg_name + ".strides");
} }
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), arg->offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
} }
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) { inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define TVM_PASS_IR_UTIL_H_ #define TVM_PASS_IR_UTIL_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/runtime/device_api.h>
#include <vector> #include <vector>
namespace tvm { namespace tvm {
...@@ -93,6 +94,20 @@ inline Expr AddressOffset(Var handle, Type dtype, int offset) { ...@@ -93,6 +94,20 @@ inline Expr AddressOffset(Var handle, Type dtype, int offset) {
} }
/*! /*!
* \brief Address of handle + offset
* \param handle the array handle.
* \param dtype The data type.
* \param offset the offset index.
*/
inline Expr AddressOffset(Var handle, Type dtype, Expr offset) {
return Call::make(
Handle(), intrinsic::tvm_address_of,
{Load::make(dtype, handle, offset * make_const(offset.type(), dtype.lanes()),
const_true(dtype.lanes()))},
Call::PureIntrinsic);
}
/*!
* \brief Set value into struct. * \brief Set value into struct.
* \param handle the struct handle. * \param handle the struct handle.
* \param index the offset index. * \param index the offset index.
...@@ -125,6 +140,23 @@ inline Type APIType(Type t) { ...@@ -125,6 +140,23 @@ inline Type APIType(Type t) {
CHECK(t.is_float()); CHECK(t.is_float());
return Float(64); return Float(64);
} }
/*!
* \brief Rule to get allocation alignment requirement for a given const array.
* \param type The type of allocation.
* \param const_size The constant size of the array.
* \return the alignment
*/
inline int GetTempAllocaAlignment(Type type, int32_t const_size) {
int align = runtime::kTempAllocaAlignment;
if (const_size > 0) {
const_size = const_size * type.bits() * type.lanes() / 8;
while (align > const_size) {
align = align / 2;
}
}
return align;
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_PASS_IR_UTIL_H_ #endif // TVM_PASS_IR_UTIL_H_
...@@ -25,7 +25,8 @@ inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) { ...@@ -25,7 +25,8 @@ inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
LoweredFunc MakeAPI(Stmt body, LoweredFunc MakeAPI(Stmt body,
std::string name, std::string name,
Array<NodeRef> api_args, Array<NodeRef> api_args,
int num_unpacked_args) { int num_unpacked_args,
bool is_restricted) {
const Stmt nop = Evaluate::make(0); const Stmt nop = Evaluate::make(0);
int num_args = static_cast<int>(api_args.size()); int num_args = static_cast<int>(api_args.size());
CHECK_LE(num_unpacked_args, num_args); CHECK_LE(num_unpacked_args, num_args);
...@@ -132,6 +133,7 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -132,6 +133,7 @@ LoweredFunc MakeAPI(Stmt body,
n->args = args; n->args = args;
n->handle_data_type = binder.def_handle_dtype(); n->handle_data_type = binder.def_handle_dtype();
n->is_packed_func = num_unpacked_args == 0; n->is_packed_func = num_unpacked_args == 0;
n->is_restricted = is_restricted;
// Set device context // Set device context
if (vmap.count(device_id.get())) { if (vmap.count(device_id.get())) {
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/runtime/device_api.h>
#include <unordered_map> #include <unordered_map>
#include "./ir_util.h" #include "./ir_util.h"
#include "./arg_binder.h" #include "./arg_binder.h"
...@@ -88,11 +89,6 @@ class StorageFlattener : public IRMutator { ...@@ -88,11 +89,6 @@ class StorageFlattener : public IRMutator {
for (auto r : e.bounds) { for (auto r : e.bounds) {
shape.push_back(r->extent); shape.push_back(r->extent);
} }
e.buffer = decl_buffer(shape, op->type, key.GetName());
buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
buf_map_[key].released = true;
// deduce current storage scope. // deduce current storage scope.
auto it = storage_scope_.find(op->func.get()); auto it = storage_scope_.find(op->func.get());
CHECK(it != storage_scope_.end()) CHECK(it != storage_scope_.end())
...@@ -107,12 +103,26 @@ class StorageFlattener : public IRMutator { ...@@ -107,12 +103,26 @@ class StorageFlattener : public IRMutator {
} else { } else {
skey = StorageScope::make(strkey); skey = StorageScope::make(strkey);
} }
// use small alignment for small arrays
int32_t const_size = Allocate::constant_allocation_size(shape, key.GetName());
int align = GetTempAllocaAlignment(op->type, const_size);
e.buffer = BufferNode::make(
Var(key.GetName(), Handle()),
op->type, shape,
Array<Expr>(), Expr(),
key.GetName(), skey.to_string(),
align, 0);
buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
buf_map_[key].released = true;
Stmt ret = Allocate::make( Stmt ret = Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape, e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body); make_const(Bool(e.buffer->dtype.lanes()), true), body);
ret = AttrStmt::make( ret = AttrStmt::make(
e.buffer->data, attr::storage_scope, e.buffer->data, attr::storage_scope,
StringImm::make(skey.to_string()), ret); StringImm::make(e.buffer->scope), ret);
return ret; return ret;
} }
} }
......
...@@ -93,6 +93,14 @@ class StorageAccessPatternFinder final : public IRVisitor { ...@@ -93,6 +93,14 @@ class StorageAccessPatternFinder final : public IRVisitor {
AccessEntry(buf, op->index, kRead, GetScope(buf))); AccessEntry(buf, op->index, kRead, GetScope(buf)));
} }
} }
void Visit_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load* l = op->args[0].as<Load>();
this->Visit(l->index);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Variable* buf) final { void Visit_(const Variable* buf) final {
// Directly reference to the variable count as a read. // Directly reference to the variable count as a read.
auto it = alloc_scope_level_.find(buf); auto it = alloc_scope_level_.find(buf);
......
...@@ -16,7 +16,7 @@ def lower(s, args, name="mydot"): ...@@ -16,7 +16,7 @@ def lower(s, args, name="mydot"):
stmt = tvm.ir_pass.StorageFlatten(stmt, binds) stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0) fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi) fapi = tvm.ir_pass.LowerPackedCall(fapi)
return fapi return fapi
......
...@@ -25,7 +25,7 @@ def test_add_pipeline(): ...@@ -25,7 +25,7 @@ def test_add_pipeline():
stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0, True)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)] fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0]) fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
......
...@@ -6,7 +6,8 @@ def test_llvm_add_pipeline(): ...@@ -6,7 +6,8 @@ def test_llvm_add_pipeline():
n = tvm.convert(nn) n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') T = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='T')
C = tvm.compute(A.shape, lambda *i: T(*i), name='C')
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=4) xo, xi = s[C].split(C.op.axis[0], factor=4)
s[C].parallel(xo) s[C].parallel(xo)
...@@ -87,6 +88,7 @@ def test_llvm_madd_pipeline(): ...@@ -87,6 +88,7 @@ def test_llvm_madd_pipeline():
c.asnumpy(), a.asnumpy()[base:] + 1) c.asnumpy(), a.asnumpy()[base:] + 1)
check_llvm(64, 0, 2) check_llvm(64, 0, 2)
check_llvm(4, 0, 1) check_llvm(4, 0, 1)
with tvm.build_config(restricted_func=False):
check_llvm(4, 0, 3) check_llvm(4, 0, 3)
def test_llvm_temp_space(): def test_llvm_temp_space():
......
...@@ -19,7 +19,7 @@ def test_stack_vm_basic(): ...@@ -19,7 +19,7 @@ def test_stack_vm_basic():
n = tvm.var('n') n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), tvm.float32) Ab = tvm.decl_buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0])) stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi) fapi = tvm.ir_pass.LowerPackedCall(fapi)
run_jit(fapi, lambda f: f(a)) run_jit(fapi, lambda f: f(a))
...@@ -41,7 +41,7 @@ def test_stack_vm_loop(): ...@@ -41,7 +41,7 @@ def test_stack_vm_loop():
ib.emit(tvm.call_packed("tvm_stack_vm_print", i)) ib.emit(tvm.call_packed("tvm_stack_vm_print", i))
stmt = ib.get() stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi) fapi = tvm.ir_pass.LowerPackedCall(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f): def check(f):
...@@ -64,7 +64,7 @@ def test_stack_vm_cond(): ...@@ -64,7 +64,7 @@ def test_stack_vm_cond():
A[i + 1] = A[i] + 2 A[i + 1] = A[i] + 2
stmt = ib.get() stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi) fapi = tvm.ir_pass.LowerPackedCall(fapi)
def check(f): def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
......
...@@ -37,7 +37,7 @@ def test_dso_module_load(): ...@@ -37,7 +37,7 @@ def test_dso_module_load():
tvm.make.Store(Ab.data, tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1, tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1)) i + 1))
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi) fapi = tvm.ir_pass.LowerPackedCall(fapi)
m = tvm.codegen.build_module(fapi, "llvm") m = tvm.codegen.build_module(fapi, "llvm")
for name in names: for name in names:
......
...@@ -19,7 +19,7 @@ def test_makeapi(): ...@@ -19,7 +19,7 @@ def test_makeapi():
num_unpacked_args = 2 num_unpacked_args = 2
f = tvm.ir_pass.MakeAPI( f = tvm.ir_pass.MakeAPI(
stmt, "myadd", [n, Ab, Bb, Cb], num_unpacked_args) stmt, "myadd", [n, Ab, Bb, Cb], num_unpacked_args, True)
assert(f.handle_data_type[Ab.data].dtype == Ab.dtype) assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
assert(len(f.args) == 5) assert(len(f.args) == 5)
output_ssa = False output_ssa = False
......
...@@ -20,7 +20,7 @@ def test_storage_sync(): ...@@ -20,7 +20,7 @@ def test_storage_sync():
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2') A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0) f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
flist = tvm.ir_pass.SplitHostDevice(f) flist = tvm.ir_pass.SplitHostDevice(f)
f = flist[1] f = flist[1]
f = tvm.ir_pass.StorageSync(f, "shared") f = tvm.ir_pass.StorageSync(f, "shared")
......
...@@ -24,7 +24,7 @@ def test_dltensor_compatible(): ...@@ -24,7 +24,7 @@ def test_dltensor_compatible():
with ib.for_range(0, n - 1, "i") as i: with ib.for_range(0, n - 1, "i") as i:
A[i + 1] = A[i] + 1 A[i + 1] = A[i] + 1
stmt = ib.get() stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi) fapi = tvm.ir_pass.LowerPackedCall(fapi)
f = tvm.codegen.build_module(fapi, "stackvm") f = tvm.codegen.build_module(fapi, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment