Commit 7e3d9da4 by Tianqi Chen Committed by GitHub

[STORAGE][BUFFER] Support access ptr for clear access pattern. (#266)

* [STORAGE][BUFFER] Support access ptr for clear access pattern.

* fix lint
parent 1f7712ae
......@@ -55,7 +55,7 @@ tvm.ir_pass
tvm.ir_pass.SplitPipeline
tvm.ir_pass.LowerThreadAllreduce
tvm.ir_pass.LowerIntrin
tvm.ir_pass.LowerPackedCall
tvm.ir_pass.LowerTVMBuiltin
tvm.ir_pass.NarrowChannelAccess
.. automodule:: tvm.ir_pass
......
......@@ -17,6 +17,12 @@ namespace tvm {
// Internal node container Buffer
class BufferNode;
/*! \brief memory access kind */
enum class AccessMask : int {
kRead = 1,
kWrite = 2
};
/*!
* \brief Buffer is a symbolic n-darray structure.
* It is a composition of primitive symbolic types,
......@@ -55,6 +61,12 @@ class Buffer : public NodeRef {
*/
Buffer MakeSlice(Array<Expr> begins, Array<Expr> extents) const;
/*!
* \brief Get access ptr to the entire buffer.
* \param access_mask The access mask
* \param ptr_type The type of the pointer.
*/
Expr access_ptr(int access_mask, Type ptr_type = Handle()) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
......
......@@ -208,6 +208,22 @@ namespace intrinsic {
*/
constexpr const char* tvm_address_of = "tvm_address_of";
/*!
* \brief Get head access address with memory access pattern info.
*
* This operator also marks range of the memory access
* The offset and extent are in unit of the DType(including vectorization factor).
* rw_mask is a bit_mask setting whether the access is a read(1) or write(2).
* The access is assume to happen in the current expression.
*
* PtrType tvm_access_ptr(Expr dtype, DType* data,
* int offset, int extent,
* int rw_mask) {
* // DType == dtype.type();
* return &data[offset];
* }
*/
constexpr const char* tvm_access_ptr = "tvm_access_ptr";
/*!
* \brief tvm_tuple is not an actual function and cannot codegen.
* It is used to represent tuple structure in value field of AttrStmt,
* for the sake of giving hint to optimization.
......
......@@ -334,7 +334,7 @@ LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
* \param f The function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerPackedCall(LoweredFunc f);
LoweredFunc LowerTVMBuiltin(LoweredFunc f);
/*!
* \brief Combine context function calls.
......
......@@ -6,6 +6,7 @@
#ifndef TVM_TARGET_INFO_H_
#define TVM_TARGET_INFO_H_
#include <string>
#include "./base.h"
#include "./expr.h"
......@@ -36,5 +37,12 @@ struct MemoryInfoNode : public Node {
/*! \brief Defines memory info */
TVM_DEFINE_NODE_REF(MemoryInfo, MemoryInfoNode);
/*!
* \brief get memory info given scope
* \param scope The scope name.
* \return info The memory info.
*/
MemoryInfo GetMemoryInfo(const std::string& scope);
} // namespace tvm
#endif // TVM_TARGET_INFO_H_
......@@ -321,7 +321,7 @@ def build(sch,
device = "cpu" if target.startswith("llvm") or target == "stackvm" else target
device_type = ndarray.context(device, 0).device_type
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerPackedCall(x) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
if fdevice:
......
"""The computation schedule api of TVM."""
from __future__ import absolute_import as _abs
from ._ffi.base import string_types
from ._ffi.node import NodeBase, register_node
from ._ffi.function import _init_api
from . import _api_internal
from . import tensor as _tensor
from . import expr as _expr
from . import container as _container
from ._ffi.function import _init_api
@register_node
class Buffer(NodeBase):
......@@ -21,7 +22,49 @@ class Buffer(NodeBase):
--------
decl_buffer : Declare a buffer
"""
pass
READ = 1
WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle"):
"""Get an access pointer to the head of buffer
This is the recommended method to get buffer data
ptress when interacting with external functions.
Parameters
----------
access_mask : int
The access pattern MASK. Indicate whether the
access will read or write to the data content.
ptr_type : str, optional
The data type of the result pointer. Do not specify
unless we want to cast pointer to specific type.
Examples
--------
.. code-block:: python
import tvm.schedule.Buffer
# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
"""
if isinstance(access_mask, string_types):
mask = 0
for value in access_mask:
if value == "r":
mask = mask | Buffer.READ
elif value == "w":
mask = mask | Buffer.WRITE
else:
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type)
@register_node
......
......@@ -156,6 +156,12 @@ TVM_REGISTER_API("_Buffer")
args[8]);
});
TVM_REGISTER_API("_BufferAccessPtr")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer()
.access_ptr(args[1], args[2]);
});
TVM_REGISTER_API("_Tensor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorNode::make(args[0],
......
......@@ -102,7 +102,7 @@ REGISTER_PASS2(SplitPipeline);
REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerPackedCall);
REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
} // namespace ir
} // namespace tvm
......@@ -22,6 +22,7 @@ using Halide::Internal::mul_would_overflow;
* \brief Compute the expression with the given binary op.
* \param lhs The left operand
* \param rhs The right operand
* \tparam Op the computation operator
* \return The result.
*/
template<typename OP>
......@@ -29,6 +30,15 @@ inline Expr ComputeExpr(Expr lhs, Expr rhs) {
return OP::make(lhs, rhs);
}
/*!
* \brief Compute an reduction with Op
* \param values The input values.
* \tparam Op The computation operator
* \return The result.
*/
template<typename Op>
inline Expr ComputeReduce(const Array<Expr>& values);
template<typename T>
inline bool GetConst(Expr e, T* out);
......@@ -128,6 +138,16 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return Halide::Internal::Interval::make_min(a, b);
}
template<typename Op>
inline Expr ComputeReduce(const Array<Expr>& values) {
CHECK_NE(values.size(), 0U);
Expr res = values[0];
for (size_t i = 1; i < values.size(); ++i) {
res = ComputeExpr<Op>(res, values[i]);
}
return res;
}
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_COMPUTE_EXPR_H_
......@@ -6,6 +6,7 @@
#include <tvm/runtime/device_api.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "../arithmetic/compute_expr.h"
namespace tvm {
......@@ -131,6 +132,19 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
0);
}
Expr Buffer::access_ptr(int access_mask, Type ptr_type) const {
const BufferNode* self = operator->();
Expr e_dtype = make_zero(self->dtype);
Expr extent = (self->strides.size() == self->shape.size() ?
arith::ComputeExpr<ir::Mul>(self->strides[0], self->shape[0]):
arith::ComputeReduce<ir::Mul>(self->shape));
Array<Expr> acc_args{
e_dtype, self->data, self->elem_offset,
extent, make_const(Int(32), access_mask)};
return ir::Call::make(
ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic);
}
Buffer BufferNode::make(Var data,
Type dtype,
Array<Expr> shape,
......
......@@ -3,6 +3,7 @@
* \file target_info.cc
*/
#include <tvm/target_info.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
......@@ -16,4 +17,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(MemoryInfoNode);
MemoryInfo GetMemoryInfo(const std::string& scope) {
std::string fname = "tvm.info.mem." + scope;
const runtime::PackedFunc* f = runtime::Registry::Get(fname);
if (f == nullptr) {
return MemoryInfo();
} else {
return (*f)();
}
}
} // namespace tvm
......@@ -299,10 +299,18 @@ void VerifyTensorizeBody(
for (size_t i = 0; i < body.size(); ++i) {
Expr lhs = CanonicalSimplify(body[i]);
Expr rhs = CanonicalSimplify(intrin_compute->body[i]);
if (lhs.type() != rhs.type()) {
LOG(FATAL)
<< "Failed to match the data type with TensorIntrin "
<< intrin->name << "'s declaration "
<< " provided=" << lhs.type()
<< ", intrin=" << rhs.type();
}
CHECK(Equal(lhs, rhs))
<< "Failed to match the compute with TensorIntrin declaration "
<< " provided:" << lhs
<< ", intrin:" << rhs;
<< "Failed to match the compute with TensorIntrin "
<< intrin->name << "'s declaration "
<< " provided= " << lhs
<< ", intrin= " << rhs;
}
}
......
......@@ -89,29 +89,21 @@ void ArgBinder::BindBuffer(const Buffer& arg,
<< ", provided_alignment=" << value->data_alignment;
}
// bind pointer and offset.
if (is_zero(arg->elem_offset) && !is_zero(value->elem_offset)) {
// If target buffer only accepts pointer, try to fold offset into pointer.
Expr addr = AddressOffset(value->data, value->dtype, value->elem_offset);
if (Bind_(arg->data, addr, arg_name + ".data", true)) {
int offset_factor = arg->data_alignment * 8 / (arg->dtype.bits() * arg->dtype.lanes());
if (offset_factor > 1) {
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
} else {
this->Bind(arg->data, value->data, arg_name + ".data");
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), arg->offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
if (is_zero(arg->elem_offset)) {
CHECK(is_zero(value->elem_offset))
<< "Trying to bind a Buffer with offset into one without offset";
}
this->Bind(arg->data, value->data, arg_name + ".data");
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), arg->offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
if (arg->shape.size() < value->shape.size()) {
CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
size_t diff = value->shape.size() - arg->shape.size();
......
......@@ -100,9 +100,12 @@ inline Expr AddressOffset(Var handle, Type dtype, int offset) {
* \param offset the offset index.
*/
inline Expr AddressOffset(Var handle, Type dtype, Expr offset) {
if (dtype.lanes() != 1) {
offset = offset * make_const(offset.type(), dtype.lanes());
}
return Call::make(
Handle(), intrinsic::tvm_address_of,
{Load::make(dtype, handle, offset * make_const(offset.type(), dtype.lanes()),
{Load::make(dtype, handle, offset,
const_true(dtype.lanes()))},
Call::PureIntrinsic);
}
......
/*!
* Copyright (c) 2017 by Contributors
* Lower calls to packed function.
* \file lower_packed_call.cc
* Lower TVM related buildin intrinsics such as packed call.
* \file lower_tvm_buildin.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/target_info.h>
#include <unordered_set>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageScope;
inline Expr ConstInt32(size_t index) {
CHECK_LE(index, std::numeric_limits<int>::max());
return make_const(Int(32), static_cast<int>(index));
......@@ -25,7 +28,7 @@ inline Expr StackAlloca(std::string type, size_t num) {
// Calculate the statistics of packed function.
// These information are needed during codegen.
class PackedCallBuilder : public IRMutator {
class BuiltinLower : public IRMutator {
public:
Stmt Build(Stmt stmt) {
stack_shape_ = Var("stack_shape", Handle());
......@@ -49,6 +52,7 @@ class PackedCallBuilder : public IRMutator {
}
return stmt;
}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
CHECK_EQ(run_shape_stack_, 0);
......@@ -65,6 +69,14 @@ class PackedCallBuilder : public IRMutator {
// Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
// For special memory, remove allocate.
auto it = storage_info_.find(op->buffer_var.get());
if (it != storage_info_.end() && it->second.scope.tag.length() != 0) {
++it->second.alloc_count;
CHECK_LE(it->second.alloc_count, 1)
<< "Double allocation of " << it->second.scope.to_string();
return op->body;
}
// Get constant allocation bound.
int64_t dev_type;
int64_t nbytes = GetVectorBytes(op->type);
......@@ -127,12 +139,25 @@ class PackedCallBuilder : public IRMutator {
CHECK(!device_type_.defined());
device_type_ = op->value;
return Mutate(op->body);
} else if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
StorageEntry e;
e.scope = scope;
if (scope.tag.length() != 0) {
e.info = GetMemoryInfo(op->value.as<StringImm>()->value);
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
}
storage_info_[buf] = e;
return IRMutator::Mutate_(op, s);
} else {
return IRMutator::Mutate_(op, s);
}
}
Expr Mutate_(const Call* op, const Expr &e) final {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return MakeCallPacked(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
return MakeShape(op, e);
......@@ -142,6 +167,7 @@ class PackedCallBuilder : public IRMutator {
return IRMutator::Mutate_(op, e);
}
}
Expr Convert(Type t, Expr e) {
if (e.type() != t) {
return Cast::make(t, e);
......@@ -254,6 +280,33 @@ class PackedCallBuilder : public IRMutator {
Int(32), intrinsic::tvm_call_packed_lowered,
packed_args, Call::Intrinsic);
}
// tvm_access_ptr
Expr MakeAccessPtr(const Call* op, const Expr& e) {
// Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U);
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
Expr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.scope.tag.length() != 0) {
return MakeTaggedAccessPtr(
op->type, dtype, offset,
it->second.info.defined() ? it->second.info->unit_bits : 8);
}
CHECK(op->type.is_handle());
// Change to address_of
return AddressOffset(Var(op->args[1].node_), dtype, offset);
}
Expr MakeTaggedAccessPtr(Type ptr_type, Type dtype,
Expr offset, int unit_bits) {
int dtype_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(unit_bits % dtype_bits, 0);
return Convert(ptr_type,
ir::Simplify(offset / make_const(offset.type(), unit_bits / dtype_bits)));
}
private:
bool IsArrayHandle(const Expr& arg) {
......@@ -284,11 +337,22 @@ class PackedCallBuilder : public IRMutator {
uint64_t max_shape_stack_{0};
uint64_t max_array_stack_{0};
uint64_t max_arg_stack_{0};
// The storage entry.
struct StorageEntry {
// Whether it is tagged memory.
StorageScope scope;
// The memory info if any.
MemoryInfo info;
// Allocation counter
int alloc_count{0};
};
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageEntry> storage_info_;
};
LoweredFunc LowerPackedCall(LoweredFunc f) {
LoweredFunc LowerTVMBuiltin(LoweredFunc f) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = PackedCallBuilder().Build(n->body);
n->body = BuiltinLower().Build(n->body);
return LoweredFunc(n);
}
......
......@@ -26,6 +26,7 @@ enum AccessType {
kSync,
kAlloc
};
/*! \brief The access entry */
struct AccessEntry {
/*! \brief The buffer variable, if any */
......@@ -44,6 +45,7 @@ struct AccessEntry {
StorageScope scope)
: buffer(buffer), index(index), type(type), scope(scope) {}
};
/*! \brief The access info about a statment */
struct StmtEntry {
/*! \brief The statement */
......@@ -51,6 +53,7 @@ struct StmtEntry {
/*! \brief access patterns in the statement */
std::vector<AccessEntry> access;
};
} // namespace storage
} // namespace ir
} // namespace tvm
......
......@@ -8,11 +8,13 @@
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/target_info.h>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include "./ir_util.h"
#include "./storage_access.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
......@@ -187,11 +189,13 @@ class StoragePlanRewriter : public IRMutator {
std::vector<Stmt> nest;
for (StorageEntry* e : attach_map_.at(nullptr)) {
CHECK_EQ(e->scope.rank, 0);
nest.emplace_back(AttrStmt::make(
e->alloc_var, attr::storage_scope,
StringImm::make(e->scope.to_string()),
Evaluate::make(0)));
nest.push_back(e->new_alloc);
if (e->new_alloc.defined()) {
nest.emplace_back(AttrStmt::make(
e->alloc_var, attr::storage_scope,
StringImm::make(e->scope.to_string()),
Evaluate::make(0)));
nest.push_back(e->new_alloc);
}
}
stmt = MergeNest(nest, stmt);
}
......@@ -202,23 +206,55 @@ class StoragePlanRewriter : public IRMutator {
op = stmt.as<Store>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return stmt;
return Store::make(it->second->alloc_var, op->value, op->index, op->predicate);
return Store::make(it->second->alloc_var,
op->value,
RemapIndex(op->value.type(), op->index, it->second),
op->predicate);
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return expr;
return Load::make(op->type, it->second->alloc_var, op->index, op->predicate);
return Load::make(op->type,
it->second->alloc_var,
RemapIndex(op->type, op->index, it->second),
op->predicate);
}
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = alloc_map_.find(op);
if (it != alloc_map_.end()) {
if (it->second->elem_offset != 0) {
LOG(WARNING) << "Use a merged buffer variable address, could cause error";
}
return it->second->alloc_var;
} else {
return e;
}
}
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
auto it = alloc_map_.find(buffer);
if (it == alloc_map_.end()) return IRMutator::Mutate_(op, e);
const StorageEntry* e = it->second;
Expr offset = Mutate(op->args[2]);
Expr extent = Mutate(op->args[3]);
CHECK_EQ(e->elem_type, dtype.element_of());
CHECK_EQ(e->elem_offset % dtype.lanes(), 0);
if (e->elem_offset != 0) {
offset = make_const(offset.type(), e->elem_offset / dtype.lanes()) + offset;
}
return Call::make(
op->type, op->name,
{op->args[0], e->alloc_var, offset, extent, op->args[4]},
op->call_type);
} else {
return IRMutator::Mutate_(op, e);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
CHECK(op->attr_key != attr::virtual_thread)
<< "InjectVirtualThread before StoragePlan";
......@@ -270,63 +306,132 @@ class StoragePlanRewriter : public IRMutator {
// For shared/local memory it is beginning of the thread extent.
// for global memory it is nullptr, means beginning of everything.
const Node* attach_scope_{nullptr};
// The constant size of the buffer in bytes, only used if it is constant.
size_t const_size{0};
// The constant size of the buffer in bits, only used if it is constant
size_t const_nbits{0};
// The storage scope.
StorageScope scope;
// Allocs that shares this entry.
std::vector<const Allocate*> allocs;
// The children of this entry, not including itself.
std::vector<StorageEntry*> merged_children;
// The replacement allocation, if any.
Stmt new_alloc;
// The var expr of new allocation.
VarExpr alloc_var;
// The replacement allocation
Stmt new_alloc;
// The allocation element type.
Type elem_type;
// This is non-zero if this allocate is folded into another one
// the address becomes alloc_var + sizeof(elem_type) * elem_offset;
size_t elem_offset{0};
};
// Remap the index
Expr RemapIndex(Type dtype, Expr index, StorageEntry* e) {
CHECK_EQ(dtype.element_of(), e->elem_type);
if (e->elem_offset == 0) return index;
return make_const(index.type(), e->elem_offset) + index;
}
// Prepare the new allocations
void PrepareNewAlloc() {
for (size_t i = 0; i < alloc_vec_.size(); ++i) {
StorageEntry* e = alloc_vec_[i].get();
attach_map_[e->attach_scope_].push_back(e);
}
// find allocation via attach map.
for (auto &kv : attach_map_) {
// find the element with the most amount of bytes.
Type t = e->allocs[0]->type;
for (const Allocate* op : e->allocs) {
if (op->type.bytes() * op->type.lanes() > t.bytes() * t.lanes()) {
t = op->type;
std::vector<StorageEntry*>& vec = kv.second;
// try to find merge, for tagged memory
for (size_t i = 0; i < vec.size(); ++i) {
StorageEntry* e = vec[i];
if (e->scope.tag.length() != 0) {
CHECK_NE(e->const_nbits, 0U)
<< "Special tagged memory must be const size";
for (size_t j = 0; j < i; ++j) {
if (e->scope == vec[j]->scope) {
vec[j]->merged_children.push_back(e);
break;
}
}
}
}
// Get the allocation size;
e->alloc_var = e->allocs[0]->buffer_var;
if (e->allocs.size() == 1) {
// simply use the original allocation.
e->new_alloc = Allocate::make(
e->alloc_var, t, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate::make(0));
} else {
// Build a merged allocation.
int alloc_unit = t.bytes() * t.lanes();
Expr combo_size;
// Start allocation
for (size_t i = 0; i < vec.size(); ++i) {
StorageEntry* e = vec[i];
// already merged
if (e->elem_offset != 0) continue;
if (e->merged_children.size() != 0) {
NewAllocTagMerged(e); continue;
}
// Get the allocation size;
e->alloc_var = e->allocs[0]->buffer_var;
Type alloc_type = e->allocs[0]->type;
for (const Allocate* op : e->allocs) {
// Get the size
Expr sz = op->extents[0];
for (size_t i = 1; i < op->extents.size(); ++i) {
sz = sz * op->extents[i];
}
int bytes = op->type.bytes() * op->type.lanes();
if (alloc_unit != bytes) {
sz = (sz * make_const(sz.type(), bytes) +
make_const(sz.type(), alloc_unit - 1)) /
make_const(sz.type(), alloc_unit);
if (op->type.lanes() > alloc_type.lanes()) {
alloc_type = op->type;
}
if (combo_size.defined()) {
combo_size = max(combo_size, sz);
} else {
combo_size = sz;
}
if (e->allocs.size() == 1) {
// simply use the original allocation.
e->new_alloc = Allocate::make(
e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate::make(0));
} else {
// Build a merged allocation
Expr combo_size;
for (const Allocate* op : e->allocs) {
Expr sz = arith::ComputeReduce<Mul>(op->extents);
if (alloc_type.lanes() != op->type.lanes()) {
sz = (sz * make_const(sz.type(), op->type.lanes()) +
make_const(sz.type(), alloc_type.lanes() - 1)) /
make_const(sz.type(), alloc_type.lanes());
}
if (combo_size.defined()) {
combo_size = max(combo_size, sz);
} else {
combo_size = sz;
}
}
combo_size = ir::Simplify(combo_size);
e->new_alloc = Allocate::make(
e->alloc_var, alloc_type, {combo_size}, const_true(),
Evaluate::make(0));
}
combo_size = ir::Simplify(combo_size);
e->new_alloc = Allocate::make(
e->alloc_var, t, {combo_size}, const_true(),
Evaluate::make(0));
}
attach_map_[e->attach_scope_].push_back(e);
}
}
// New allocation for merged data
void NewAllocTagMerged(StorageEntry* e) {
CHECK_NE(e->scope.tag.length(), 0U);
// allocate with element type.
CHECK_NE(e->const_nbits, 0U);
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
size_t align = 1;
if (info.defined()) {
align = (info->max_simd_bits + e->elem_type.bits() - 1) / e->elem_type.bits();
}
size_t total_elem = e->const_nbits / e->elem_type.bits();
if (total_elem % align != 0) {
total_elem += align - (total_elem % align);
}
e->alloc_var = e->allocs[0]->buffer_var;
for (StorageEntry* child : e->merged_children) {
CHECK_NE(e->const_nbits, 0U);
CHECK_NE(total_elem, 0U);
size_t num_elem = child->const_nbits / child->elem_type.bits();
child->elem_offset = total_elem;
child->alloc_var = e->alloc_var;
total_elem += num_elem;
if (total_elem % align != 0) {
total_elem += align - (total_elem % align);
}
}
Expr alloc_size = make_const(e->allocs[0]->extents[0].type(), total_elem);
e->new_alloc = Allocate::make(
e->alloc_var, e->elem_type, {alloc_size}, const_true(),
Evaluate::make(0));
if (info.defined()) {
CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
}
}
// Find the free location of each varaible.
......@@ -382,12 +487,13 @@ class StoragePlanRewriter : public IRMutator {
// Allocate new storage entry.
StorageEntry* NewAlloc(const Allocate* op,
const StorageScope& scope,
size_t const_size) {
size_t const_nbits) {
// Re-use not successful, allocate a new buffer.
std::unique_ptr<StorageEntry> entry(new StorageEntry());
entry->attach_scope_ = thread_scope_;
entry->scope = scope;
entry->const_size = const_size;
entry->elem_type = op->type.element_of();
entry->const_nbits = const_nbits;
StorageEntry* e = entry.get();
alloc_vec_.emplace_back(std::move(entry));
return e;
......@@ -397,31 +503,35 @@ class StoragePlanRewriter : public IRMutator {
// skip plan for local variable,
// compiler can do a better job with register allocation.
const size_t match_range = 16;
size_t const_size = static_cast<size_t>(
op->constant_allocation_size()) * op->type.bytes() * op->type.lanes();
size_t const_nbits = static_cast<size_t>(
op->constant_allocation_size() * op->type.bits() * op->type.lanes());
if (scope.rank > 1 || op->type.is_handle()) {
return NewAlloc(op, scope, const_size);
return NewAlloc(op, scope, const_nbits);
}
// disable reuse of small arrays
if (const_size > 0 && const_size <= 32) {
return NewAlloc(op, scope, const_size);
// disable reuse of small arrays, they will be lowered to registers in LLVM
if (const_nbits > 0 &&
const_nbits <= 32 &&
scope.tag.length() == 0) {
return NewAlloc(op, scope, const_nbits);
}
if (const_size != 0) {
if (const_nbits != 0) {
// constant allocation.
auto begin = const_free_map_.lower_bound(const_size / match_range);
auto mid = const_free_map_.lower_bound(const_size);
auto end = const_free_map_.upper_bound(const_size * match_range);
auto begin = const_free_map_.lower_bound(const_nbits / match_range);
auto mid = const_free_map_.lower_bound(const_nbits);
auto end = const_free_map_.upper_bound(const_nbits * match_range);
for (auto it = mid; it != end; ++it) {
StorageEntry *e = it->second;
if (it->second->scope != scope) continue;
e->const_size = std::max(const_size, e->const_size);
if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue;
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
}
for (auto it = mid; it != begin;) {
--it;
StorageEntry *e = it->second;
if (it->second->scope != scope) continue;
if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue;
const_free_map_.erase(it);
return e;
}
......@@ -431,11 +541,12 @@ class StoragePlanRewriter : public IRMutator {
it != sym_free_list_.end(); ++it) {
StorageEntry* e = *it;
if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue;
sym_free_list_.erase(it);
return e;
}
}
return NewAlloc(op, scope, const_size);
return NewAlloc(op, scope, const_nbits);
}
// simulated free.
void Free(const Variable* var) {
......@@ -445,10 +556,10 @@ class StoragePlanRewriter : public IRMutator {
// Disable sharing of local memory.
if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return;
// disable reuse of small arrays
if (e->const_size > 0 && e->const_size <= 32) return;
if (e->const_nbits > 0 && e->const_nbits <= 32) return;
// normal free.
if (e->const_size != 0) {
const_free_map_.insert({e->const_size, e});
if (e->const_nbits != 0) {
const_free_map_.insert({e->const_nbits, e});
} else {
sym_free_list_.push_back(e);
}
......
......@@ -17,7 +17,7 @@ namespace runtime {
struct StorageScope {
/*! \brief The rank of the storage */
int rank{0};
/*! \brief tag for special memory, if any */
/*! \brief tag for special purpose memory. */
std::string tag;
// comparator
inline bool operator==(const StorageScope& other) const {
......
......@@ -17,7 +17,7 @@ def lower(s, args, name="mydot"):
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
return fapi
......
......@@ -32,7 +32,7 @@ def test_add_pipeline():
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0, True)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
def check_target(device, host="stackvm"):
if not tvm.module.enabled(host):
......
......@@ -20,7 +20,7 @@ def test_stack_vm_basic():
Ab = tvm.decl_buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
run_jit(fapi, lambda f: f(a))
......@@ -42,7 +42,7 @@ def test_stack_vm_loop():
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f):
f(a)
......@@ -65,7 +65,7 @@ def test_stack_vm_cond():
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......
import tvm
from tvm.schedule import Buffer
def test_buffer():
m = tvm.var('m')
......@@ -11,6 +12,17 @@ def test_buffer():
assert Ab.dtype == tvm.float32
assert tuple(Ab.shape) == (m, n)
def test_buffer_access_ptr():
m = tvm.var('m')
n = tvm.var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1])
aptr = Ab.access_ptr("rw")
assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m)
assert aptr.args[0].dtype == Ab.dtype
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
aptr = Ab.access_ptr("w")
assert aptr.args[4].value == Buffer.WRITE
if __name__ == "__main__":
test_buffer()
test_buffer_access_ptr()
......@@ -39,7 +39,7 @@ def test_dso_module_load():
tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1))
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
m = tvm.codegen.build_module(fapi, "llvm")
for name in names:
m.save(name)
......
......@@ -30,6 +30,39 @@ def test_storage_share():
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 2
def test_storage_combine():
n = 8
A = tvm.placeholder((4,), name='A')
num_stage = 5
B = A
stages = []
for t in range(num_stage):
B = tvm.compute((n, ), lambda i: B[i] + (t+1), name='A%d' % t)
stages.append(B)
s = tvm.create_schedule(B.op)
for S in stages[:-1]:
s[S].set_scope("global:tag")
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt)
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
assert (n.extents[0].value == 16)
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 1
def test_storage_share_gpu():
m = tvm.var('m')
A = [tvm.placeholder((m), name='A')]
......@@ -67,5 +100,6 @@ def test_storage_share_gpu():
if __name__ == "__main__":
test_storage_combine()
test_storage_share_gpu()
test_storage_share()
......@@ -25,7 +25,7 @@ def test_dltensor_compatible():
A[i + 1] = A[i] + 1
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
f = tvm.codegen.build_module(fapi, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a)
......
......@@ -17,20 +17,26 @@ def intrin_gemv(m, n):
k = tvm.reduce_axis((0, n), name='k')
z = tvm.compute((m,), lambda i:
tvm.sum(w[i, k] * x[k], axis=k), name='z')
Wb = tvm.decl_buffer(w.shape, w.dtype, name="W",
Wb = tvm.decl_buffer(w.shape, w.dtype,
name="W",
offset_factor=16,
strides=[tvm.var('ldw'), 1])
def intrin_func(ins, outs):
ww, xx = ins
zz = outs[0]
ww_ptr = ww.access_ptr("r")
xx_ptr = xx.access_ptr("r")
zz_ptr = zz.access_ptr("w")
body = tvm.call_packed(
"gemv", ww.data, xx.data, zz.data, n, ww.strides[0])
"gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
reset = tvm.call_packed(
"fill_zero", outs[0].data, n)
"fill_zero", zz_ptr, n)
update = tvm.call_packed(
"gemv_add", ww.data, xx.data, zz.data, n, ww.strides[0])
"gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
return body, reset, update
with tvm.build_config(data_alignment=16):
with tvm.build_config(data_alignment=16,
offset_factor=16):
return tvm.decl_tensor_intrin(z.op, intrin_func,
binds={w: Wb})
......@@ -93,6 +99,7 @@ def test_tensorize_matmul():
stmt = tvm.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
def check_rfactor(factor, rfactor):
s = tvm.create_schedule(C.op)
x, y = C.op.axis
......
......@@ -39,7 +39,7 @@ def test_add_pipeline():
s[C].bind(px, tvm.thread_axis("pipeline"))
fapi = lower(s, [A, B, C], "myadd")
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
print("------")
def check_target(device, host="stackvm"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment