Commit 34d2aae3 by Tianqi Chen Committed by GitHub

[BUFFER/REFACTOR] Buffer byte_offset-> elem_offset, add buffer_bind_scope (#209)

parent b0e41b9a
Subproject commit e42653d7c3a604eb9f6ee1b5f989ddadd1cea69c Subproject commit 860199eea031a4ea694b8fce03ad0bf8127910ac
...@@ -16,10 +16,11 @@ namespace tvm { ...@@ -16,10 +16,11 @@ namespace tvm {
// Internal node container Buffer // Internal node container Buffer
class BufferNode; class BufferNode;
/*! /*!
* \brief Buffer is a symbolic n-darray structure. * \brief Buffer is a symbolic n-darray structure.
* It is a composition of primitive symbolic types, * It is a composition of primitive symbolic types,
* used to specify input/output strcuture of the program. * used to specify the memory layout of the Tensor used in program input.
*/ */
class Buffer : public NodeRef { class Buffer : public NodeRef {
public: public:
...@@ -39,6 +40,21 @@ class Buffer : public NodeRef { ...@@ -39,6 +40,21 @@ class Buffer : public NodeRef {
*/ */
Stmt MakeStore(Array<Expr> index, Expr value) const; Stmt MakeStore(Array<Expr> index, Expr value) const;
/*! /*!
* \brief Return a new buffer that is equivalent with current one
* but always add stride field.
* \return The strided version of the buffer.
*/
Buffer MakeStrideView() const;
/*!
* \brief Make a new symbolic buffer representing a slice of the buffer.
* \param begins The beginning position of each dimension.
* \param extents The extent of each dimension.
* \note This function will make target buffer as compact as possible.
* If stride is not needed in the slice, it won't be presented
* \return the result buffer.
*/
Buffer MakeSlice(Array<Expr> begins, Array<Expr> extents) const;
/*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
*/ */
...@@ -63,17 +79,14 @@ class BufferNode : public Node { ...@@ -63,17 +79,14 @@ class BufferNode : public Node {
* This can be an empty array, indicating array is contiguous * This can be an empty array, indicating array is contiguous
*/ */
Array<Expr> strides; Array<Expr> strides;
/*! /*! \brief The offset in terms of number of dtype elements (including lanes) */
* \brief The offset in bytes to the beginning pointer to data Expr elem_offset;
* Can be undefined, indicating this must be zero.
*/
Expr byte_offset;
// Meta data // Meta data
/*! \brief optional name of the buffer */ /*! \brief optional name of the buffer */
std::string name; std::string name;
/*! \brief storage scope of the buffer, if other than global */ /*! \brief storage scope of the buffer, if other than global */
std::string scope; std::string scope;
/*! \brief Alignment bytes size of byte_offset */ /*! \brief Alignment multiple in terms of dtype elements (including lanes) */
int offset_alignment; int offset_alignment;
/*! \brief constructor */ /*! \brief constructor */
BufferNode() {} BufferNode() {}
...@@ -83,7 +96,7 @@ class BufferNode : public Node { ...@@ -83,7 +96,7 @@ class BufferNode : public Node {
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
v->Visit("shape", &shape); v->Visit("shape", &shape);
v->Visit("strides", &strides); v->Visit("strides", &strides);
v->Visit("byte_offset", &byte_offset); v->Visit("elem_offset", &elem_offset);
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("scope", &scope); v->Visit("scope", &scope);
v->Visit("offset_alignment", &offset_alignment); v->Visit("offset_alignment", &offset_alignment);
......
...@@ -61,6 +61,14 @@ inline TVMType Type2TVMType(Type t) { ...@@ -61,6 +61,14 @@ inline TVMType Type2TVMType(Type t) {
return ret; return ret;
} }
// Get number of bytes considering vector type.
inline int GetVectorBytes(Type dtype) {
int data_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
}
/*! \brief a named variable in TVM */ /*! \brief a named variable in TVM */
class Var : public Halide::VarExpr { class Var : public Halide::VarExpr {
public: public:
......
...@@ -167,8 +167,16 @@ constexpr const char* prefetch_scope = "prefetch_scope"; ...@@ -167,8 +167,16 @@ constexpr const char* prefetch_scope = "prefetch_scope";
constexpr const char* scan_update_scope = "scan_update_scope"; constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */ /*! \brief Mark of scan init scope */
constexpr const char* scan_init_scope = "scan_init_scope"; constexpr const char* scan_init_scope = "scan_init_scope";
/*! \brief extern operator scope */ /*!
constexpr const char* extern_op_scope = "extern_op_scope"; * \brief Bind the buffer specification to the region of the op
* When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor]
* stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...).
* The scope represents that we need to bind the storage region of tensor to buffer.
* This will affect replacement of some variables inside the scope that
* corresponds to field of buffer to be the actual expressions of tensor during
* storage flattening phase.
*/
constexpr const char* buffer_bind_scope = "buffer_bind_scope";
// Pipeline related attributes // Pipeline related attributes
/*! \brief channel read scope */ /*! \brief channel read scope */
constexpr const char* channel_read_scope = "channel_read_scope"; constexpr const char* channel_read_scope = "channel_read_scope";
...@@ -195,6 +203,14 @@ namespace intrinsic { ...@@ -195,6 +203,14 @@ namespace intrinsic {
*/ */
constexpr const char* tvm_address_of = "tvm_address_of"; constexpr const char* tvm_address_of = "tvm_address_of";
/*! /*!
* \brief tvm_tuple is not an actual function and cannot codegen.
* It is used to represent tuple structure in value field of AttrStmt,
* for the sake of giving hint to optimization.
*
* Handle tvm_tuple(value0, value1, ..., value_n);
*/
constexpr const char* tvm_tuple = "tvm_tuple";
/*!
* \brief See pesudo code * \brief See pesudo code
* *
* Type tvm_struct_get(StructType* arr, int index, int field_id) { * Type tvm_struct_get(StructType* arr, int index, int field_id) {
...@@ -250,14 +266,14 @@ constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape"; ...@@ -250,14 +266,14 @@ constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape";
* Expr strides, * Expr strides,
* Expr ndim, * Expr ndim,
* Expr dtype, * Expr dtype,
* Expr byte_offset) { * Expr elem_offset) {
* ret = alloca stack DLTensor(); * ret = alloca stack DLTensor();
* ret->data = data; * ret->data = data;
* ret->shape = shape; * ret->shape = shape;
* ret->strides = strides != 0 ? strides : nullptr; * ret->strides = strides != 0 ? strides : nullptr;
* ret->ndim = ndim; * ret->ndim = ndim;
* ret->dtype = dtype.type(); * ret->dtype = dtype.type();
* ret->byte_offset = byte_offset; * ret->byte_offset = elem_offset * sizeof(dtype);
* return ret; * return ret;
* } * }
*/ */
......
...@@ -62,7 +62,7 @@ class OperationNode : public FunctionBaseNode { ...@@ -62,7 +62,7 @@ class OperationNode : public FunctionBaseNode {
virtual Array<Expr> output_shape(size_t i) const = 0; virtual Array<Expr> output_shape(size_t i) const = 0;
/*! /*!
* \brief List all the input Tensors. * \brief List all the input Tensors.
* \return List if input tensors. * \return List of input tensors.
*/ */
virtual Array<Tensor> InputTensors() const = 0; virtual Array<Tensor> InputTensors() const = 0;
/*! /*!
......
...@@ -287,6 +287,7 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"): ...@@ -287,6 +287,7 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"):
res = [op.output(i) for i in range(len(update))] res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res return res[0] if len(res) == 1 else res
def extern(shape, inputs, fcompute, def extern(shape, inputs, fcompute,
name="extern", dtype=None): name="extern", dtype=None):
"""Compute several tensor via extern function. """Compute several tensor via extern function.
...@@ -374,7 +375,7 @@ def decl_buffer(shape, ...@@ -374,7 +375,7 @@ def decl_buffer(shape,
name="buffer", name="buffer",
data=None, data=None,
strides=None, strides=None,
byte_offset=None, elem_offset=None,
scope="", scope="",
offset_alignment=0): offset_alignment=0):
"""Decleare a new symbolic buffer. """Decleare a new symbolic buffer.
...@@ -401,8 +402,9 @@ def decl_buffer(shape, ...@@ -401,8 +402,9 @@ def decl_buffer(shape,
strides: array of Expr strides: array of Expr
The stride of the buffer. The stride of the buffer.
byte_offset: Expr, optional elem_offset: Expr, optional
The offset in bytes to data pointer. The beginning offset of the array to data.
In terms of number of elements of dtype.
scope: str, optional scope: str, optional
The storage scope of the buffer, if not global. The storage scope of the buffer, if not global.
...@@ -423,7 +425,7 @@ def decl_buffer(shape, ...@@ -423,7 +425,7 @@ def decl_buffer(shape,
to create function that only handles specific case of data structure to create function that only handles specific case of data structure
and make compiled function benefit from it. and make compiled function benefit from it.
If user pass strides and byte_offset is passed as None If user pass strides and elem_offset is passed as None
when constructing the function, then the function will be specialized when constructing the function, then the function will be specialized
for the DLTensor that is compact and aligned. for the DLTensor that is compact and aligned.
If user pass a fully generic symbolic array to the strides, If user pass a fully generic symbolic array to the strides,
...@@ -436,7 +438,7 @@ def decl_buffer(shape, ...@@ -436,7 +438,7 @@ def decl_buffer(shape,
data = var(name, "handle") data = var(name, "handle")
return _api_internal._Buffer( return _api_internal._Buffer(
data, dtype, shape, strides, byte_offset, name, scope, offset_alignment) data, dtype, shape, strides, elem_offset, name, scope, offset_alignment)
def _IterVar(dom, name, iter_type, thread_tag=''): def _IterVar(dom, name, iter_type, thread_tag=''):
...@@ -464,11 +466,11 @@ def _IterVar(dom, name, iter_type, thread_tag=''): ...@@ -464,11 +466,11 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
if dom is not None: if dom is not None:
if isinstance(dom, (list, tuple)): if isinstance(dom, (list, tuple)):
if len(dom) != 2: if len(dom) != 2:
raise ValueError("need to list of ranges") raise TypeError("need to be list of ranges")
dom = Range(dom[0], dom[1]) dom = Range(dom[0], dom[1])
if not isinstance(dom, _collections.Range): if not isinstance(dom, _collections.Range):
raise ValueError("dom need to be Range") raise TypeError("dom need to be Range")
name = name if name else 'iter' name = name if name else 'iter'
v = var(name) v = var(name)
return _api_internal._IterVar(dom, v, iter_type, thread_tag) return _api_internal._IterVar(dom, v, iter_type, thread_tag)
......
...@@ -26,8 +26,6 @@ class Array(NodeBase): ...@@ -26,8 +26,6 @@ class Array(NodeBase):
def __len__(self): def __len__(self):
return _api_internal._ArraySize(self) return _api_internal._ArraySize(self)
def __repr__(self):
return '[' + (','.join(str(x) for x in self)) + ']'
@register_node @register_node
class Map(NodeBase): class Map(NodeBase):
...@@ -52,9 +50,6 @@ class Map(NodeBase): ...@@ -52,9 +50,6 @@ class Map(NodeBase):
def __len__(self): def __len__(self):
return _api_internal._MapSize(self) return _api_internal._MapSize(self)
def __repr__(self):
return '{' + (", ".join(str(x[0]) + ": " +str(x[1]) for x in self.items())) + '}'
@register_node @register_node
class Range(NodeBase): class Range(NodeBase):
......
...@@ -237,6 +237,10 @@ class Broadcast(Expr): ...@@ -237,6 +237,10 @@ class Broadcast(Expr):
pass pass
@register_node @register_node
class Shuffle(Expr):
pass
@register_node
class Call(Expr): class Call(Expr):
Extern = 0 Extern = 0
ExternCPlusPlus = 1 ExternCPlusPlus = 1
......
"""Intrinsics and math functions in TVM.""" """Expression Intrinsics and math functions in TVM."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.function import register_func as _register_func from ._ffi.function import register_func as _register_func
...@@ -20,7 +20,7 @@ def _pack_buffer(buf): ...@@ -20,7 +20,7 @@ def _pack_buffer(buf):
strides, strides,
len(buf.shape), len(buf.shape),
const(0, dtype=buf.dtype), const(0, dtype=buf.dtype),
buf.byte_offset] buf.elem_offset]
return _make.Call("handle", "tvm_stack_make_array", return _make.Call("handle", "tvm_stack_make_array",
pack_args, _Call.Intrinsic, None, 0) pack_args, _Call.Intrinsic, None, 0)
......
...@@ -73,3 +73,7 @@ class IfThenElse(Stmt): ...@@ -73,3 +73,7 @@ class IfThenElse(Stmt):
@register_node @register_node
class Evaluate(Stmt): class Evaluate(Stmt):
pass pass
@register_node
class Prefetch(Stmt):
pass
...@@ -118,6 +118,12 @@ class Operation(NodeBase): ...@@ -118,6 +118,12 @@ class Operation(NodeBase):
"""Number of outputs of this op.""" """Number of outputs of this op."""
return _api_internal._OpNumOutputs(self) return _api_internal._OpNumOutputs(self)
@property
def input_tensors(self):
"""List of input tensors to this op."""
return _api_internal._OpInputTensors(self)
@register_node @register_node
class PlaceholderOp(Operation): class PlaceholderOp(Operation):
"""Placeholder operation.""" """Placeholder operation."""
......
...@@ -218,6 +218,11 @@ TVM_REGISTER_API("_OpNumOutputs") ...@@ -218,6 +218,11 @@ TVM_REGISTER_API("_OpNumOutputs")
*ret = args[0].operator Operation()->num_outputs(); *ret = args[0].operator Operation()->num_outputs();
}); });
TVM_REGISTER_API("_OpInputTensors")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation()->InputTensors();
});
TVM_REGISTER_API("_IterVar") TVM_REGISTER_API("_IterVar")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = IterVarNode::make( *ret = IterVarNode::make(
......
...@@ -40,10 +40,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") ...@@ -40,10 +40,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
transa ? A->shape[1] : A->shape[0], transa ? A->shape[1] : A->shape[0],
transa ? B->shape[1] : B->shape[0], transa ? B->shape[1] : B->shape[0],
1.0f, 1.0f,
static_cast<float*>(B->data), B->shape[1], reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset),
static_cast<float*>(A->data), A->shape[1], B->shape[1],
reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset),
A->shape[1],
0.0f, 0.0f,
static_cast<float*>(C->data), C->shape[1]); reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset),
C->shape[1]);
}); });
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
*/ */
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h>
namespace tvm { namespace tvm {
...@@ -28,27 +29,43 @@ Buffer decl_buffer(Array<Expr> shape, ...@@ -28,27 +29,43 @@ Buffer decl_buffer(Array<Expr> shape,
name, "", 0); name, "", 0);
} }
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) { // The buffer offset in convention of number of elements of
Expr base; // original data ignoring number of lanes.
inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
Expr base = n->elem_offset;
if (n->strides.size() == 0) { if (n->strides.size() == 0) {
CHECK_EQ(n->shape.size(), index.size()); CHECK_EQ(n->shape.size(), index.size());
base = index[0]; if (is_zero(base)) {
base = index[0];
} else {
base = base + index[0];
}
for (size_t i = 1; i < index.size(); ++i) { for (size_t i = 1; i < index.size(); ++i) {
base = base * n->shape[i] + index[i]; base = base * n->shape[i] + index[i];
} }
} else { } else {
CHECK_EQ(n->strides.size(), index.size()); CHECK_EQ(n->strides.size(), index.size());
base = index[0] * n->strides[0]; if (is_zero(base)) {
base = index[0] * n->strides[0];
} else {
base = base + index[0] * n->strides[0];
}
for (size_t i = 1; i < index.size(); ++i) { for (size_t i = 1; i < index.size(); ++i) {
base = base + index[i] * n->strides[i]; base = base + index[i] * n->strides[i];
} }
} }
if (!is_zero(n->byte_offset)) {
base = base + (n->byte_offset / n->dtype.bytes());
}
return base; return base;
} }
// Buffer access offset.
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
Expr offset = ElemOffset(n, index);
if (n->dtype.lanes() != 1) {
offset = offset * make_const(offset.type(), n->dtype.lanes());
}
return offset;
}
Expr Buffer::MakeLoad(Array<Expr> index) const { Expr Buffer::MakeLoad(Array<Expr> index) const {
const BufferNode* n = operator->(); const BufferNode* n = operator->();
return ir::Load::make( return ir::Load::make(
...@@ -63,11 +80,58 @@ Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const { ...@@ -63,11 +80,58 @@ Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
const_true(n->dtype.lanes())); const_true(n->dtype.lanes()));
} }
Buffer Buffer::MakeStrideView() const {
if ((*this)->strides.size() != 0) return *this;
std::vector<Expr> temp;
auto n = std::make_shared<BufferNode>(*operator->());
Expr acc = make_const(n->shape[0].type(), 1);
for (size_t i = n->shape.size(); i != 0 ; --i) {
temp.push_back(acc);
acc = acc * n->shape[i - 1];
}
for (size_t i = temp.size(); i != 0; --i) {
n->strides.push_back(temp[i - 1]);
}
return Buffer(n);
}
Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
const BufferNode* n = operator->();
Expr elem_offset = ElemOffset(n, begins);
Array<Expr> strides = n->strides;
if (strides.size() == 0) {
bool can_relax = true;
bool need_stride = false;
// check if stride is needed.
for (size_t i = 0; i < extents.size(); ++i) {
if (!can_relax) {
if (!is_zero(begins[i]) ||
!is_zero(ir::Simplify(extents[i] - n->shape[i]))) {
need_stride = true;
}
}
if (!is_one(extents[i])) can_relax = false;
}
// make stride.
if (need_stride) {
return MakeStrideView().MakeSlice(begins, extents);
}
}
return BufferNode::make(n->data,
n->dtype,
extents,
strides,
elem_offset,
n->name + "_slice",
n->scope,
0);
}
Buffer BufferNode::make(Var data, Buffer BufferNode::make(Var data,
Type dtype, Type dtype,
Array<Expr> shape, Array<Expr> shape,
Array<Expr> strides, Array<Expr> strides,
Expr byte_offset, Expr elem_offset,
std::string name, std::string name,
std::string scope, std::string scope,
int offset_alignment) { int offset_alignment) {
...@@ -78,16 +142,13 @@ Buffer BufferNode::make(Var data, ...@@ -78,16 +142,13 @@ Buffer BufferNode::make(Var data,
n->strides = std::move(strides); n->strides = std::move(strides);
n->name = std::move(name); n->name = std::move(name);
n->scope = std::move(scope); n->scope = std::move(scope);
if (!byte_offset.defined()) { if (!elem_offset.defined()) {
byte_offset = make_const(n->shape[0].type(), 0); elem_offset = make_const(n->shape[0].type(), 0);
} }
if (offset_alignment != 0) { if (offset_alignment == 0) {
CHECK_EQ(offset_alignment % dtype.bytes(), 0) offset_alignment = 1;
<< "Offset alignments must be at least " << dtype.bytes();
} else {
offset_alignment = dtype.bytes();
} }
n->byte_offset = byte_offset; n->elem_offset = elem_offset;
n->offset_alignment = offset_alignment; n->offset_alignment = offset_alignment;
return Buffer(n); return Buffer(n);
} }
......
...@@ -42,6 +42,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -42,6 +42,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< ", identity_element=" << op->identity_element << ", identity_element=" << op->identity_element
<< ")"; << ")";
}); });
} // namespace Internal } // namespace Internal
} // namespace Halide } // namespace Halide
......
...@@ -30,7 +30,7 @@ Tensor TensorNode::make(Array<Expr> shape, ...@@ -30,7 +30,7 @@ Tensor TensorNode::make(Array<Expr> shape,
Operation op, Operation op,
int value_index) { int value_index) {
auto n = std::make_shared<TensorNode>(); auto n = std::make_shared<TensorNode>();
n->shape = shape; n->shape = std::move(shape);
n->dtype = dtype; n->dtype = dtype;
n->op = op; n->op = op;
n->value_index = value_index; n->value_index = value_index;
......
...@@ -251,7 +251,7 @@ Stmt Substitute(Stmt s, ...@@ -251,7 +251,7 @@ Stmt Substitute(Stmt s,
return ir::Substitute(s, temp); return ir::Substitute(s, temp);
} }
// Cross Thread reduction marker. // Cross Thread reduction
bool IsCrossThreadReduction(const ComputeOpNode* self, bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage) { const Stage& stage) {
// Verify correctness of leaf nest. // Verify correctness of leaf nest.
...@@ -360,6 +360,7 @@ Stmt MakeCrossThreadReduction( ...@@ -360,6 +360,7 @@ Stmt MakeCrossThreadReduction(
return MergeNest(nest, body); return MergeNest(nest, body);
} }
// Normal computation.
Stmt MakeProvide(const ComputeOpNode* op, Stmt MakeProvide(const ComputeOpNode* op,
const Tensor& t) { const Tensor& t) {
Array<Expr> args; Array<Expr> args;
...@@ -369,60 +370,56 @@ Stmt MakeProvide(const ComputeOpNode* op, ...@@ -369,60 +370,56 @@ Stmt MakeProvide(const ComputeOpNode* op,
return Provide::make(t->op, t->value_index, op->body[t->value_index], args); return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
} }
Stmt ComputeOpNode::BuildProvide( // loop nest structure for general compute
// This the the loop nest structured used in compute.
// Does not include the loop body.
struct ComputeLoopNest {
// The common number of loops between init and main
size_t num_common_loop;
// predicates for the initialize loop
std::vector<Expr> init_predicates;
// Initialization nest involved.
std::vector<std::vector<Stmt> > init_nest;
// Value map for the init code
std::unordered_map<IterVar, Expr> init_vmap;
// Predicates for the main update loop
std::vector<Expr> main_predicates;
// The general loop nest
std::vector<std::vector<Stmt> > main_nest;
// Value map for the IterVar.
std::unordered_map<IterVar, Expr> main_vmap;
};
ComputeLoopNest MakeComputeLoopNest(
const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const { const std::unordered_map<IterVar, Range>& dom_map) {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), self);
ComputeLoopNest ret;
if (IsCrossThreadReduction(this, stage)) { // make main loop nest
// specially handle cross thread reduction. ret.main_nest = op::MakeLoopNest(
return MakeCrossThreadReduction(this, stage, dom_map); stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap);
} ret.main_predicates = op::MakeBoundCheck(stage, dom_map, false,
std::unordered_set<IterVar>(), ret.main_vmap);
size_t size = this->body.size(); for (auto& e : ret.main_predicates) {
Stmt init; e = likely(e);
Stmt provide;
if (this->reduce_axis.size() == 0) {
std::vector<Stmt> provides;
for (size_t i = 0; i < size; ++i) {
provides.emplace_back(MakeProvide(this, stage->op.output(i)));
}
provide = Block::make(provides);
} else {
Array<Tensor> source;
for (size_t i = 0; i < size; ++i) {
source.push_back(stage->op.output(i));
}
MakeReduction(this, source, &init, &provide);
} }
// make loop nest
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
auto preds = op::MakeBoundCheck(stage, dom_map, false,
std::unordered_set<IterVar>(), value_map);
for (auto& e : preds) e = likely(e);
nest.push_back(op::MakeIfNest(preds));
if (stage->store_predicate.defined()) { if (stage->store_predicate.defined()) {
nest.emplace_back(op::MakeIfNest({stage->store_predicate})); ret.main_predicates.push_back(stage->store_predicate);
} }
provide = Substitute(provide, value_map); if (self->reduce_axis.size() != 0) {
if (init.defined()) {
// try to find the location to insert the initialization. // try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible. // Fuse the initialization and provide loop when possible.
std::unordered_map<IterVar, int> update_state; std::unordered_map<IterVar, int> update_state;
for (IterVar iv : this->reduce_axis) { for (IterVar iv : self->reduce_axis) {
update_state[iv] = 2; update_state[iv] = 2;
} }
for (IterVar iv : this->axis) { for (IterVar iv : self->axis) {
update_state[iv] = 1; update_state[iv] = 1;
} }
// find which iter var is related to reduction and which is related to axis. // find which iter var is related to reduction and which is related to axis.
schedule::PassDownBitMaskOr(stage, &update_state); schedule::PassDownBitMaskOr(stage, &update_state);
auto leaf_iter_vars = stage->leaf_iter_vars; auto leaf_iter_vars = stage->leaf_iter_vars;
std::unordered_map<IterVar, Expr> init_value_map;
// first first loop that is related to reduction. // first first loop that is related to reduction.
size_t begin_loop = leaf_iter_vars.size(); size_t begin_loop = leaf_iter_vars.size();
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) { for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
...@@ -431,29 +428,69 @@ Stmt ComputeOpNode::BuildProvide( ...@@ -431,29 +428,69 @@ Stmt ComputeOpNode::BuildProvide(
if ((flag & 2) != 0) { if ((flag & 2) != 0) {
begin_loop = i; break; begin_loop = i; break;
} }
init_value_map[iv] = value_map.at(iv); ret.init_vmap[iv] = ret.main_vmap.at(iv);
} }
ret.num_common_loop = begin_loop;
// skip loops that does not relates to axis. // skip loops that does not relates to axis.
std::unordered_set<IterVar> skip_iter; std::unordered_set<IterVar> skip_iter;
for (auto kv : update_state) { for (auto kv : update_state) {
int flag = kv.second; int flag = kv.second;
if ((flag & 1) == 0) skip_iter.insert(kv.first); if ((flag & 1) == 0) skip_iter.insert(kv.first);
} }
auto init_nest = op::MakeLoopNest( ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true, stage, dom_map, begin_loop, true,
skip_iter, &init_value_map); skip_iter, &(ret.init_vmap));
auto preds = op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map); ret.init_predicates = op::MakeBoundCheck(
for (auto& e : preds) e = likely(e); stage, dom_map, true, skip_iter, ret.init_vmap);
init_nest.push_back(op::MakeIfNest(preds)); for (auto& e : ret.init_predicates) {
init = Substitute(init, init_value_map); e = likely(e);
init = MergeNest(init_nest, init); }
} else {
ret.num_common_loop = ret.main_nest.size() - 1;
}
// copy elison here.
return ret;
}
// implement the provide utility.
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
if (IsCrossThreadReduction(this, stage)) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map);
}
// grab the nest structure
ComputeLoopNest n = MakeComputeLoopNest(this, stage, dom_map);
// Normal loop structure
n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
if (this->reduce_axis.size() != 0) {
// make reduction.
Stmt init, provide;
Array<Tensor> source;
for (size_t i = 0; i < this->body.size(); ++i) {
source.push_back(stage->op.output(i));
}
MakeReduction(this, source, &init, &provide);
init = Substitute(init, n.init_vmap);
init = MergeNest(n.init_nest, init);
// common nest // common nest
std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop + 1); std::vector<std::vector<Stmt> > common(
std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop + 1, nest.end()); n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
std::vector<std::vector<Stmt> > reduce(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
provide = Substitute(provide, n.main_vmap);
provide = MergeNest(reduce, provide); provide = MergeNest(reduce, provide);
return MergeNest(common, Block::make(init, provide)); return MergeNest(common, Block::make(init, provide));
} else { } else {
return MergeNest(nest, provide); std::vector<Stmt> provides;
for (size_t i = 0; i < this->body.size(); ++i) {
provides.emplace_back(MakeProvide(this, stage->op.output(i)));
}
Stmt provide = Substitute(Block::make(provides), n.main_vmap);
return MergeNest(n.main_nest, provide);
} }
} }
} // namespace tvm } // namespace tvm
...@@ -128,8 +128,26 @@ Stmt ExternOpNode::BuildProvide( ...@@ -128,8 +128,26 @@ Stmt ExternOpNode::BuildProvide(
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const { const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
return AttrStmt::make( Stmt ret = this->body;
stage->op, ir::attr::extern_op_scope, auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
StringImm::make(name), body); Array<NodeRef> bind_spec;
Array<Expr> tuple;
bind_spec.push_back(buffer);
bind_spec.push_back(tensor);
for (size_t k = 0; k < buffer->shape.size(); ++k) {
tuple.push_back(make_const(buffer->shape[k].type(), 0));
tuple.push_back(buffer->shape[k]);
}
ret = AttrStmt::make(
bind_spec, attr::buffer_bind_scope,
Call::make(Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
};
for (size_t i = output_placeholders.size(); i != 0; --i) {
f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
}
for (size_t i = inputs.size(); i != 0; --i) {
f_push_bind(input_placeholders[i - 1], inputs[i - 1]);
}
return ret;
} }
} // namespace tvm } // namespace tvm
...@@ -131,9 +131,15 @@ class PackedCallBuilder : public IRMutator { ...@@ -131,9 +131,15 @@ class PackedCallBuilder : public IRMutator {
prep_seq_.emplace_back( prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes, TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
make_const(UInt(16), dtype.lanes()))); make_const(UInt(16), dtype.lanes())));
// set byte offset
int data_bytes = GetVectorBytes(dtype);
Expr byte_offset = op->args[5];
if (!is_zero(byte_offset)) {
byte_offset = byte_offset * make_const(byte_offset.type(), data_bytes);
}
prep_seq_.emplace_back( prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset, TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
Convert(Int(64), op->args[5]))); Convert(UInt(64), byte_offset)));
CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR";
prep_seq_.emplace_back( prep_seq_.emplace_back(
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <unordered_set> #include <unordered_set>
#include "./ir_util.h" #include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -222,9 +223,19 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -222,9 +223,19 @@ LoweredFunc MakeAPI(Stmt body,
} }
} }
// Byte_offset field. // Byte_offset field.
f_push(buf->byte_offset, int data_bytes = GetVectorBytes(buf->dtype);
TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset), int64_t const_offset;
v_arg->name_hint + ".byte_offset"); if (arith::GetConst(buf->elem_offset, &const_offset)) {
f_push(make_const(buf->elem_offset.type(), const_offset * data_bytes),
TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset),
v_arg->name_hint + ".byte_offset");
} else {
f_push(buf->elem_offset,
cast(buf->elem_offset.type(),
(TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset) /
make_const(UInt(64), data_bytes))),
v_arg->name_hint + ".elem_offset");
}
// device info. // device info.
f_push(device_id, f_push(device_id,
TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceId), TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceId),
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/operation.h>
#include <unordered_map> #include <unordered_map>
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h" #include "../runtime/thread_storage_scope.h"
namespace tvm { namespace tvm {
...@@ -31,9 +31,12 @@ class StorageFlattener : public IRMutator { ...@@ -31,9 +31,12 @@ class StorageFlattener : public IRMutator {
Stmt Mutate_(const Store* op, const Stmt& s) final { Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>(); op = stmt.as<Store>();
auto it = extern_buf_remap_.find(op->buffer_var.get()); auto it = var_remap_.find(op->buffer_var.get());
if (it != extern_buf_remap_.end()) { if (it != var_remap_.end() &&
return Store::make(it->second, op->value, op->index, op->predicate); !it->second.same_as(op->buffer_var)) {
CHECK(it->second.as<Variable>());
VarExpr buf_var(it->second.node_);
return Store::make(buf_var, op->value, op->index, op->predicate);
} else { } else {
return stmt; return stmt;
} }
...@@ -50,8 +53,8 @@ class StorageFlattener : public IRMutator { ...@@ -50,8 +53,8 @@ class StorageFlattener : public IRMutator {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back(); curr_thread_scope_.pop_back();
return stmt; return stmt;
} else if (op->attr_key == attr::extern_op_scope) { } else if (op->attr_key == attr::buffer_bind_scope) {
return HandleExternOp(op); return HandleBufferBindScope(op);
} }
return IRMutator::Mutate_(op, s); return IRMutator::Mutate_(op, s);
} }
...@@ -115,17 +118,20 @@ class StorageFlattener : public IRMutator { ...@@ -115,17 +118,20 @@ class StorageFlattener : public IRMutator {
Expr Mutate_(const Load* op, const Expr& e) final { Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e); Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>(); op = expr.as<Load>();
auto it = extern_buf_remap_.find(op->buffer_var.get()); auto it = var_remap_.find(op->buffer_var.get());
if (it != extern_buf_remap_.end()) { if (it != var_remap_.end() &&
return Load::make(op->type, it->second, op->index, op->predicate); !it->second.same_as(op->buffer_var)) {
CHECK(it->second.as<Variable>());
VarExpr buf_var(it->second.node_);
return Load::make(op->type, buf_var, op->index, op->predicate);
} else { } else {
return expr; return expr;
} }
} }
Expr Mutate_(const Variable* op, const Expr& e) final { Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = extern_buf_remap_.find(op); auto it = var_remap_.find(op);
if (it != extern_buf_remap_.end()) { if (it != var_remap_.end()) {
return it->second; return it->second;
} else { } else {
return e; return e;
...@@ -150,35 +156,115 @@ class StorageFlattener : public IRMutator { ...@@ -150,35 +156,115 @@ class StorageFlattener : public IRMutator {
} }
private: private:
Stmt HandleExternOp(const AttrStmt* op) { // Bind the symbol sym to value if it is a Variable
const ExternOpNode* ext_op = op->node.as<ExternOpNode>(); // send a sequence of asserts if it is a constant constrant.
CHECK(ext_op); // hint_name: used for error message
Operation func(op->node.node_); // add_keys: a list of newly binded keys
CHECK_EQ(extern_buf_remap_.size(), 0U); // add_asserts: a list of asserts during the bind
for (size_t i = 0; i < ext_op->output_placeholders.size(); ++i) { void BindSymbol(Expr sym,
TensorKey key{func, static_cast<int>(i)}; Expr value,
CHECK(buf_map_.count(key)) std::string hint_name,
<< "Cannot find allocated buffer for " << key.f std::vector<const Variable*>* add_keys,
<< "(" << key.value_index << ")"; std::vector<Stmt>* add_asserts) {
extern_buf_remap_[ext_op->output_placeholders[i]->data.get()] = if (const Variable* v = sym.as<Variable>()) {
buf_map_.at(key).buffer->data; auto it = var_remap_.find(v);
if (it == var_remap_.end()) {
add_keys->push_back(v);
var_remap_[v] = value;
return;
}
}
// add assertions
std::ostringstream os;
os << "BufferBind constaint fail " << hint_name;
add_asserts->emplace_back(
AssertStmt::make(sym == value, os.str()));
}
// Start bind
Stmt HandleBufferBindScope(const AttrStmt* op) {
Array<NodeRef> arr(op->node.node_);
CHECK_EQ(arr.size(), 2U);
const BufferNode* buffer = arr[0].as<BufferNode>();
const TensorNode* tensor = arr[1].as<TensorNode>();
const Call* tuple = op->value.as<Call>();
CHECK(buffer && tensor);
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
TensorKey key{tensor->op, tensor->value_index};
CHECK(buf_map_.count(key));
const BufferEntry& be = buf_map_.at(key);
CHECK(!be.released);
CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
Array<Expr> begins, extents;
if (be.bounds.size() != 0) {
CHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
begins.push_back(
arith::ComputeExpr<Sub>(tuple->args[2 * i], be.bounds[i]->min));
extents.push_back(tuple->args[2 * i + 1]);
}
} else {
for (size_t i = 0; i < tuple->args.size(); i += 2) {
begins.push_back(tuple->args[i]);
extents.push_back(tuple->args[i + 1]);
}
}
Buffer slice = be.buffer.MakeSlice(begins, extents);
if (buffer->strides.size() == 0) {
CHECK_EQ(slice->strides.size(), 0U)
<< "Trying to bind compact buffer to strided one";
} else {
slice = slice.MakeStrideView();
}
CHECK_EQ(slice->strides.size(), buffer->strides.size());
// start binding
std::vector<const Variable*> keys;
std::vector<Stmt> asserts;
BindSymbol(buffer->data, slice->data,
buffer->name + ".data",
&keys, &asserts);
for (size_t i = 0; i < buffer->shape.size(); ++i) {
std::ostringstream field_name;
field_name << buffer->name << ".shape[" << i << ']';
BindSymbol(buffer->shape[i], slice->shape[i],
field_name.str(),
&keys, &asserts);
}
for (size_t i = 0; i < buffer->strides.size(); ++i) {
std::ostringstream field_name;
field_name << buffer->name << ".strides[" << i << ']';
BindSymbol(buffer->strides[i], slice->strides[i],
field_name.str(),
&keys, &asserts);
}
BindSymbol(buffer->elem_offset, slice->elem_offset,
buffer->name + ".elem_offset",
&keys, &asserts);
CHECK_EQ(buffer->scope, slice->scope)
<< "Buffer bind scope mismatch";
// Apply the remaps
Stmt body = this->Mutate(op->body);
for (size_t i = 0; i < asserts.size(); ++i) {
Stmt ret = Simplify(this->Mutate(asserts[i]));
if (const AssertStmt* assert_op = ret.as<AssertStmt>()) {
if (!is_zero(assert_op->condition)) {
body = Block::make(ret, body);
} else {
LOG(FATAL) << "BindBuffer have unmet assertion: " << ret;
}
}
} }
for (size_t i = 0; i < ext_op->inputs.size(); ++i) { // remove the binds
TensorKey key{ext_op->inputs[i]->op, ext_op->inputs[i]->value_index}; for (const Variable* op : keys) {
CHECK(buf_map_.count(key)); var_remap_.erase(op);
extern_buf_remap_[ext_op->input_placeholders[i]->data.get()] =
buf_map_.at(key).buffer->data;
} }
Stmt ret = Mutate(op->body); return body;
extern_buf_remap_.clear();
return ret;
} }
// The buffer entry in the flatten map // The buffer entry in the flatten map
struct BufferEntry { struct BufferEntry {
// the buffer of storage // the buffer of storage
Buffer buffer; Buffer buffer;
// the bounds of realization, can be null // the bounds of realization, can be null, means everything
Region bounds; Region bounds;
// Whether the buffer is external // Whether the buffer is external
bool external{false}; bool external{false};
...@@ -200,7 +286,9 @@ class StorageFlattener : public IRMutator { ...@@ -200,7 +286,9 @@ class StorageFlattener : public IRMutator {
} }
}; };
// The buffer assignment map // The buffer assignment map
std::unordered_map<const Variable*, Var> extern_buf_remap_; // Variable remap
std::unordered_map<const Variable*, Expr> var_remap_;
// Buffer map
std::unordered_map<TensorKey, BufferEntry> buf_map_; std::unordered_map<TensorKey, BufferEntry> buf_map_;
std::unordered_map<const Node*, std::string> storage_scope_; std::unordered_map<const Node*, std::string> storage_scope_;
// The current thread scope. // The current thread scope.
......
...@@ -14,8 +14,13 @@ def test_llvm_add_pipeline(): ...@@ -14,8 +14,13 @@ def test_llvm_add_pipeline():
def check_llvm(): def check_llvm():
if not tvm.module.enabled("llvm"): if not tvm.module.enabled("llvm"):
return return
# Specifically allow offset to test codepath when offset is available
Ab = tvm.decl_buffer(
A.shape, A.dtype, elem_offset=tvm.var('Aoffset'),
name='A')
binds = {A : Ab}
# build and invoke the kernel. # build and invoke the kernel.
f = tvm.build(s, [A, B, C], "llvm") f = tvm.build(s, [Ab, B, C], "llvm", binds=binds)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
# launch the kernel. # launch the kernel.
n = nn n = nn
...@@ -25,6 +30,7 @@ def test_llvm_add_pipeline(): ...@@ -25,6 +30,7 @@ def test_llvm_add_pipeline():
f(a, b, c) f(a, b, c)
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy()) c.asnumpy(), a.asnumpy() + b.asnumpy())
check_llvm() check_llvm()
......
...@@ -168,7 +168,14 @@ def test_tuple_with_different_deps(): ...@@ -168,7 +168,14 @@ def test_tuple_with_different_deps():
assert stmt.node == C.op and len(ret) == 1 assert stmt.node == C.op and len(ret) == 1
def test_tensor_inputs():
x = tvm.placeholder((1,), name='x')
y = tvm.compute(x.shape, lambda i: x[i] + x[i])
assert tuple(y.op.input_tensors) == (x,)
if __name__ == "__main__": if __name__ == "__main__":
test_tensor_inputs()
test_tensor_reduce_multi_axis() test_tensor_reduce_multi_axis()
test_conv1d() test_conv1d()
test_tensor_slice() test_tensor_slice()
......
...@@ -72,6 +72,17 @@ def test_auto_inline(): ...@@ -72,6 +72,17 @@ def test_auto_inline():
bounds = tvm.schedule.InferBound(s) bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_const_bound():
n = 128
A = tvm.placeholder((n,), name='A')
A1 = tvm.compute((n,), lambda i: A[i] + 1, name='A1')
s = tvm.create_schedule(A1.op)
xo, xi = s[A1].split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_inline_mixed(): def test_inline_mixed():
n = tvm.var('n') n = tvm.var('n')
A = tvm.placeholder((n, ), name='A') A = tvm.placeholder((n, ), name='A')
...@@ -150,6 +161,7 @@ def test_schedule_cache(): ...@@ -150,6 +161,7 @@ def test_schedule_cache():
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_const_bound()
test_scan_inline1() test_scan_inline1()
test_scan_inline2() test_scan_inline2()
test_inline_mixed() test_inline_mixed()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment