Commit 34d2aae3 by Tianqi Chen Committed by GitHub

[BUFFER/REFACTOR] Buffer byte_offset-> elem_offset, add buffer_bind_scope (#209)

parent b0e41b9a
Subproject commit e42653d7c3a604eb9f6ee1b5f989ddadd1cea69c
Subproject commit 860199eea031a4ea694b8fce03ad0bf8127910ac
......@@ -16,10 +16,11 @@ namespace tvm {
// Internal node container Buffer
class BufferNode;
/*!
* \brief Buffer is a symbolic n-darray structure.
* It is a composition of primitive symbolic types,
* used to specify input/output strcuture of the program.
* used to specify the memory layout of the Tensor used in program input.
*/
class Buffer : public NodeRef {
public:
......@@ -39,6 +40,21 @@ class Buffer : public NodeRef {
*/
Stmt MakeStore(Array<Expr> index, Expr value) const;
/*!
* \brief Return a new buffer that is equivalent with current one
* but always add stride field.
* \return The strided version of the buffer.
*/
Buffer MakeStrideView() const;
/*!
* \brief Make a new symbolic buffer representing a slice of the buffer.
* \param begins The beginning position of each dimension.
* \param extents The extent of each dimension.
* \note This function will make target buffer as compact as possible.
* If stride is not needed in the slice, it won't be presented
* \return the result buffer.
*/
Buffer MakeSlice(Array<Expr> begins, Array<Expr> extents) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
......@@ -63,17 +79,14 @@ class BufferNode : public Node {
* This can be an empty array, indicating array is contiguous
*/
Array<Expr> strides;
/*!
* \brief The offset in bytes to the beginning pointer to data
* Can be undefined, indicating this must be zero.
*/
Expr byte_offset;
/*! \brief The offset in terms of number of dtype elements (including lanes) */
Expr elem_offset;
// Meta data
/*! \brief optional name of the buffer */
std::string name;
/*! \brief storage scope of the buffer, if other than global */
std::string scope;
/*! \brief Alignment bytes size of byte_offset */
/*! \brief Alignment multiple in terms of dtype elements (including lanes) */
int offset_alignment;
/*! \brief constructor */
BufferNode() {}
......@@ -83,7 +96,7 @@ class BufferNode : public Node {
v->Visit("dtype", &dtype);
v->Visit("shape", &shape);
v->Visit("strides", &strides);
v->Visit("byte_offset", &byte_offset);
v->Visit("elem_offset", &elem_offset);
v->Visit("name", &name);
v->Visit("scope", &scope);
v->Visit("offset_alignment", &offset_alignment);
......
......@@ -61,6 +61,14 @@ inline TVMType Type2TVMType(Type t) {
return ret;
}
// Get number of bytes considering vector type.
inline int GetVectorBytes(Type dtype) {
int data_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
}
/*! \brief a named variable in TVM */
class Var : public Halide::VarExpr {
public:
......
......@@ -167,8 +167,16 @@ constexpr const char* prefetch_scope = "prefetch_scope";
constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
constexpr const char* scan_init_scope = "scan_init_scope";
/*! \brief extern operator scope */
constexpr const char* extern_op_scope = "extern_op_scope";
/*!
* \brief Bind the buffer specification to the region of the op
* When this scope occurs, the stmt.node is a Array<NodeRef> = [buffer, tensor]
* stmt.value is a tvm_tuple(min0, extent0, min1, extent1, ...).
* The scope represents that we need to bind the storage region of tensor to buffer.
* This will affect replacement of some variables inside the scope that
* corresponds to field of buffer to be the actual expressions of tensor during
* storage flattening phase.
*/
constexpr const char* buffer_bind_scope = "buffer_bind_scope";
// Pipeline related attributes
/*! \brief channel read scope */
constexpr const char* channel_read_scope = "channel_read_scope";
......@@ -195,6 +203,14 @@ namespace intrinsic {
*/
constexpr const char* tvm_address_of = "tvm_address_of";
/*!
* \brief tvm_tuple is not an actual function and cannot codegen.
* It is used to represent tuple structure in value field of AttrStmt,
* for the sake of giving hint to optimization.
*
* Handle tvm_tuple(value0, value1, ..., value_n);
*/
constexpr const char* tvm_tuple = "tvm_tuple";
/*!
* \brief See pesudo code
*
* Type tvm_struct_get(StructType* arr, int index, int field_id) {
......@@ -250,14 +266,14 @@ constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape";
* Expr strides,
* Expr ndim,
* Expr dtype,
* Expr byte_offset) {
* Expr elem_offset) {
* ret = alloca stack DLTensor();
* ret->data = data;
* ret->shape = shape;
* ret->strides = strides != 0 ? strides : nullptr;
* ret->ndim = ndim;
* ret->dtype = dtype.type();
* ret->byte_offset = byte_offset;
* ret->byte_offset = elem_offset * sizeof(dtype);
* return ret;
* }
*/
......
......@@ -62,7 +62,7 @@ class OperationNode : public FunctionBaseNode {
virtual Array<Expr> output_shape(size_t i) const = 0;
/*!
* \brief List all the input Tensors.
* \return List if input tensors.
* \return List of input tensors.
*/
virtual Array<Tensor> InputTensors() const = 0;
/*!
......
......@@ -287,6 +287,7 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"):
res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res
def extern(shape, inputs, fcompute,
name="extern", dtype=None):
"""Compute several tensor via extern function.
......@@ -374,7 +375,7 @@ def decl_buffer(shape,
name="buffer",
data=None,
strides=None,
byte_offset=None,
elem_offset=None,
scope="",
offset_alignment=0):
"""Decleare a new symbolic buffer.
......@@ -401,8 +402,9 @@ def decl_buffer(shape,
strides: array of Expr
The stride of the buffer.
byte_offset: Expr, optional
The offset in bytes to data pointer.
elem_offset: Expr, optional
The beginning offset of the array to data.
In terms of number of elements of dtype.
scope: str, optional
The storage scope of the buffer, if not global.
......@@ -423,7 +425,7 @@ def decl_buffer(shape,
to create function that only handles specific case of data structure
and make compiled function benefit from it.
If user pass strides and byte_offset is passed as None
If user pass strides and elem_offset is passed as None
when constructing the function, then the function will be specialized
for the DLTensor that is compact and aligned.
If user pass a fully generic symbolic array to the strides,
......@@ -436,7 +438,7 @@ def decl_buffer(shape,
data = var(name, "handle")
return _api_internal._Buffer(
data, dtype, shape, strides, byte_offset, name, scope, offset_alignment)
data, dtype, shape, strides, elem_offset, name, scope, offset_alignment)
def _IterVar(dom, name, iter_type, thread_tag=''):
......@@ -464,11 +466,11 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
if dom is not None:
if isinstance(dom, (list, tuple)):
if len(dom) != 2:
raise ValueError("need to list of ranges")
raise TypeError("need to be list of ranges")
dom = Range(dom[0], dom[1])
if not isinstance(dom, _collections.Range):
raise ValueError("dom need to be Range")
raise TypeError("dom need to be Range")
name = name if name else 'iter'
v = var(name)
return _api_internal._IterVar(dom, v, iter_type, thread_tag)
......
......@@ -26,8 +26,6 @@ class Array(NodeBase):
def __len__(self):
return _api_internal._ArraySize(self)
def __repr__(self):
return '[' + (','.join(str(x) for x in self)) + ']'
@register_node
class Map(NodeBase):
......@@ -52,9 +50,6 @@ class Map(NodeBase):
def __len__(self):
return _api_internal._MapSize(self)
def __repr__(self):
return '{' + (", ".join(str(x[0]) + ": " +str(x[1]) for x in self.items())) + '}'
@register_node
class Range(NodeBase):
......
......@@ -237,6 +237,10 @@ class Broadcast(Expr):
pass
@register_node
class Shuffle(Expr):
pass
@register_node
class Call(Expr):
Extern = 0
ExternCPlusPlus = 1
......
"""Intrinsics and math functions in TVM."""
"""Expression Intrinsics and math functions in TVM."""
from __future__ import absolute_import as _abs
from ._ffi.function import register_func as _register_func
......@@ -20,7 +20,7 @@ def _pack_buffer(buf):
strides,
len(buf.shape),
const(0, dtype=buf.dtype),
buf.byte_offset]
buf.elem_offset]
return _make.Call("handle", "tvm_stack_make_array",
pack_args, _Call.Intrinsic, None, 0)
......
......@@ -73,3 +73,7 @@ class IfThenElse(Stmt):
@register_node
class Evaluate(Stmt):
pass
@register_node
class Prefetch(Stmt):
pass
......@@ -118,6 +118,12 @@ class Operation(NodeBase):
"""Number of outputs of this op."""
return _api_internal._OpNumOutputs(self)
@property
def input_tensors(self):
"""List of input tensors to this op."""
return _api_internal._OpInputTensors(self)
@register_node
class PlaceholderOp(Operation):
"""Placeholder operation."""
......
......@@ -218,6 +218,11 @@ TVM_REGISTER_API("_OpNumOutputs")
*ret = args[0].operator Operation()->num_outputs();
});
TVM_REGISTER_API("_OpInputTensors")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation()->InputTensors();
});
TVM_REGISTER_API("_IterVar")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = IterVarNode::make(
......
......@@ -40,10 +40,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
transa ? A->shape[1] : A->shape[0],
transa ? B->shape[1] : B->shape[0],
1.0f,
static_cast<float*>(B->data), B->shape[1],
static_cast<float*>(A->data), A->shape[1],
reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset),
B->shape[1],
reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset),
A->shape[1],
0.0f,
static_cast<float*>(C->data), C->shape[1]);
reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset),
C->shape[1]);
});
} // namespace contrib
} // namespace tvm
......@@ -4,6 +4,7 @@
*/
#include <tvm/buffer.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
namespace tvm {
......@@ -28,27 +29,43 @@ Buffer decl_buffer(Array<Expr> shape,
name, "", 0);
}
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
Expr base;
// The buffer offset in convention of number of elements of
// original data ignoring number of lanes.
inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
Expr base = n->elem_offset;
if (n->strides.size() == 0) {
CHECK_EQ(n->shape.size(), index.size());
base = index[0];
if (is_zero(base)) {
base = index[0];
} else {
base = base + index[0];
}
for (size_t i = 1; i < index.size(); ++i) {
base = base * n->shape[i] + index[i];
}
} else {
CHECK_EQ(n->strides.size(), index.size());
base = index[0] * n->strides[0];
if (is_zero(base)) {
base = index[0] * n->strides[0];
} else {
base = base + index[0] * n->strides[0];
}
for (size_t i = 1; i < index.size(); ++i) {
base = base + index[i] * n->strides[i];
}
}
if (!is_zero(n->byte_offset)) {
base = base + (n->byte_offset / n->dtype.bytes());
}
return base;
}
// Buffer access offset.
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
Expr offset = ElemOffset(n, index);
if (n->dtype.lanes() != 1) {
offset = offset * make_const(offset.type(), n->dtype.lanes());
}
return offset;
}
Expr Buffer::MakeLoad(Array<Expr> index) const {
const BufferNode* n = operator->();
return ir::Load::make(
......@@ -63,11 +80,58 @@ Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
const_true(n->dtype.lanes()));
}
Buffer Buffer::MakeStrideView() const {
if ((*this)->strides.size() != 0) return *this;
std::vector<Expr> temp;
auto n = std::make_shared<BufferNode>(*operator->());
Expr acc = make_const(n->shape[0].type(), 1);
for (size_t i = n->shape.size(); i != 0 ; --i) {
temp.push_back(acc);
acc = acc * n->shape[i - 1];
}
for (size_t i = temp.size(); i != 0; --i) {
n->strides.push_back(temp[i - 1]);
}
return Buffer(n);
}
Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
const BufferNode* n = operator->();
Expr elem_offset = ElemOffset(n, begins);
Array<Expr> strides = n->strides;
if (strides.size() == 0) {
bool can_relax = true;
bool need_stride = false;
// check if stride is needed.
for (size_t i = 0; i < extents.size(); ++i) {
if (!can_relax) {
if (!is_zero(begins[i]) ||
!is_zero(ir::Simplify(extents[i] - n->shape[i]))) {
need_stride = true;
}
}
if (!is_one(extents[i])) can_relax = false;
}
// make stride.
if (need_stride) {
return MakeStrideView().MakeSlice(begins, extents);
}
}
return BufferNode::make(n->data,
n->dtype,
extents,
strides,
elem_offset,
n->name + "_slice",
n->scope,
0);
}
Buffer BufferNode::make(Var data,
Type dtype,
Array<Expr> shape,
Array<Expr> strides,
Expr byte_offset,
Expr elem_offset,
std::string name,
std::string scope,
int offset_alignment) {
......@@ -78,16 +142,13 @@ Buffer BufferNode::make(Var data,
n->strides = std::move(strides);
n->name = std::move(name);
n->scope = std::move(scope);
if (!byte_offset.defined()) {
byte_offset = make_const(n->shape[0].type(), 0);
if (!elem_offset.defined()) {
elem_offset = make_const(n->shape[0].type(), 0);
}
if (offset_alignment != 0) {
CHECK_EQ(offset_alignment % dtype.bytes(), 0)
<< "Offset alignments must be at least " << dtype.bytes();
} else {
offset_alignment = dtype.bytes();
if (offset_alignment == 0) {
offset_alignment = 1;
}
n->byte_offset = byte_offset;
n->elem_offset = elem_offset;
n->offset_alignment = offset_alignment;
return Buffer(n);
}
......
......@@ -42,6 +42,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< ", identity_element=" << op->identity_element
<< ")";
});
} // namespace Internal
} // namespace Halide
......
......@@ -30,7 +30,7 @@ Tensor TensorNode::make(Array<Expr> shape,
Operation op,
int value_index) {
auto n = std::make_shared<TensorNode>();
n->shape = shape;
n->shape = std::move(shape);
n->dtype = dtype;
n->op = op;
n->value_index = value_index;
......
......@@ -251,7 +251,7 @@ Stmt Substitute(Stmt s,
return ir::Substitute(s, temp);
}
// Cross Thread reduction marker.
// Cross Thread reduction
bool IsCrossThreadReduction(const ComputeOpNode* self,
const Stage& stage) {
// Verify correctness of leaf nest.
......@@ -360,6 +360,7 @@ Stmt MakeCrossThreadReduction(
return MergeNest(nest, body);
}
// Normal computation.
Stmt MakeProvide(const ComputeOpNode* op,
const Tensor& t) {
Array<Expr> args;
......@@ -369,60 +370,56 @@ Stmt MakeProvide(const ComputeOpNode* op,
return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
}
Stmt ComputeOpNode::BuildProvide(
// loop nest structure for general compute
// This the the loop nest structured used in compute.
// Does not include the loop body.
struct ComputeLoopNest {
// The common number of loops between init and main
size_t num_common_loop;
// predicates for the initialize loop
std::vector<Expr> init_predicates;
// Initialization nest involved.
std::vector<std::vector<Stmt> > init_nest;
// Value map for the init code
std::unordered_map<IterVar, Expr> init_vmap;
// Predicates for the main update loop
std::vector<Expr> main_predicates;
// The general loop nest
std::vector<std::vector<Stmt> > main_nest;
// Value map for the IterVar.
std::unordered_map<IterVar, Expr> main_vmap;
};
ComputeLoopNest MakeComputeLoopNest(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
if (IsCrossThreadReduction(this, stage)) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map);
}
size_t size = this->body.size();
Stmt init;
Stmt provide;
if (this->reduce_axis.size() == 0) {
std::vector<Stmt> provides;
for (size_t i = 0; i < size; ++i) {
provides.emplace_back(MakeProvide(this, stage->op.output(i)));
}
provide = Block::make(provides);
} else {
Array<Tensor> source;
for (size_t i = 0; i < size; ++i) {
source.push_back(stage->op.output(i));
}
MakeReduction(this, source, &init, &provide);
const std::unordered_map<IterVar, Range>& dom_map) {
CHECK_EQ(stage->op.operator->(), self);
ComputeLoopNest ret;
// make main loop nest
ret.main_nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap);
ret.main_predicates = op::MakeBoundCheck(stage, dom_map, false,
std::unordered_set<IterVar>(), ret.main_vmap);
for (auto& e : ret.main_predicates) {
e = likely(e);
}
// make loop nest
std::unordered_map<IterVar, Expr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
auto preds = op::MakeBoundCheck(stage, dom_map, false,
std::unordered_set<IterVar>(), value_map);
for (auto& e : preds) e = likely(e);
nest.push_back(op::MakeIfNest(preds));
if (stage->store_predicate.defined()) {
nest.emplace_back(op::MakeIfNest({stage->store_predicate}));
ret.main_predicates.push_back(stage->store_predicate);
}
provide = Substitute(provide, value_map);
if (init.defined()) {
if (self->reduce_axis.size() != 0) {
// try to find the location to insert the initialization.
// Fuse the initialization and provide loop when possible.
std::unordered_map<IterVar, int> update_state;
for (IterVar iv : this->reduce_axis) {
for (IterVar iv : self->reduce_axis) {
update_state[iv] = 2;
}
for (IterVar iv : this->axis) {
for (IterVar iv : self->axis) {
update_state[iv] = 1;
}
// find which iter var is related to reduction and which is related to axis.
schedule::PassDownBitMaskOr(stage, &update_state);
auto leaf_iter_vars = stage->leaf_iter_vars;
std::unordered_map<IterVar, Expr> init_value_map;
// first first loop that is related to reduction.
size_t begin_loop = leaf_iter_vars.size();
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
......@@ -431,29 +428,69 @@ Stmt ComputeOpNode::BuildProvide(
if ((flag & 2) != 0) {
begin_loop = i; break;
}
init_value_map[iv] = value_map.at(iv);
ret.init_vmap[iv] = ret.main_vmap.at(iv);
}
ret.num_common_loop = begin_loop;
// skip loops that does not relates to axis.
std::unordered_set<IterVar> skip_iter;
for (auto kv : update_state) {
int flag = kv.second;
if ((flag & 1) == 0) skip_iter.insert(kv.first);
}
auto init_nest = op::MakeLoopNest(
ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true,
skip_iter, &init_value_map);
auto preds = op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map);
for (auto& e : preds) e = likely(e);
init_nest.push_back(op::MakeIfNest(preds));
init = Substitute(init, init_value_map);
init = MergeNest(init_nest, init);
skip_iter, &(ret.init_vmap));
ret.init_predicates = op::MakeBoundCheck(
stage, dom_map, true, skip_iter, ret.init_vmap);
for (auto& e : ret.init_predicates) {
e = likely(e);
}
} else {
ret.num_common_loop = ret.main_nest.size() - 1;
}
// copy elison here.
return ret;
}
// implement the provide utility.
Stmt ComputeOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
if (IsCrossThreadReduction(this, stage)) {
// specially handle cross thread reduction.
return MakeCrossThreadReduction(this, stage, dom_map);
}
// grab the nest structure
ComputeLoopNest n = MakeComputeLoopNest(this, stage, dom_map);
// Normal loop structure
n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
if (this->reduce_axis.size() != 0) {
// make reduction.
Stmt init, provide;
Array<Tensor> source;
for (size_t i = 0; i < this->body.size(); ++i) {
source.push_back(stage->op.output(i));
}
MakeReduction(this, source, &init, &provide);
init = Substitute(init, n.init_vmap);
init = MergeNest(n.init_nest, init);
// common nest
std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop + 1);
std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop + 1, nest.end());
std::vector<std::vector<Stmt> > common(
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
std::vector<std::vector<Stmt> > reduce(
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
provide = Substitute(provide, n.main_vmap);
provide = MergeNest(reduce, provide);
return MergeNest(common, Block::make(init, provide));
} else {
return MergeNest(nest, provide);
std::vector<Stmt> provides;
for (size_t i = 0; i < this->body.size(); ++i) {
provides.emplace_back(MakeProvide(this, stage->op.output(i)));
}
Stmt provide = Substitute(Block::make(provides), n.main_vmap);
return MergeNest(n.main_nest, provide);
}
}
} // namespace tvm
......@@ -128,8 +128,26 @@ Stmt ExternOpNode::BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map) const {
CHECK_EQ(stage->op.operator->(), this);
return AttrStmt::make(
stage->op, ir::attr::extern_op_scope,
StringImm::make(name), body);
Stmt ret = this->body;
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<NodeRef> bind_spec;
Array<Expr> tuple;
bind_spec.push_back(buffer);
bind_spec.push_back(tensor);
for (size_t k = 0; k < buffer->shape.size(); ++k) {
tuple.push_back(make_const(buffer->shape[k].type(), 0));
tuple.push_back(buffer->shape[k]);
}
ret = AttrStmt::make(
bind_spec, attr::buffer_bind_scope,
Call::make(Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
};
for (size_t i = output_placeholders.size(); i != 0; --i) {
f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
}
for (size_t i = inputs.size(); i != 0; --i) {
f_push_bind(input_placeholders[i - 1], inputs[i - 1]);
}
return ret;
}
} // namespace tvm
......@@ -131,9 +131,15 @@ class PackedCallBuilder : public IRMutator {
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
make_const(UInt(16), dtype.lanes())));
// set byte offset
int data_bytes = GetVectorBytes(dtype);
Expr byte_offset = op->args[5];
if (!is_zero(byte_offset)) {
byte_offset = byte_offset * make_const(byte_offset.type(), data_bytes);
}
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
Convert(Int(64), op->args[5])));
Convert(UInt(64), byte_offset)));
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
prep_seq_.emplace_back(
......
......@@ -12,6 +12,7 @@
#include <unordered_set>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
......@@ -222,9 +223,19 @@ LoweredFunc MakeAPI(Stmt body,
}
}
// Byte_offset field.
f_push(buf->byte_offset,
TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset),
v_arg->name_hint + ".byte_offset");
int data_bytes = GetVectorBytes(buf->dtype);
int64_t const_offset;
if (arith::GetConst(buf->elem_offset, &const_offset)) {
f_push(make_const(buf->elem_offset.type(), const_offset * data_bytes),
TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset),
v_arg->name_hint + ".byte_offset");
} else {
f_push(buf->elem_offset,
cast(buf->elem_offset.type(),
(TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset) /
make_const(UInt(64), data_bytes))),
v_arg->name_hint + ".elem_offset");
}
// device info.
f_push(device_id,
TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceId),
......
......@@ -7,8 +7,8 @@
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/buffer.h>
#include <tvm/operation.h>
#include <unordered_map>
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
......@@ -31,9 +31,12 @@ class StorageFlattener : public IRMutator {
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
auto it = extern_buf_remap_.find(op->buffer_var.get());
if (it != extern_buf_remap_.end()) {
return Store::make(it->second, op->value, op->index, op->predicate);
auto it = var_remap_.find(op->buffer_var.get());
if (it != var_remap_.end() &&
!it->second.same_as(op->buffer_var)) {
CHECK(it->second.as<Variable>());
VarExpr buf_var(it->second.node_);
return Store::make(buf_var, op->value, op->index, op->predicate);
} else {
return stmt;
}
......@@ -50,8 +53,8 @@ class StorageFlattener : public IRMutator {
Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back();
return stmt;
} else if (op->attr_key == attr::extern_op_scope) {
return HandleExternOp(op);
} else if (op->attr_key == attr::buffer_bind_scope) {
return HandleBufferBindScope(op);
}
return IRMutator::Mutate_(op, s);
}
......@@ -115,17 +118,20 @@ class StorageFlattener : public IRMutator {
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
auto it = extern_buf_remap_.find(op->buffer_var.get());
if (it != extern_buf_remap_.end()) {
return Load::make(op->type, it->second, op->index, op->predicate);
auto it = var_remap_.find(op->buffer_var.get());
if (it != var_remap_.end() &&
!it->second.same_as(op->buffer_var)) {
CHECK(it->second.as<Variable>());
VarExpr buf_var(it->second.node_);
return Load::make(op->type, buf_var, op->index, op->predicate);
} else {
return expr;
}
}
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = extern_buf_remap_.find(op);
if (it != extern_buf_remap_.end()) {
auto it = var_remap_.find(op);
if (it != var_remap_.end()) {
return it->second;
} else {
return e;
......@@ -150,35 +156,115 @@ class StorageFlattener : public IRMutator {
}
private:
Stmt HandleExternOp(const AttrStmt* op) {
const ExternOpNode* ext_op = op->node.as<ExternOpNode>();
CHECK(ext_op);
Operation func(op->node.node_);
CHECK_EQ(extern_buf_remap_.size(), 0U);
for (size_t i = 0; i < ext_op->output_placeholders.size(); ++i) {
TensorKey key{func, static_cast<int>(i)};
CHECK(buf_map_.count(key))
<< "Cannot find allocated buffer for " << key.f
<< "(" << key.value_index << ")";
extern_buf_remap_[ext_op->output_placeholders[i]->data.get()] =
buf_map_.at(key).buffer->data;
// Bind the symbol sym to value if it is a Variable
// send a sequence of asserts if it is a constant constrant.
// hint_name: used for error message
// add_keys: a list of newly binded keys
// add_asserts: a list of asserts during the bind
void BindSymbol(Expr sym,
Expr value,
std::string hint_name,
std::vector<const Variable*>* add_keys,
std::vector<Stmt>* add_asserts) {
if (const Variable* v = sym.as<Variable>()) {
auto it = var_remap_.find(v);
if (it == var_remap_.end()) {
add_keys->push_back(v);
var_remap_[v] = value;
return;
}
}
// add assertions
std::ostringstream os;
os << "BufferBind constaint fail " << hint_name;
add_asserts->emplace_back(
AssertStmt::make(sym == value, os.str()));
}
// Start bind
Stmt HandleBufferBindScope(const AttrStmt* op) {
Array<NodeRef> arr(op->node.node_);
CHECK_EQ(arr.size(), 2U);
const BufferNode* buffer = arr[0].as<BufferNode>();
const TensorNode* tensor = arr[1].as<TensorNode>();
const Call* tuple = op->value.as<Call>();
CHECK(buffer && tensor);
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
TensorKey key{tensor->op, tensor->value_index};
CHECK(buf_map_.count(key));
const BufferEntry& be = buf_map_.at(key);
CHECK(!be.released);
CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
Array<Expr> begins, extents;
if (be.bounds.size() != 0) {
CHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
begins.push_back(
arith::ComputeExpr<Sub>(tuple->args[2 * i], be.bounds[i]->min));
extents.push_back(tuple->args[2 * i + 1]);
}
} else {
for (size_t i = 0; i < tuple->args.size(); i += 2) {
begins.push_back(tuple->args[i]);
extents.push_back(tuple->args[i + 1]);
}
}
Buffer slice = be.buffer.MakeSlice(begins, extents);
if (buffer->strides.size() == 0) {
CHECK_EQ(slice->strides.size(), 0U)
<< "Trying to bind compact buffer to strided one";
} else {
slice = slice.MakeStrideView();
}
CHECK_EQ(slice->strides.size(), buffer->strides.size());
// start binding
std::vector<const Variable*> keys;
std::vector<Stmt> asserts;
BindSymbol(buffer->data, slice->data,
buffer->name + ".data",
&keys, &asserts);
for (size_t i = 0; i < buffer->shape.size(); ++i) {
std::ostringstream field_name;
field_name << buffer->name << ".shape[" << i << ']';
BindSymbol(buffer->shape[i], slice->shape[i],
field_name.str(),
&keys, &asserts);
}
for (size_t i = 0; i < buffer->strides.size(); ++i) {
std::ostringstream field_name;
field_name << buffer->name << ".strides[" << i << ']';
BindSymbol(buffer->strides[i], slice->strides[i],
field_name.str(),
&keys, &asserts);
}
BindSymbol(buffer->elem_offset, slice->elem_offset,
buffer->name + ".elem_offset",
&keys, &asserts);
CHECK_EQ(buffer->scope, slice->scope)
<< "Buffer bind scope mismatch";
// Apply the remaps
Stmt body = this->Mutate(op->body);
for (size_t i = 0; i < asserts.size(); ++i) {
Stmt ret = Simplify(this->Mutate(asserts[i]));
if (const AssertStmt* assert_op = ret.as<AssertStmt>()) {
if (!is_zero(assert_op->condition)) {
body = Block::make(ret, body);
} else {
LOG(FATAL) << "BindBuffer have unmet assertion: " << ret;
}
}
}
for (size_t i = 0; i < ext_op->inputs.size(); ++i) {
TensorKey key{ext_op->inputs[i]->op, ext_op->inputs[i]->value_index};
CHECK(buf_map_.count(key));
extern_buf_remap_[ext_op->input_placeholders[i]->data.get()] =
buf_map_.at(key).buffer->data;
// remove the binds
for (const Variable* op : keys) {
var_remap_.erase(op);
}
Stmt ret = Mutate(op->body);
extern_buf_remap_.clear();
return ret;
return body;
}
// The buffer entry in the flatten map
struct BufferEntry {
// the buffer of storage
Buffer buffer;
// the bounds of realization, can be null
// the bounds of realization, can be null, means everything
Region bounds;
// Whether the buffer is external
bool external{false};
......@@ -200,7 +286,9 @@ class StorageFlattener : public IRMutator {
}
};
// The buffer assignment map
std::unordered_map<const Variable*, Var> extern_buf_remap_;
// Variable remap
std::unordered_map<const Variable*, Expr> var_remap_;
// Buffer map
std::unordered_map<TensorKey, BufferEntry> buf_map_;
std::unordered_map<const Node*, std::string> storage_scope_;
// The current thread scope.
......
......@@ -14,8 +14,13 @@ def test_llvm_add_pipeline():
def check_llvm():
if not tvm.module.enabled("llvm"):
return
# Specifically allow offset to test codepath when offset is available
Ab = tvm.decl_buffer(
A.shape, A.dtype, elem_offset=tvm.var('Aoffset'),
name='A')
binds = {A : Ab}
# build and invoke the kernel.
f = tvm.build(s, [A, B, C], "llvm")
f = tvm.build(s, [Ab, B, C], "llvm", binds=binds)
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
......@@ -25,6 +30,7 @@ def test_llvm_add_pipeline():
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
check_llvm()
......
......@@ -168,7 +168,14 @@ def test_tuple_with_different_deps():
assert stmt.node == C.op and len(ret) == 1
def test_tensor_inputs():
x = tvm.placeholder((1,), name='x')
y = tvm.compute(x.shape, lambda i: x[i] + x[i])
assert tuple(y.op.input_tensors) == (x,)
if __name__ == "__main__":
test_tensor_inputs()
test_tensor_reduce_multi_axis()
test_conv1d()
test_tensor_slice()
......
......@@ -72,6 +72,17 @@ def test_auto_inline():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_const_bound():
n = 128
A = tvm.placeholder((n,), name='A')
A1 = tvm.compute((n,), lambda i: A[i] + 1, name='A1')
s = tvm.create_schedule(A1.op)
xo, xi = s[A1].split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_inline_mixed():
n = tvm.var('n')
A = tvm.placeholder((n, ), name='A')
......@@ -150,6 +161,7 @@ def test_schedule_cache():
if __name__ == "__main__":
test_schedule_const_bound()
test_scan_inline1()
test_scan_inline2()
test_inline_mixed()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment