Commit 8ef26606 by Tianqi Chen Committed by GitHub

[SCHEDULE][PASS] support storage_align of certain axis (#400)

* [SCHEDULE][PASS] support storage_align of certain axis

* fix lint
parent b03c3243
...@@ -183,6 +183,13 @@ constexpr const char* scan_update_scope = "scan_update_scope"; ...@@ -183,6 +183,13 @@ constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */ /*! \brief Mark of scan init scope */
constexpr const char* scan_init_scope = "scan_init_scope"; constexpr const char* scan_init_scope = "scan_init_scope";
/*! /*!
* \brief Mark alignment of buffer dimension
* stmt.node is Tensor
* stmt.value is tvm_tuple(dim, align, offset)
* This gives hint to require stride of dim to be k * align + offset.
*/
constexpr const char* buffer_dim_align = "buffer_dim_align";
/*!
* \brief Bind the buffer specification to the region of the op * \brief Bind the buffer specification to the region of the op
* When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor] * When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor]
* stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...). * stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...).
......
...@@ -104,13 +104,13 @@ class OperationNode : public FunctionBaseNode { ...@@ -104,13 +104,13 @@ class OperationNode : public FunctionBaseNode {
/*! /*!
* \brief Build the Realize statement that realizes * \brief Build the Realize statement that realizes
* the op's output tensors. * the op's output tensors.
* \param self The reference to self. * \param stage the op's stage.
* \param realize_map The realization domain map of the operators. * \param realize_map The realization domain map of the operators.
* \param body The body that is going to get * \param body The body that is going to get
* \return A realization statement that wraps body. * \return A realization statement that wraps body.
*/ */
virtual Stmt BuildRealize( virtual Stmt BuildRealize(
const Operation& self, const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const = 0; const Stmt& body) const = 0;
/*! /*!
...@@ -155,7 +155,7 @@ class PlaceholderOpNode : public OperationNode { ...@@ -155,7 +155,7 @@ class PlaceholderOpNode : public OperationNode {
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final; std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize( Stmt BuildRealize(
const Operation& self, const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final; const Stmt& body) const final;
Stmt BuildProvide( Stmt BuildProvide(
...@@ -206,7 +206,7 @@ class ComputeOpNode : public OperationNode { ...@@ -206,7 +206,7 @@ class ComputeOpNode : public OperationNode {
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final; std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize( Stmt BuildRealize(
const Operation& self, const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final; const Stmt& body) const final;
Stmt BuildProvide( Stmt BuildProvide(
...@@ -277,7 +277,7 @@ class ScanOpNode : public OperationNode { ...@@ -277,7 +277,7 @@ class ScanOpNode : public OperationNode {
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final; std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize( Stmt BuildRealize(
const Operation& self, const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final; const Stmt& body) const final;
Stmt BuildProvide( Stmt BuildProvide(
...@@ -340,7 +340,7 @@ class ExternOpNode : public OperationNode { ...@@ -340,7 +340,7 @@ class ExternOpNode : public OperationNode {
const std::unordered_map<Tensor, TensorDom>& tensor_dom, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final; std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize( Stmt BuildRealize(
const Operation& self, const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final; const Stmt& body) const final;
Stmt BuildProvide( Stmt BuildProvide(
......
...@@ -198,6 +198,17 @@ class Stage : public NodeRef { ...@@ -198,6 +198,17 @@ class Stage : public NodeRef {
*/ */
Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*) Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
/*! /*!
* \brief Set alignment requirement for specific dimension.
*
* Such that stride[axis] == k * factor + offset for some k.
*
* \param axis The dimension to be specified for alignment.
* \param factor The factor multiple of alignment
* \param offset The required offset factor.
* \return reference to self
*/
Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
/*!
* \brief whether the stage has been scheduled. * \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled. * \return whether the stage has been scheduled.
*/ */
...@@ -496,6 +507,10 @@ class IterVarAttrNode : public Node { ...@@ -496,6 +507,10 @@ class IterVarAttrNode : public Node {
* when the axis is marked as Tensorized * when the axis is marked as Tensorized
*/ */
TensorIntrin tensor_intrin; TensorIntrin tensor_intrin;
/*! \brief Alignment factor of buffer dimension */
int dim_align_factor{0};
/*! \brief Alignment offset of buffer dimension */
int dim_align_offset{0};
/*! /*!
* \brief Additional pragmas, array of StringImm * \brief Additional pragmas, array of StringImm
*/ */
...@@ -507,6 +522,8 @@ class IterVarAttrNode : public Node { ...@@ -507,6 +522,8 @@ class IterVarAttrNode : public Node {
v->Visit("prefetch_data", &prefetch_data); v->Visit("prefetch_data", &prefetch_data);
v->Visit("prefetch_offset", &prefetch_offset); v->Visit("prefetch_offset", &prefetch_offset);
v->Visit("tensor_intrin", &tensor_intrin); v->Visit("tensor_intrin", &tensor_intrin);
v->Visit("dim_align_factor", &dim_align_factor);
v->Visit("dim_align_offset", &dim_align_offset);
v->Visit("pragmas", &pragmas); v->Visit("pragmas", &pragmas);
} }
......
...@@ -569,4 +569,24 @@ class Stage(NodeBase): ...@@ -569,4 +569,24 @@ class Stage(NodeBase):
""" """
_api_internal._StagePrefetch(self, tensor, var, offset) _api_internal._StagePrefetch(self, tensor, var, offset)
def storage_align(self, axis, factor, offset):
"""Set alignment requirement for specific axis
This ensures that stride[axis] == k * factor + offset for some k.
This is useful to set memory layout to for more friendly memory
access pattern. For example, we can set alignment to be
factor=2, offset=1 to avoid bank conflict for thread access on
higher dimension in GPU shared memory.
Parameters
----------
axis : IterVar
The axis dimension to be aligned.
factor : int
The factor in alignment specification.
offset : int
The offset in the alignment specification.
"""
_api_internal._StageStorageAlign(self, axis, factor, offset)
_init_api("tvm.schedule") _init_api("tvm.schedule")
...@@ -388,6 +388,12 @@ TVM_REGISTER_API("_StagePrefetch") ...@@ -388,6 +388,12 @@ TVM_REGISTER_API("_StagePrefetch")
.prefetch(args[1], args[2], args[3]); .prefetch(args[1], args[2], args[3]);
}); });
TVM_REGISTER_API("_StageStorageAlign")
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Stage()
.storage_align(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_ScheduleNormalize") TVM_REGISTER_API("_ScheduleNormalize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Schedule() *ret = args[0].operator Schedule()
......
...@@ -198,19 +198,35 @@ void ComputeOpNode::GatherBound( ...@@ -198,19 +198,35 @@ void ComputeOpNode::GatherBound(
} }
Stmt ComputeOpNode::BuildRealize( Stmt ComputeOpNode::BuildRealize(
const Operation& self, const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& realize_body) const { const Stmt& realize_body) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(stage->op.get(), this);
Halide::Internal::Region bounds; Halide::Internal::Region bounds;
for (IterVar iv : this->axis) { for (IterVar iv : this->axis) {
bounds.push_back(realize_map.at(iv)); bounds.push_back(realize_map.at(iv));
} }
Stmt realize = realize_body; Stmt realize = realize_body;
for (int i = self->num_outputs(); i > 0; --i) { for (int i = this->num_outputs(); i > 0; --i) {
Tensor t = self.output(i-1); Tensor t = stage->op.output(i-1);
realize = ir::Realize::make(t->op, t->value_index, realize = ir::Realize::make(t->op, t->value_index,
t->dtype, bounds, const_true(), realize); t->dtype, bounds, const_true(), realize);
// alignment requirement, only useful for compute
for (size_t i = 0; i < this->axis.size(); ++i) {
auto it = stage->iter_var_attrs.find(this->axis[i]);
if (it != stage->iter_var_attrs.end()) {
IterVarAttr attr = (*it).second;
if (attr->dim_align_factor != 0) {
Array<Expr> tuple = {static_cast<int>(i),
attr->dim_align_factor,
attr->dim_align_offset};
realize = ir::AttrStmt::make(
t, ir::attr::buffer_dim_align,
Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
realize);
}
}
}
} }
return realize; return realize;
} }
...@@ -304,7 +320,7 @@ enum class ComputeType { ...@@ -304,7 +320,7 @@ enum class ComputeType {
}; };
ComputeType DetectComputeType(const ComputeOpNode* self, ComputeType DetectComputeType(const ComputeOpNode* self,
const Stage& stage) { const Stage& stage) {
// Verify correctness of leaf nest. // Verify correctness of leaf nest.
int normal_red = 0, thread_red = 0, tensorize = 0; int normal_red = 0, thread_red = 0, tensorize = 0;
......
...@@ -106,13 +106,13 @@ void ExternOpNode::GatherBound( ...@@ -106,13 +106,13 @@ void ExternOpNode::GatherBound(
} }
Stmt ExternOpNode::BuildRealize( Stmt ExternOpNode::BuildRealize(
const Operation& self, const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const { const Stmt& body) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(stage->op.get(), this);
Stmt realize_body = body; Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) { for (int k = 0; k < num_outputs(); ++k) {
Tensor t = self.output(k); Tensor t = stage->op.output(k);
Halide::Internal::Region bounds; Halide::Internal::Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) { for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back( bounds.push_back(
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by5A Contributors
* \brief Utility to make loop nest. * \brief Utility to make loop nest.
* \file op_util.cc * \file op_util.cc
*/ */
......
...@@ -52,6 +52,7 @@ MakeBoundCheck(const Stage& stage, ...@@ -52,6 +52,7 @@ MakeBoundCheck(const Stage& stage,
bool skip_ivar_domain, bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter, const std::unordered_set<IterVar>& skip_iter,
const std::unordered_map<IterVar, Expr>& value_map); const std::unordered_map<IterVar, Expr>& value_map);
/*! /*!
* \brief Create a nest of if checking the predicates. * \brief Create a nest of if checking the predicates.
* *
......
...@@ -70,7 +70,7 @@ void PlaceholderOpNode::GatherBound( ...@@ -70,7 +70,7 @@ void PlaceholderOpNode::GatherBound(
} }
Stmt PlaceholderOpNode::BuildRealize( Stmt PlaceholderOpNode::BuildRealize(
const Operation& self, const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const { const Stmt& body) const {
return body; return body;
......
...@@ -226,17 +226,17 @@ void ScanOpNode::GatherBound( ...@@ -226,17 +226,17 @@ void ScanOpNode::GatherBound(
} }
Stmt ScanOpNode::BuildRealize( Stmt ScanOpNode::BuildRealize(
const Operation& self, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
const Stmt& body) const { const Stmt& body) const {
CHECK_EQ(self.operator->(), this); CHECK_EQ(stage->op.get(), this);
Range sdom = dom_map.at(this->scan_axis); Range sdom = dom_map.at(this->scan_axis);
Range tdom = Range::make_by_min_extent( Range tdom = Range::make_by_min_extent(
0, ir::Simplify(sdom->extent + sdom->min)); 0, ir::Simplify(sdom->extent + sdom->min));
Stmt ret = body; Stmt ret = body;
size_t sp_idx = 0; size_t sp_idx = 0;
for (size_t i = 0; i < update.size(); ++i) { for (size_t i = 0; i < update.size(); ++i) {
Tensor t = self.output(i); Tensor t = stage->op.output(i);
CHECK_EQ(static_cast<size_t>(t->value_index), i); CHECK_EQ(static_cast<size_t>(t->value_index), i);
Halide::Internal::Region bounds; Halide::Internal::Region bounds;
bounds.push_back(tdom); bounds.push_back(tdom);
......
...@@ -62,6 +62,19 @@ class StorageFlattener : public IRMutator { ...@@ -62,6 +62,19 @@ class StorageFlattener : public IRMutator {
return stmt; return stmt;
} else if (op->attr_key == attr::buffer_bind_scope) { } else if (op->attr_key == attr::buffer_bind_scope) {
return HandleBufferBindScope(op); return HandleBufferBindScope(op);
} else if (op->attr_key == attr::buffer_dim_align) {
Tensor tensor(op->node.node_);
const Call* tuple = op->value.as<Call>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
TensorKey key{tensor->op, tensor->value_index};
auto& vinfo = dim_align_[key];
int dim = tuple->args[0].as<IntImm>()->value;
if (static_cast<size_t>(dim) >= vinfo.size()) {
vinfo.resize(dim + 1);
}
vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
return this->Mutate(op->body);
} }
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
...@@ -116,20 +129,45 @@ class StorageFlattener : public IRMutator { ...@@ -116,20 +129,45 @@ class StorageFlattener : public IRMutator {
align = (info->max_simd_bits + op->type.bits() - 1) / op->type.bits(); align = (info->max_simd_bits + op->type.bits() - 1) / op->type.bits();
} }
} }
Array<Expr> strides;
if (dim_align_.count(key) != 0) {
std::vector<Expr> rstrides;
const std::vector<DimAlignInfo>& avec = dim_align_[key];
Expr stride = make_const(shape[0].type(), 1);
for (size_t i = shape.size(); i != 0; --i) {
size_t dim = i - 1;
if (dim < avec.size() && avec[dim].align_factor != 0) {
Expr factor = make_const(stride.type(), avec[dim].align_factor);
Expr offset = make_const(stride.type(), avec[dim].align_offset);
stride = stride + (factor + offset - stride % factor) % factor;
stride = ir::Simplify(stride);
}
rstrides.push_back(stride);
stride = arith::ComputeExpr<Mul>(stride, shape[dim]);
}
strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
}
e.buffer = BufferNode::make( e.buffer = BufferNode::make(
Var(key.GetName(), Handle()), Var(key.GetName(), Handle()),
op->type, shape, op->type, shape, strides, Expr(),
Array<Expr>(), Expr(),
key.GetName(), skey.to_string(), key.GetName(), skey.to_string(),
align, 0); align, 0);
buf_map_[key] = e; buf_map_[key] = e;
Stmt body = this->Mutate(op->body); Stmt body = this->Mutate(op->body);
buf_map_[key].released = true; buf_map_[key].released = true;
Stmt ret;
Stmt ret = Allocate::make( if (strides.size() != 0) {
e.buffer->data, e.buffer->dtype, e.buffer->shape, ret = Allocate::make(
make_const(Bool(e.buffer->dtype.lanes()), true), body); e.buffer->data, e.buffer->dtype,
{arith::ComputeExpr<Mul>(e.buffer->strides[0], e.buffer->shape[0])},
make_const(Bool(e.buffer->dtype.lanes()), true), body);
} else {
ret = Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
}
ret = AttrStmt::make( ret = AttrStmt::make(
e.buffer->data, attr::storage_scope, e.buffer->data, attr::storage_scope,
StringImm::make(e.buffer->scope), ret); StringImm::make(e.buffer->scope), ret);
...@@ -283,7 +321,11 @@ class StorageFlattener : public IRMutator { ...@@ -283,7 +321,11 @@ class StorageFlattener : public IRMutator {
} }
return body; return body;
} }
// The buffer entry in the flatten map
struct DimAlignInfo {
int align_factor{0};
int align_offset{0};
};
// The buffer entry in the flatten map // The buffer entry in the flatten map
struct BufferEntry { struct BufferEntry {
// the buffer of storage // the buffer of storage
...@@ -294,7 +336,6 @@ class StorageFlattener : public IRMutator { ...@@ -294,7 +336,6 @@ class StorageFlattener : public IRMutator {
bool external{false}; bool external{false};
// Whether we are out of allocation bounds and buffer get released. // Whether we are out of allocation bounds and buffer get released.
bool released{false}; bool released{false};
// TODO(tqchen) allow permutation and inference of index dimension.
// relative index // relative index
inline Array<Expr> RelIndex(Array<Expr> args) const { inline Array<Expr> RelIndex(Array<Expr> args) const {
if (bounds.size() != 0) { if (bounds.size() != 0) {
...@@ -314,6 +355,9 @@ class StorageFlattener : public IRMutator { ...@@ -314,6 +355,9 @@ class StorageFlattener : public IRMutator {
std::unordered_map<const Variable*, Expr> var_remap_; std::unordered_map<const Variable*, Expr> var_remap_;
// Buffer map // Buffer map
std::unordered_map<TensorKey, BufferEntry> buf_map_; std::unordered_map<TensorKey, BufferEntry> buf_map_;
// Dimension alignment
std::unordered_map<TensorKey, std::vector<DimAlignInfo> > dim_align_;
// Storage scope
std::unordered_map<const Node*, std::string> storage_scope_; std::unordered_map<const Node*, std::string> storage_scope_;
// The current thread scope. // The current thread scope.
std::vector<ThreadScope> curr_thread_scope_; std::vector<ThreadScope> curr_thread_scope_;
......
...@@ -296,10 +296,15 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent, ...@@ -296,10 +296,15 @@ Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
} }
template<typename FUpdate> template<typename FUpdate>
inline void UpdateIterVarAttr(StageNode* self, IterVar var, FUpdate fupdate) { inline void UpdateIterVarAttr(StageNode* self,
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); IterVar var,
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); FUpdate fupdate,
FindLeafVar(all_vars, leaf_vars, var); bool need_leaf = true) {
if (need_leaf) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
}
auto it = self->iter_var_attrs.find(var); auto it = self->iter_var_attrs.find(var);
std::shared_ptr<IterVarAttrNode> n; std::shared_ptr<IterVarAttrNode> n;
if (it != self->iter_var_attrs.end()) { if (it != self->iter_var_attrs.end()) {
...@@ -371,6 +376,15 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) { ...@@ -371,6 +376,15 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) {
return *this; return *this;
} }
Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
StageNode *self = operator->();
UpdateIterVarAttr(self, axis, [factor, offset](IterVarAttrNode* n) {
n->dim_align_factor = factor;
n->dim_align_offset = offset;
}, false);
return *this;
}
Stage CopyStage(const Stage& s) { Stage CopyStage(const Stage& s) {
std::shared_ptr<StageNode> n = std::shared_ptr<StageNode> n =
std::make_shared<StageNode>(*s.operator->()); std::make_shared<StageNode>(*s.operator->());
......
...@@ -33,7 +33,7 @@ Stmt MakePipeline(const Stage& s, ...@@ -33,7 +33,7 @@ Stmt MakePipeline(const Stage& s,
consumer = ProducerConsumer::make(s->op, false, consumer); consumer = ProducerConsumer::make(s->op, false, consumer);
pipeline = Block::make(producer, consumer); pipeline = Block::make(producer, consumer);
} }
pipeline = s->op->BuildRealize(s->op, dom_map, pipeline); pipeline = s->op->BuildRealize(s, dom_map, pipeline);
// use attribute to mark scope of the operation. // use attribute to mark scope of the operation.
pipeline = AttrStmt::make( pipeline = AttrStmt::make(
s->op, ir::attr::realize_scope, s->op, ir::attr::realize_scope,
...@@ -194,6 +194,18 @@ class SchedulePostProc : public IRMutator { ...@@ -194,6 +194,18 @@ class SchedulePostProc : public IRMutator {
return this->Mutate(op->body); return this->Mutate(op->body);
} }
} }
} else if (op->attr_key == ir::attr::buffer_dim_align) {
Tensor tensor(op->node.node_);
auto it = replace_op_.find(tensor->op.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
return AttrStmt::make(
it->second.output(tensor->value_index),
op->attr_key, op->value, Mutate(op->body));
} else {
return this->Mutate(op->body);
}
}
} }
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
......
...@@ -32,6 +32,27 @@ def test_flatten_prefetch(): ...@@ -32,6 +32,27 @@ def test_flatten_prefetch():
assert isinstance(stmt.body, tvm.stmt.For) assert isinstance(stmt.body, tvm.stmt.For)
assert stmt.body.extent.value == 2 assert stmt.body.extent.value == 2
def test_flatten_storage_align():
m = 8
l = 16
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.create_schedule(A2.op)
s[A1].storage_align(A1.op.axis[0], 2, 1)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
stmt = tvm.ir_pass.Simplify(stmt)
assert(stmt.body.extents[0].value == 17 * 8)
if __name__ == "__main__": if __name__ == "__main__":
test_flatten_storage_align()
test_flatten2() test_flatten2()
test_flatten_prefetch() test_flatten_prefetch()
...@@ -28,14 +28,13 @@ from tvm.contrib import rpc, util ...@@ -28,14 +28,13 @@ from tvm.contrib import rpc, util
# local machine, we need build runtime on remote device. # local machine, we need build runtime on remote device.
# #
# To get started, clone tvm repo from github. It is important to clone # To get started, clone tvm repo from github. It is important to clone
# the submodules along, with --recursive option (Assuming you are in # the submodules along, with --recursive option (Assuming you are in
# your home directory): # your home directory):
# #
# .. code-block:: bash # .. code-block:: bash
# #
# git clone --recursive https://github.com/dmlc/tvm # git clone --recursive https://github.com/dmlc/tvm
# #
######################################################################
# .. note:: # .. note::
# #
# Usually device has limited resources and we only need to build # Usually device has limited resources and we only need to build
...@@ -51,14 +50,13 @@ from tvm.contrib import rpc, util ...@@ -51,14 +50,13 @@ from tvm.contrib import rpc, util
# #
# Also make sure that you have set :code:`USE_RPC=1` in your # Also make sure that you have set :code:`USE_RPC=1` in your
# :code:`config.mk`. We don't need LLVM when building runtime, so # :code:`config.mk`. We don't need LLVM when building runtime, so
# :code:`LLVM_CONFIG = llvm-config` in :code:`config.mk`is commented # :code:`LLVM_CONFIG = llvm-config` in :code:`config.mk` is commented
# out by default. After that, build runtime! # out by default. After that, build runtime!
# #
# .. code-block:: bash # .. code-block:: bash
# #
# make runtime # make runtime
# #
######################################################################
# After success of buildind runtime, we need set environment varibles # After success of buildind runtime, we need set environment varibles
# in :code:`~/.bashrc` file of yourself account or :code:`/etc/profile` # in :code:`~/.bashrc` file of yourself account or :code:`/etc/profile`
# of system enviroment variables. Assuming your TVM directory is in # of system enviroment variables. Assuming your TVM directory is in
...@@ -95,7 +93,7 @@ from tvm.contrib import rpc, util ...@@ -95,7 +93,7 @@ from tvm.contrib import rpc, util
# successful to start RPC server on your device. # successful to start RPC server on your device.
# #
# .. code-block:: bash # .. code-block:: bash
# #
# Loading runtime library /home/YOURNAME/code/tvm/lib/libtvm_runtime.so... exec only # Loading runtime library /home/YOURNAME/code/tvm/lib/libtvm_runtime.so... exec only
# INFO:root:RPCServer: bind to 0.0.0.0:9090 # INFO:root:RPCServer: bind to 0.0.0.0:9090
# #
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment