Commit 7e3d9da4 by Tianqi Chen Committed by GitHub

[STORAGE][BUFFER] Support access ptr for clear access pattern. (#266)

* [STORAGE][BUFFER] Support access ptr for clear access pattern.

* fix lint
parent 1f7712ae
......@@ -55,7 +55,7 @@ tvm.ir_pass
tvm.ir_pass.SplitPipeline
tvm.ir_pass.LowerThreadAllreduce
tvm.ir_pass.LowerIntrin
tvm.ir_pass.LowerPackedCall
tvm.ir_pass.LowerTVMBuiltin
tvm.ir_pass.NarrowChannelAccess
.. automodule:: tvm.ir_pass
......
......@@ -17,6 +17,12 @@ namespace tvm {
// Internal node container Buffer
class BufferNode;
/*! \brief memory access kind */
enum class AccessMask : int {
kRead = 1,
kWrite = 2
};
/*!
* \brief Buffer is a symbolic n-darray structure.
* It is a composition of primitive symbolic types,
......@@ -55,6 +61,12 @@ class Buffer : public NodeRef {
*/
Buffer MakeSlice(Array<Expr> begins, Array<Expr> extents) const;
/*!
* \brief Get access ptr to the entire buffer.
* \param access_mask The access mask
* \param ptr_type The type of the pointer.
*/
Expr access_ptr(int access_mask, Type ptr_type = Handle()) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
......
......@@ -208,6 +208,22 @@ namespace intrinsic {
*/
constexpr const char* tvm_address_of = "tvm_address_of";
/*!
* \brief Get head access address with memory access pattern info.
*
* This operator also marks range of the memory access
* The offset and extent are in unit of the DType(including vectorization factor).
* rw_mask is a bit_mask setting whether the access is a read(1) or write(2).
* The access is assume to happen in the current expression.
*
* PtrType tvm_access_ptr(Expr dtype, DType* data,
* int offset, int extent,
* int rw_mask) {
* // DType == dtype.type();
* return &data[offset];
* }
*/
constexpr const char* tvm_access_ptr = "tvm_access_ptr";
/*!
* \brief tvm_tuple is not an actual function and cannot codegen.
* It is used to represent tuple structure in value field of AttrStmt,
* for the sake of giving hint to optimization.
......
......@@ -334,7 +334,7 @@ LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
* \param f The function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerPackedCall(LoweredFunc f);
LoweredFunc LowerTVMBuiltin(LoweredFunc f);
/*!
* \brief Combine context function calls.
......
......@@ -6,6 +6,7 @@
#ifndef TVM_TARGET_INFO_H_
#define TVM_TARGET_INFO_H_
#include <string>
#include "./base.h"
#include "./expr.h"
......@@ -36,5 +37,12 @@ struct MemoryInfoNode : public Node {
/*! \brief Defines memory info */
TVM_DEFINE_NODE_REF(MemoryInfo, MemoryInfoNode);
/*!
* \brief get memory info given scope
* \param scope The scope name.
* \return info The memory info.
*/
MemoryInfo GetMemoryInfo(const std::string& scope);
} // namespace tvm
#endif // TVM_TARGET_INFO_H_
......@@ -321,7 +321,7 @@ def build(sch,
device = "cpu" if target.startswith("llvm") or target == "stackvm" else target
device_type = ndarray.context(device, 0).device_type
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerPackedCall(x) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
if fdevice:
......
"""The computation schedule api of TVM."""
from __future__ import absolute_import as _abs
from ._ffi.base import string_types
from ._ffi.node import NodeBase, register_node
from ._ffi.function import _init_api
from . import _api_internal
from . import tensor as _tensor
from . import expr as _expr
from . import container as _container
from ._ffi.function import _init_api
@register_node
class Buffer(NodeBase):
......@@ -21,7 +22,49 @@ class Buffer(NodeBase):
--------
decl_buffer : Declare a buffer
"""
pass
READ = 1
WRITE = 2
def access_ptr(self, access_mask, ptr_type="handle"):
"""Get an access pointer to the head of buffer
This is the recommended method to get buffer data
ptress when interacting with external functions.
Parameters
----------
access_mask : int
The access pattern MASK. Indicate whether the
access will read or write to the data content.
ptr_type : str, optional
The data type of the result pointer. Do not specify
unless we want to cast pointer to specific type.
Examples
--------
.. code-block:: python
import tvm.schedule.Buffer
# Get access ptr for read
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# Get access ptr for read/write with str flag
buffer.access_ptr("rw")
"""
if isinstance(access_mask, string_types):
mask = 0
for value in access_mask:
if value == "r":
mask = mask | Buffer.READ
elif value == "w":
mask = mask | Buffer.WRITE
else:
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type)
@register_node
......
......@@ -156,6 +156,12 @@ TVM_REGISTER_API("_Buffer")
args[8]);
});
TVM_REGISTER_API("_BufferAccessPtr")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer()
.access_ptr(args[1], args[2]);
});
TVM_REGISTER_API("_Tensor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorNode::make(args[0],
......
......@@ -102,7 +102,7 @@ REGISTER_PASS2(SplitPipeline);
REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerPackedCall);
REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
} // namespace ir
} // namespace tvm
......@@ -22,6 +22,7 @@ using Halide::Internal::mul_would_overflow;
* \brief Compute the expression with the given binary op.
* \param lhs The left operand
* \param rhs The right operand
* \tparam Op the computation operator
* \return The result.
*/
template<typename OP>
......@@ -29,6 +30,15 @@ inline Expr ComputeExpr(Expr lhs, Expr rhs) {
return OP::make(lhs, rhs);
}
/*!
* \brief Compute an reduction with Op
* \param values The input values.
* \tparam Op The computation operator
* \return The result.
*/
template<typename Op>
inline Expr ComputeReduce(const Array<Expr>& values);
template<typename T>
inline bool GetConst(Expr e, T* out);
......@@ -128,6 +138,16 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return Halide::Internal::Interval::make_min(a, b);
}
template<typename Op>
inline Expr ComputeReduce(const Array<Expr>& values) {
CHECK_NE(values.size(), 0U);
Expr res = values[0];
for (size_t i = 1; i < values.size(); ++i) {
res = ComputeExpr<Op>(res, values[i]);
}
return res;
}
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_COMPUTE_EXPR_H_
......@@ -6,6 +6,7 @@
#include <tvm/runtime/device_api.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "../arithmetic/compute_expr.h"
namespace tvm {
......@@ -131,6 +132,19 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
0);
}
Expr Buffer::access_ptr(int access_mask, Type ptr_type) const {
const BufferNode* self = operator->();
Expr e_dtype = make_zero(self->dtype);
Expr extent = (self->strides.size() == self->shape.size() ?
arith::ComputeExpr<ir::Mul>(self->strides[0], self->shape[0]):
arith::ComputeReduce<ir::Mul>(self->shape));
Array<Expr> acc_args{
e_dtype, self->data, self->elem_offset,
extent, make_const(Int(32), access_mask)};
return ir::Call::make(
ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic);
}
Buffer BufferNode::make(Var data,
Type dtype,
Array<Expr> shape,
......
......@@ -3,6 +3,7 @@
* \file target_info.cc
*/
#include <tvm/target_info.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
......@@ -16,4 +17,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(MemoryInfoNode);
MemoryInfo GetMemoryInfo(const std::string& scope) {
std::string fname = "tvm.info.mem." + scope;
const runtime::PackedFunc* f = runtime::Registry::Get(fname);
if (f == nullptr) {
return MemoryInfo();
} else {
return (*f)();
}
}
} // namespace tvm
......@@ -299,10 +299,18 @@ void VerifyTensorizeBody(
for (size_t i = 0; i < body.size(); ++i) {
Expr lhs = CanonicalSimplify(body[i]);
Expr rhs = CanonicalSimplify(intrin_compute->body[i]);
if (lhs.type() != rhs.type()) {
LOG(FATAL)
<< "Failed to match the data type with TensorIntrin "
<< intrin->name << "'s declaration "
<< " provided=" << lhs.type()
<< ", intrin=" << rhs.type();
}
CHECK(Equal(lhs, rhs))
<< "Failed to match the compute with TensorIntrin declaration "
<< " provided:" << lhs
<< ", intrin:" << rhs;
<< "Failed to match the compute with TensorIntrin "
<< intrin->name << "'s declaration "
<< " provided= " << lhs
<< ", intrin= " << rhs;
}
}
......
......@@ -89,19 +89,11 @@ void ArgBinder::BindBuffer(const Buffer& arg,
<< ", provided_alignment=" << value->data_alignment;
}
// bind pointer and offset.
if (is_zero(arg->elem_offset) && !is_zero(value->elem_offset)) {
// If target buffer only accepts pointer, try to fold offset into pointer.
Expr addr = AddressOffset(value->data, value->dtype, value->elem_offset);
if (Bind_(arg->data, addr, arg_name + ".data", true)) {
int offset_factor = arg->data_alignment * 8 / (arg->dtype.bits() * arg->dtype.lanes());
if (offset_factor > 1) {
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
if (is_zero(arg->elem_offset)) {
CHECK(is_zero(value->elem_offset))
<< "Trying to bind a Buffer with offset into one without offset";
}
} else {
this->Bind(arg->data, value->data, arg_name + ".data");
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
......@@ -111,7 +103,7 @@ void ArgBinder::BindBuffer(const Buffer& arg,
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
}
if (arg->shape.size() < value->shape.size()) {
CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
size_t diff = value->shape.size() - arg->shape.size();
......
......@@ -100,9 +100,12 @@ inline Expr AddressOffset(Var handle, Type dtype, int offset) {
* \param offset the offset index.
*/
inline Expr AddressOffset(Var handle, Type dtype, Expr offset) {
if (dtype.lanes() != 1) {
offset = offset * make_const(offset.type(), dtype.lanes());
}
return Call::make(
Handle(), intrinsic::tvm_address_of,
{Load::make(dtype, handle, offset * make_const(offset.type(), dtype.lanes()),
{Load::make(dtype, handle, offset,
const_true(dtype.lanes()))},
Call::PureIntrinsic);
}
......
/*!
* Copyright (c) 2017 by Contributors
* Lower calls to packed function.
* \file lower_packed_call.cc
* Lower TVM related buildin intrinsics such as packed call.
* \file lower_tvm_buildin.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/target_info.h>
#include <unordered_set>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageScope;
inline Expr ConstInt32(size_t index) {
CHECK_LE(index, std::numeric_limits<int>::max());
return make_const(Int(32), static_cast<int>(index));
......@@ -25,7 +28,7 @@ inline Expr StackAlloca(std::string type, size_t num) {
// Calculate the statistics of packed function.
// These information are needed during codegen.
class PackedCallBuilder : public IRMutator {
class BuiltinLower : public IRMutator {
public:
Stmt Build(Stmt stmt) {
stack_shape_ = Var("stack_shape", Handle());
......@@ -49,6 +52,7 @@ class PackedCallBuilder : public IRMutator {
}
return stmt;
}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
CHECK_EQ(run_shape_stack_, 0);
......@@ -65,6 +69,14 @@ class PackedCallBuilder : public IRMutator {
// Lower allocate to device allocate when needed.
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
// For special memory, remove allocate.
auto it = storage_info_.find(op->buffer_var.get());
if (it != storage_info_.end() && it->second.scope.tag.length() != 0) {
++it->second.alloc_count;
CHECK_LE(it->second.alloc_count, 1)
<< "Double allocation of " << it->second.scope.to_string();
return op->body;
}
// Get constant allocation bound.
int64_t dev_type;
int64_t nbytes = GetVectorBytes(op->type);
......@@ -127,12 +139,25 @@ class PackedCallBuilder : public IRMutator {
CHECK(!device_type_.defined());
device_type_ = op->value;
return Mutate(op->body);
} else if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
StorageEntry e;
e.scope = scope;
if (scope.tag.length() != 0) {
e.info = GetMemoryInfo(op->value.as<StringImm>()->value);
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
}
storage_info_[buf] = e;
return IRMutator::Mutate_(op, s);
} else {
return IRMutator::Mutate_(op, s);
}
}
Expr Mutate_(const Call* op, const Expr &e) final {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return MakeCallPacked(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
return MakeShape(op, e);
......@@ -142,6 +167,7 @@ class PackedCallBuilder : public IRMutator {
return IRMutator::Mutate_(op, e);
}
}
Expr Convert(Type t, Expr e) {
if (e.type() != t) {
return Cast::make(t, e);
......@@ -254,6 +280,33 @@ class PackedCallBuilder : public IRMutator {
Int(32), intrinsic::tvm_call_packed_lowered,
packed_args, Call::Intrinsic);
}
// tvm_access_ptr
Expr MakeAccessPtr(const Call* op, const Expr& e) {
// Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U);
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
Expr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.scope.tag.length() != 0) {
return MakeTaggedAccessPtr(
op->type, dtype, offset,
it->second.info.defined() ? it->second.info->unit_bits : 8);
}
CHECK(op->type.is_handle());
// Change to address_of
return AddressOffset(Var(op->args[1].node_), dtype, offset);
}
Expr MakeTaggedAccessPtr(Type ptr_type, Type dtype,
Expr offset, int unit_bits) {
int dtype_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(unit_bits % dtype_bits, 0);
return Convert(ptr_type,
ir::Simplify(offset / make_const(offset.type(), unit_bits / dtype_bits)));
}
private:
bool IsArrayHandle(const Expr& arg) {
......@@ -284,11 +337,22 @@ class PackedCallBuilder : public IRMutator {
uint64_t max_shape_stack_{0};
uint64_t max_array_stack_{0};
uint64_t max_arg_stack_{0};
// The storage entry.
struct StorageEntry {
// Whether it is tagged memory.
StorageScope scope;
// The memory info if any.
MemoryInfo info;
// Allocation counter
int alloc_count{0};
};
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageEntry> storage_info_;
};
LoweredFunc LowerPackedCall(LoweredFunc f) {
LoweredFunc LowerTVMBuiltin(LoweredFunc f) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = PackedCallBuilder().Build(n->body);
n->body = BuiltinLower().Build(n->body);
return LoweredFunc(n);
}
......
......@@ -26,6 +26,7 @@ enum AccessType {
kSync,
kAlloc
};
/*! \brief The access entry */
struct AccessEntry {
/*! \brief The buffer variable, if any */
......@@ -44,6 +45,7 @@ struct AccessEntry {
StorageScope scope)
: buffer(buffer), index(index), type(type), scope(scope) {}
};
/*! \brief The access info about a statment */
struct StmtEntry {
/*! \brief The statement */
......@@ -51,6 +53,7 @@ struct StmtEntry {
/*! \brief access patterns in the statement */
std::vector<AccessEntry> access;
};
} // namespace storage
} // namespace ir
} // namespace tvm
......
......@@ -17,7 +17,7 @@ namespace runtime {
struct StorageScope {
/*! \brief The rank of the storage */
int rank{0};
/*! \brief tag for special memory, if any */
/*! \brief tag for special purpose memory. */
std::string tag;
// comparator
inline bool operator==(const StorageScope& other) const {
......
......@@ -17,7 +17,7 @@ def lower(s, args, name="mydot"):
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
return fapi
......
......@@ -32,7 +32,7 @@ def test_add_pipeline():
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0, True)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
def check_target(device, host="stackvm"):
if not tvm.module.enabled(host):
......
......@@ -20,7 +20,7 @@ def test_stack_vm_basic():
Ab = tvm.decl_buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
run_jit(fapi, lambda f: f(a))
......@@ -42,7 +42,7 @@ def test_stack_vm_loop():
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f):
f(a)
......@@ -65,7 +65,7 @@ def test_stack_vm_cond():
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......
import tvm
from tvm.schedule import Buffer
def test_buffer():
m = tvm.var('m')
......@@ -11,6 +12,17 @@ def test_buffer():
assert Ab.dtype == tvm.float32
assert tuple(Ab.shape) == (m, n)
def test_buffer_access_ptr():
m = tvm.var('m')
n = tvm.var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1])
aptr = Ab.access_ptr("rw")
assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m)
assert aptr.args[0].dtype == Ab.dtype
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
aptr = Ab.access_ptr("w")
assert aptr.args[4].value == Buffer.WRITE
if __name__ == "__main__":
test_buffer()
test_buffer_access_ptr()
......@@ -39,7 +39,7 @@ def test_dso_module_load():
tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1))
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
m = tvm.codegen.build_module(fapi, "llvm")
for name in names:
m.save(name)
......
......@@ -30,6 +30,39 @@ def test_storage_share():
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 2
def test_storage_combine():
n = 8
A = tvm.placeholder((4,), name='A')
num_stage = 5
B = A
stages = []
for t in range(num_stage):
B = tvm.compute((n, ), lambda i: B[i] + (t+1), name='A%d' % t)
stages.append(B)
s = tvm.create_schedule(B.op)
for S in stages[:-1]:
s[S].set_scope("global:tag")
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
stmt = tvm.ir_pass.StorageRewrite(stmt)
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
num_alloc[0] += 1
assert (n.extents[0].value == 16)
tvm.ir_pass.PostOrderVisit(stmt, verify)
assert num_alloc[0] == 1
def test_storage_share_gpu():
m = tvm.var('m')
A = [tvm.placeholder((m), name='A')]
......@@ -67,5 +100,6 @@ def test_storage_share_gpu():
if __name__ == "__main__":
test_storage_combine()
test_storage_share_gpu()
test_storage_share()
......@@ -25,7 +25,7 @@ def test_dltensor_compatible():
A[i + 1] = A[i] + 1
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
f = tvm.codegen.build_module(fapi, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a)
......
......@@ -17,20 +17,26 @@ def intrin_gemv(m, n):
k = tvm.reduce_axis((0, n), name='k')
z = tvm.compute((m,), lambda i:
tvm.sum(w[i, k] * x[k], axis=k), name='z')
Wb = tvm.decl_buffer(w.shape, w.dtype, name="W",
Wb = tvm.decl_buffer(w.shape, w.dtype,
name="W",
offset_factor=16,
strides=[tvm.var('ldw'), 1])
def intrin_func(ins, outs):
ww, xx = ins
zz = outs[0]
ww_ptr = ww.access_ptr("r")
xx_ptr = xx.access_ptr("r")
zz_ptr = zz.access_ptr("w")
body = tvm.call_packed(
"gemv", ww.data, xx.data, zz.data, n, ww.strides[0])
"gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
reset = tvm.call_packed(
"fill_zero", outs[0].data, n)
"fill_zero", zz_ptr, n)
update = tvm.call_packed(
"gemv_add", ww.data, xx.data, zz.data, n, ww.strides[0])
"gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
return body, reset, update
with tvm.build_config(data_alignment=16):
with tvm.build_config(data_alignment=16,
offset_factor=16):
return tvm.decl_tensor_intrin(z.op, intrin_func,
binds={w: Wb})
......@@ -93,6 +99,7 @@ def test_tensorize_matmul():
stmt = tvm.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
def check_rfactor(factor, rfactor):
s = tvm.create_schedule(C.op)
x, y = C.op.axis
......
......@@ -39,7 +39,7 @@ def test_add_pipeline():
s[C].bind(px, tvm.thread_axis("pipeline"))
fapi = lower(s, [A, B, C], "myadd")
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
print("------")
def check_target(device, host="stackvm"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment