Commit 8ef26606 by Tianqi Chen Committed by GitHub

[SCHEDULE][PASS] support storage_align of certain axis (#400)

* [SCHEDULE][PASS] support storage_align of certain axis

* fix lint
parent b03c3243
......@@ -183,6 +183,13 @@ constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
constexpr const char* scan_init_scope = "scan_init_scope";
/*!
* \brief Mark alignment of buffer dimension
* stmt.node is Tensor
* stmt.value is tvm_tuple(dim, align, offset)
* This gives hint to require stride of dim to be k * align + offset.
*/
constexpr const char* buffer_dim_align = "buffer_dim_align";
/*!
* \brief Bind the buffer specification to the region of the op
* When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor]
* stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...).
......
......@@ -104,13 +104,13 @@ class OperationNode : public FunctionBaseNode {
/*!
* \brief Build the Realize statement that realizes
* the op's output tensors.
* \param self The reference to self.
* \param stage the op's stage.
* \param realize_map The realization domain map of the operators.
* \param body The body that is going to get
* \return A realization statement that wraps body.
*/
virtual Stmt BuildRealize(
const Operation& self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const = 0;
/*!
......@@ -155,7 +155,7 @@ class PlaceholderOpNode : public OperationNode {
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Operation& self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
......@@ -206,7 +206,7 @@ class ComputeOpNode : public OperationNode {
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Operation& self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
......@@ -277,7 +277,7 @@ class ScanOpNode : public OperationNode {
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Operation& self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
......@@ -340,7 +340,7 @@ class ExternOpNode : public OperationNode {
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Operation& self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
......
......@@ -198,6 +198,17 @@ class Stage : public NodeRef {
*/
Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
/*!
* \brief Set alignment requirement for specific dimension.
*
* Such that stride[axis] == k * factor + offset for some k.
*
* \param axis The dimension to be specified for alignment.
* \param factor The factor multiple of alignment
* \param offset The required offset factor.
* \return reference to self
*/
Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
/*!
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
*/
......@@ -496,6 +507,10 @@ class IterVarAttrNode : public Node {
* when the axis is marked as Tensorized
*/
TensorIntrin tensor_intrin;
/*! \brief Alignment factor of buffer dimension */
int dim_align_factor{0};
/*! \brief Alignment offset of buffer dimension */
int dim_align_offset{0};
/*!
* \brief Additional pragmas, array of StringImm
*/
......@@ -507,6 +522,8 @@ class IterVarAttrNode : public Node {
v->Visit("prefetch_data", &prefetch_data);
v->Visit("prefetch_offset", &prefetch_offset);
v->Visit("tensor_intrin", &tensor_intrin);
v->Visit("dim_align_factor", &dim_align_factor);
v->Visit("dim_align_offset", &dim_align_offset);
v->Visit("pragmas", &pragmas);
}
......
......@@ -569,4 +569,24 @@ class Stage(NodeBase):
"""
_api_internal._StagePrefetch(self, tensor, var, offset)
def storage_align(self, axis, factor, offset):
"""Set alignment requirement for specific axis
This ensures that stride[axis] == k * factor + offset for some k.
This is useful to set memory layout to for more friendly memory
access pattern. For example, we can set alignment to be
factor=2, offset=1 to avoid bank conflict for thread access on
higher dimension in GPU shared memory.
Parameters
----------
axis : IterVar
The axis dimension to be aligned.
factor : int
The factor in alignment specification.
offset : int
The offset in the alignment specification.
"""
_api_internal._StageStorageAlign(self, axis, factor, offset)
_init_api("tvm.schedule")
......@@ -388,6 +388,12 @@ TVM_REGISTER_API("_StagePrefetch")
.prefetch(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_StageStorageAlign")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage()
.storage_align(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_ScheduleNormalize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule()
......
......@@ -198,19 +198,35 @@ void ComputeOpNode::GatherBound(
}
Stmt ComputeOpNode::BuildRealize(
const Operation& self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& realize_body) const {
CHECK_EQ(self.operator->(), this);
CHECK_EQ(stage->op.get(), this);
Halide::Internal::Region bounds;
for (IterVar iv : this->axis) {
bounds.push_back(realize_map.at(iv));
}
Stmt realize = realize_body;
for (int i = self->num_outputs(); i > 0; --i) {
Tensor t = self.output(i-1);
for (int i = this->num_outputs(); i > 0; --i) {
Tensor t = stage->op.output(i-1);
realize = ir::Realize::make(t->op, t->value_index,
t->dtype, bounds, const_true(), realize);
// alignment requirement, only useful for compute
for (size_t i = 0; i < this->axis.size(); ++i) {
auto it = stage->iter_var_attrs.find(this->axis[i]);
if (it != stage->iter_var_attrs.end()) {
IterVarAttr attr = (*it).second;
if (attr->dim_align_factor != 0) {
Array<Expr> tuple = {static_cast<int>(i),
attr->dim_align_factor,
attr->dim_align_offset};
realize = ir::AttrStmt::make(
t, ir::attr::buffer_dim_align,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
realize);
}
}
}
}
return realize;
}
......@@ -304,7 +320,7 @@ enum class ComputeType {
};
ComputeType DetectComputeType(const ComputeOpNode* self,
const Stage& stage) {
const Stage& stage) {
// Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0, tensorize = 0;
......
......@@ -106,13 +106,13 @@ void ExternOpNode::GatherBound(
}
Stmt ExternOpNode::BuildRealize(
const Operation& self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(self.operator->(), this);
CHECK_EQ(stage->op.get(), this);
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = self.output(k);
Tensor t = stage->op.output(k);
Halide::Internal::Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
......
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017 by5A Contributors
* \brief Utility to make loop nest.
* \file op_util.cc
*/
......
......@@ -52,6 +52,7 @@ MakeBoundCheck(const Stage& stage,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter,
const std::unordered_map<IterVar, Expr>& value_map);
/*!
* \brief Create a nest of if checking the predicates.
*
......
......@@ -70,7 +70,7 @@ void PlaceholderOpNode::GatherBound(
}
Stmt PlaceholderOpNode::BuildRealize(
const Operation& self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
return body;
......
......@@ -226,17 +226,17 @@ void ScanOpNode::GatherBound(
}
Stmt ScanOpNode::BuildRealize(
const Operation& self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const Stmt& body) const {
CHECK_EQ(self.operator->(), this);
CHECK_EQ(stage->op.get(), this);
Range sdom = dom_map.at(this->scan_axis);
Range tdom = Range::make_by_min_extent(
0, ir::Simplify(sdom->extent + sdom->min));
Stmt ret = body;
size_t sp_idx = 0;
for (size_t i = 0; i < update.size(); ++i) {
Tensor t = self.output(i);
Tensor t = stage->op.output(i);
CHECK_EQ(static_cast<size_t>(t->value_index), i);
Halide::Internal::Region bounds;
bounds.push_back(tdom);
......
......@@ -62,6 +62,19 @@ class StorageFlattener : public IRMutator {
return stmt;
} else if (op->attr_key == attr::buffer_bind_scope) {
return HandleBufferBindScope(op);
} else if (op->attr_key == attr::buffer_dim_align) {
Tensor tensor(op->node.node_);
const Call* tuple = op->value.as<Call>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
TensorKey key{tensor->op, tensor->value_index};
auto& vinfo = dim_align_[key];
int dim = tuple->args[0].as<IntImm>()->value;
if (static_cast<size_t>(dim) >= vinfo.size()) {
vinfo.resize(dim + 1);
}
vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
return this->Mutate(op->body);
}
return IRMutator::Mutate_(op, s);
}
......@@ -116,20 +129,45 @@ class StorageFlattener : public IRMutator {
align = (info->max_simd_bits + op->type.bits() - 1) / op->type.bits();
}
}
Array<Expr> strides;
if (dim_align_.count(key) != 0) {
std::vector<Expr> rstrides;
const std::vector<DimAlignInfo>& avec = dim_align_[key];
Expr stride = make_const(shape[0].type(), 1);
for (size_t i = shape.size(); i != 0; --i) {
size_t dim = i - 1;
if (dim < avec.size() && avec[dim].align_factor != 0) {
Expr factor = make_const(stride.type(), avec[dim].align_factor);
Expr offset = make_const(stride.type(), avec[dim].align_offset);
stride = stride + (factor + offset - stride % factor) % factor;
stride = ir::Simplify(stride);
}
rstrides.push_back(stride);
stride = arith::ComputeExpr<Mul>(stride, shape[dim]);
}
strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
}
e.buffer = BufferNode::make(
Var(key.GetName(), Handle()),
op->type, shape,
Array<Expr>(), Expr(),
op->type, shape, strides, Expr(),
key.GetName(), skey.to_string(),
align, 0);
buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
buf_map_[key].released = true;
Stmt ret;
Stmt ret = Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
if (strides.size() != 0) {
ret = Allocate::make(
e.buffer->data, e.buffer->dtype,
{arith::ComputeExpr<Mul>(e.buffer->strides[0], e.buffer->shape[0])},
make_const(Bool(e.buffer->dtype.lanes()), true), body);
} else {
ret = Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
}
ret = AttrStmt::make(
e.buffer->data, attr::storage_scope,
StringImm::make(e.buffer->scope), ret);
......@@ -283,7 +321,11 @@ class StorageFlattener : public IRMutator {
}
return body;
}
// The buffer entry in the flatten map
struct DimAlignInfo {
int align_factor{0};
int align_offset{0};
};
// The buffer entry in the flatten map
struct BufferEntry {
// the buffer of storage
......@@ -294,7 +336,6 @@ class StorageFlattener : public IRMutator {
bool external{false};
// Whether we are out of allocation bounds and buffer get released.
bool released{false};
// TODO(tqchen) allow permutation and inference of index dimension.
// relative index
inline Array<Expr> RelIndex(Array<Expr> args) const {
if (bounds.size() != 0) {
......@@ -314,6 +355,9 @@ class StorageFlattener : public IRMutator {
std::unordered_map<const Variable*, Expr> var_remap_;
// Buffer map
std::unordered_map<TensorKey, BufferEntry> buf_map_;
// Dimension alignment
std::unordered_map<TensorKey, std::vector<DimAlignInfo> > dim_align_;
// Storage scope
std::unordered_map<const Node*, std::string> storage_scope_;
// The current thread scope.
std::vector<ThreadScope> curr_thread_scope_;
......
......@@ -296,10 +296,15 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
}
template<typename FUpdate>
inline void UpdateIterVarAttr(StageNode* self, IterVar var, FUpdate fupdate) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
inline void UpdateIterVarAttr(StageNode* self,
IterVar var,
FUpdate fupdate,
bool need_leaf = true) {
if (need_leaf) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
}
auto it = self->iter_var_attrs.find(var);
std::shared_ptr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) {
......@@ -371,6 +376,15 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) {
return *this;
}
Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
StageNode *self = operator->();
UpdateIterVarAttr(self, axis, [factor, offset](IterVarAttrNode* n) {
n->dim_align_factor = factor;
n->dim_align_offset = offset;
}, false);
return *this;
}
Stage CopyStage(const Stage& s) {
std::shared_ptr<StageNode> n =
std::make_shared<StageNode>(*s.operator->());
......
......@@ -33,7 +33,7 @@ Stmt MakePipeline(const Stage& s,
consumer = ProducerConsumer::make(s->op, false, consumer);
pipeline = Block::make(producer, consumer);
}
pipeline = s->op->BuildRealize(s->op, dom_map, pipeline);
pipeline = s->op->BuildRealize(s, dom_map, pipeline);
// use attribute to mark scope of the operation.
pipeline = AttrStmt::make(
s->op, ir::attr::realize_scope,
......@@ -194,6 +194,18 @@ class SchedulePostProc : public IRMutator {
return this->Mutate(op->body);
}
}
} else if (op->attr_key == ir::attr::buffer_dim_align) {
Tensor tensor(op->node.node_);
auto it = replace_op_.find(tensor->op.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
return AttrStmt::make(
it->second.output(tensor->value_index),
op->attr_key, op->value, Mutate(op->body));
} else {
return this->Mutate(op->body);
}
}
}
return IRMutator::Mutate_(op, s);
}
......
......@@ -32,6 +32,27 @@ def test_flatten_prefetch():
assert isinstance(stmt.body, tvm.stmt.For)
assert stmt.body.extent.value == 2
def test_flatten_storage_align():
m = 8
l = 16
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.create_schedule(A2.op)
s[A1].storage_align(A1.op.axis[0], 2, 1)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
stmt = tvm.ir_pass.Simplify(stmt)
assert(stmt.body.extents[0].value == 17 * 8)
if __name__ == "__main__":
test_flatten_storage_align()
test_flatten2()
test_flatten_prefetch()
......@@ -28,14 +28,13 @@ from tvm.contrib import rpc, util
# local machine, we need build runtime on remote device.
#
# To get started, clone tvm repo from github. It is important to clone
# the submodules along, with --recursive option (Assuming you are in
# your home directory):
# the submodules along, with --recursive option (Assuming you are in
# your home directory):
#
# .. code-block:: bash
#
#
# git clone --recursive https://github.com/dmlc/tvm
#
######################################################################
#
# .. note::
#
# Usually device has limited resources and we only need to build
......@@ -51,14 +50,13 @@ from tvm.contrib import rpc, util
#
# Also make sure that you have set :code:`USE_RPC=1` in your
# :code:`config.mk`. We don't need LLVM when building runtime, so
# :code:`LLVM_CONFIG = llvm-config` in :code:`config.mk`is commented
# :code:`LLVM_CONFIG = llvm-config` in :code:`config.mk` is commented
# out by default. After that, build runtime!
#
# .. code-block:: bash
#
# make runtime
#
######################################################################
# After success of buildind runtime, we need set environment varibles
# in :code:`~/.bashrc` file of yourself account or :code:`/etc/profile`
# of system enviroment variables. Assuming your TVM directory is in
......@@ -95,7 +93,7 @@ from tvm.contrib import rpc, util
# successful to start RPC server on your device.
#
# .. code-block:: bash
#
#
# Loading runtime library /home/YOURNAME/code/tvm/lib/libtvm_runtime.so... exec only
# INFO:root:RPCServer: bind to 0.0.0.0:9090
#
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment