Commit f2b91392 by Tianqi Chen Committed by GitHub

Support rank-0 tensor (#687)

* Support rank-0 tensor

* fix lint
parent df4962e2
...@@ -124,6 +124,11 @@ class BufferNode : public Node { ...@@ -124,6 +124,11 @@ class BufferNode : public Node {
v->Visit("offset_factor", &offset_factor); v->Visit("offset_factor", &offset_factor);
} }
/*! \return preferred index type for this buffer node */
Type DefaultIndexType() const {
return shape.size() != 0 ? shape[0].type() : Int(32);
}
// User can specify data_alignment and offset_factor to be 0 // User can specify data_alignment and offset_factor to be 0
// A default value will be picked. // A default value will be picked.
TVM_DLL static Buffer make(Var ptr, TVM_DLL static Buffer make(Var ptr,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "./base.h" #include "./base.h"
#include "./expr.h" #include "./expr.h"
#include "./tensor.h"
#include "./runtime/packed_func.h" #include "./runtime/packed_func.h"
namespace tvm { namespace tvm {
...@@ -116,6 +117,9 @@ inline TVMArgValue::operator Halide::Expr() const { ...@@ -116,6 +117,9 @@ inline TVMArgValue::operator Halide::Expr() const {
if (sptr->is_type<IterVarNode>()) { if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var; return IterVar(sptr)->var;
} }
if (sptr->is_type<TensorNode>()) {
return Tensor(sptr)();
}
CHECK(NodeTypeChecker<Expr>::Check(sptr.get())) CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<Expr>() << "Expected type " << NodeTypeName<Expr>()
<< " but get " << sptr->type_key(); << " but get " << sptr->type_key();
......
...@@ -188,7 +188,7 @@ inline bool Tensor::operator==(const Tensor& other) const { ...@@ -188,7 +188,7 @@ inline bool Tensor::operator==(const Tensor& other) const {
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
inline Expr operator Op (const Tensor::Slice& a) { \ inline Expr operator Op (const Tensor::Slice& a) { \
return Op a.operator Expr() ; \ return Op a.operator Expr() ; \
} } \
#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
template<typename T> \ template<typename T> \
......
...@@ -177,13 +177,14 @@ class NDArrayBase(_NDArrayBase): ...@@ -177,13 +177,14 @@ class NDArrayBase(_NDArrayBase):
shape = shape + (t.lanes,) shape = shape + (t.lanes,)
t.lanes = 1 t.lanes = 1
dtype = str(t) dtype = str(t)
source_array = np.ascontiguousarray(source_array, dtype=dtype)
if source_array.shape != shape: if source_array.shape != shape:
raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format( raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape)) source_array.shape, shape))
source_array = np.ascontiguousarray(source_array, dtype=dtype)
assert source_array.flags['C_CONTIGUOUS'] assert source_array.flags['C_CONTIGUOUS']
data = source_array.ctypes.data_as(ctypes.c_void_p) data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np.prod(source_array.shape) * source_array.dtype.itemsize) nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes)) check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
return self return self
...@@ -212,7 +213,7 @@ class NDArrayBase(_NDArrayBase): ...@@ -212,7 +213,7 @@ class NDArrayBase(_NDArrayBase):
np_arr = np.empty(shape, dtype=dtype) np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags['C_CONTIGUOUS'] assert np_arr.flags['C_CONTIGUOUS']
data = np_arr.ctypes.data_as(ctypes.c_void_p) data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np.prod(np_arr.shape) * np_arr.dtype.itemsize) nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes)) check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
return np_arr return np_arr
......
...@@ -462,7 +462,6 @@ def decl_buffer(shape, ...@@ -462,7 +462,6 @@ def decl_buffer(shape,
elem_offset = var('%s_elem_offset' % name, shape[0].dtype) elem_offset = var('%s_elem_offset' % name, shape[0].dtype)
if data is None: if data is None:
data = var(name, "handle") data = var(name, "handle")
return _api_internal._Buffer( return _api_internal._Buffer(
data, dtype, shape, strides, elem_offset, name, scope, data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor) data_alignment, offset_factor)
......
...@@ -32,7 +32,7 @@ class TensorSlice(NodeGeneric, _expr.ExprOp): ...@@ -32,7 +32,7 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
itervar_cls = None itervar_cls = None
@register_node @register_node
class Tensor(NodeBase): class Tensor(NodeBase, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor""" """Tensor object, to construct, see function.Tensor"""
def __call__(self, *indices): def __call__(self, *indices):
ndim = self.ndim ndim = self.ndim
...@@ -60,7 +60,13 @@ class Tensor(NodeBase): ...@@ -60,7 +60,13 @@ class Tensor(NodeBase):
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Tensor): if not isinstance(other, Tensor):
if isinstance(other, _expr.ExprOp):
return _expr.EqualOp(self, other)
return False return False
if self.ndim == 0 and other.ndim == 0:
raise ValueError("Equal == comparison among rank-0 tensor is ambiguous, "
"use Tensor.equal for content expression equvalence, "
"use Tensor.same_as for exact reference comparison")
return _api_internal._TensorEqual(self, other) return _api_internal._TensorEqual(self, other)
@property @property
......
...@@ -33,11 +33,14 @@ inline Expr ComputeExpr(Expr lhs, Expr rhs) { ...@@ -33,11 +33,14 @@ inline Expr ComputeExpr(Expr lhs, Expr rhs) {
/*! /*!
* \brief Compute an reduction with Op * \brief Compute an reduction with Op
* \param values The input values. * \param values The input values.
* \param empty_value The value when return if it is empty, can be Expr()
* which will cause an error to be rasied.
* \tparam Op The computation operator * \tparam Op The computation operator
* \return The result. * \return The result.
*/ */
template<typename Op> template<typename Op>
inline Expr ComputeReduce(const Array<Expr>& values); inline Expr ComputeReduce(
const Array<Expr>& values, Expr empty_value);
template<typename T> template<typename T>
inline bool GetConst(Expr e, T* out); inline bool GetConst(Expr e, T* out);
...@@ -139,8 +142,11 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) { ...@@ -139,8 +142,11 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
} }
template<typename Op> template<typename Op>
inline Expr ComputeReduce(const Array<Expr>& values) { inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) {
CHECK_NE(values.size(), 0U); if (values.size() == 0U) {
CHECK(empty_value.defined());
return empty_value;
}
Expr res = values[0]; Expr res = values[0];
for (size_t i = 1; i < values.size(); ++i) { for (size_t i = 1; i < values.size(); ++i) {
res = ComputeExpr<Op>(res, values[i]); res = ComputeExpr<Op>(res, values[i]);
......
...@@ -11,15 +11,6 @@ ...@@ -11,15 +11,6 @@
namespace tvm { namespace tvm {
Array<Expr> GetStrides(Array<Expr> shape) {
CHECK_NE(shape.size(), 0U);
std::vector<Expr> vec{make_const(shape[0].type(), 1)};
for (size_t i = shape.size() - 1; i != 0; --i) {
vec.push_back(shape[i - 1] * vec.back());
}
return Array<Expr>(vec.rbegin(), vec.rend());
}
Array<Expr> SimplifyArray(Array<Expr> array) { Array<Expr> SimplifyArray(Array<Expr> array) {
for (size_t i = 0; i < array.size(); ++i) { for (size_t i = 0; i < array.size(); ++i) {
array.Set(i, ir::Simplify(array[i])); array.Set(i, ir::Simplify(array[i]));
...@@ -235,11 +226,13 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) { ...@@ -235,11 +226,13 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
Expr base = n->elem_offset; Expr base = n->elem_offset;
if (n->strides.size() == 0) { if (n->strides.size() == 0) {
CHECK_EQ(n->shape.size(), index.size()); CHECK_EQ(n->shape.size(), index.size());
if (n->shape.size() != 0) {
if (is_zero(base)) { if (is_zero(base)) {
base = index[0]; base = index[0];
} else { } else {
base = base + index[0]; base = base + index[0];
} }
}
base = MergeMulMod(base); base = MergeMulMod(base);
for (size_t i = 1; i < index.size(); ++i) { for (size_t i = 1; i < index.size(); ++i) {
base = MergeMulMod(base * n->shape[i] + index[i]); base = MergeMulMod(base * n->shape[i] + index[i]);
...@@ -294,9 +287,10 @@ Stmt Buffer::vstore(Array<Expr> begin, Expr value) const { ...@@ -294,9 +287,10 @@ Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
Buffer Buffer::MakeStrideView() const { Buffer Buffer::MakeStrideView() const {
if ((*this)->strides.size() != 0) return *this; if ((*this)->strides.size() != 0) return *this;
if ((*this)->shape.size() == 0) return *this;
std::vector<Expr> temp; std::vector<Expr> temp;
auto n = std::make_shared<BufferNode>(*operator->()); auto n = std::make_shared<BufferNode>(*operator->());
Expr acc = make_const(n->shape[0].type(), 1); Expr acc = make_const(n->DefaultIndexType(), 1);
for (size_t i = n->shape.size(); i != 0 ; --i) { for (size_t i = n->shape.size(); i != 0 ; --i) {
temp.push_back(acc); temp.push_back(acc);
acc = acc * n->shape[i - 1]; acc = acc * n->shape[i - 1];
...@@ -344,9 +338,16 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const { ...@@ -344,9 +338,16 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const { Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const {
const BufferNode* self = operator->(); const BufferNode* self = operator->();
Expr e_dtype; Expr e_dtype;
Expr extent = (self->strides.size() == self->shape.size() ? Expr extent;
arith::ComputeExpr<ir::Mul>(self->strides[0], self->shape[0]): if (self->shape.size() == 0) {
arith::ComputeReduce<ir::Mul>(self->shape)); extent = make_const(self->DefaultIndexType(), 1);
} else if (self->strides.size() == self->shape.size()) {
int highest_dim = 0;
extent = arith::ComputeExpr<ir::Mul>(
self->strides[highest_dim], self->shape[highest_dim]);
} else {
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr());
}
Expr elem_offset = self->elem_offset; Expr elem_offset = self->elem_offset;
if (content_lanes > 1) { if (content_lanes > 1) {
e_dtype = make_zero(self->dtype.with_lanes(content_lanes)); e_dtype = make_zero(self->dtype.with_lanes(content_lanes));
...@@ -383,7 +384,7 @@ Buffer BufferNode::make(Var data, ...@@ -383,7 +384,7 @@ Buffer BufferNode::make(Var data,
} }
n->scope = std::move(scope); n->scope = std::move(scope);
if (!elem_offset.defined()) { if (!elem_offset.defined()) {
elem_offset = make_const(n->shape[0].type(), 0); elem_offset = make_const(n->DefaultIndexType(), 0);
} }
if (data_alignment <= 0) { if (data_alignment <= 0) {
data_alignment = runtime::kAllocAlignment; data_alignment = runtime::kAllocAlignment;
......
...@@ -196,7 +196,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -196,7 +196,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
nop)); nop));
if (buffer->strides.size() == 0) { if (buffer->strides.size() == 0) {
// Assert the buffer is compact // Assert the buffer is compact
Type stype = buffer->shape[0].type(); Type stype = buffer->DefaultIndexType();
Expr expect_stride = make_const(stype, 1); Expr expect_stride = make_const(stype, 1);
Array<Expr> conds; Array<Expr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) { for (size_t i = buffer->shape.size(); i != 0; --i) {
...@@ -211,14 +211,16 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -211,14 +211,16 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
std::ostringstream stride_err_msg; std::ostringstream stride_err_msg;
stride_err_msg << arg_name << ".strides:" stride_err_msg << arg_name << ".strides:"
<< " expected to be compact array"; << " expected to be compact array";
if (conds.size() != 0) {
Stmt check = Stmt check =
AssertStmt::make(arith::ComputeReduce<ir::And>(conds), AssertStmt::make(arith::ComputeReduce<ir::And>(conds, Expr()),
stride_err_msg.str(), Evaluate::make(0)); stride_err_msg.str(), Evaluate::make(0));
Expr is_null = Call::make( Expr is_null = Call::make(
Bool(1), intrinsic::tvm_handle_is_null, Bool(1), intrinsic::tvm_handle_is_null,
{v_strides}, Call::PureIntrinsic); {v_strides}, Call::PureIntrinsic);
check = IfThenElse::make(Not::make(is_null), check, Stmt()); check = IfThenElse::make(Not::make(is_null), check, Stmt());
init_nest_.emplace_back(Block::make(check, Evaluate::make(0))); init_nest_.emplace_back(Block::make(check, Evaluate::make(0)));
}
} else { } else {
for (size_t k = 0; k < buffer->strides.size(); ++k) { for (size_t k = 0; k < buffer->strides.size(); ++k) {
std::ostringstream field_name; std::ostringstream field_name;
......
...@@ -81,7 +81,8 @@ class DoubleBufferInjector : public IRMutator { ...@@ -81,7 +81,8 @@ class DoubleBufferInjector : public IRMutator {
Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt Mutate_(const Allocate* op, const Stmt& s) final {
auto it = dbuffer_info_.find(op->buffer_var.get()); auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) { if (it != dbuffer_info_.end()) {
it->second.stride = arith::ComputeReduce<Mul>(op->extents) * op->type.lanes(); it->second.stride = arith::ComputeReduce<Mul>
(op->extents, Expr()) * op->type.lanes();
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>(); op = stmt.as<Allocate>();
Array<Expr> new_extents{make_const(op->extents[0].type(), 2)}; Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
......
...@@ -376,7 +376,8 @@ class VTInjector : public IRMutator { ...@@ -376,7 +376,8 @@ class VTInjector : public IRMutator {
// always rewrite if not allow sharing. // always rewrite if not allow sharing.
if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
// place v on highest dimension. // place v on highest dimension.
Expr stride = arith::ComputeReduce<Mul>(op->extents) * op->type.lanes(); Expr stride = arith::ComputeReduce<Mul>(
op->extents, Expr()) * op->type.lanes();
Array<Expr> other; Array<Expr> other;
other.push_back(make_const(op->extents[0].type(), num_threads_)); other.push_back(make_const(op->extents[0].type(), num_threads_));
for (Expr e : extents) { for (Expr e : extents) {
......
...@@ -147,10 +147,11 @@ class StorageFlattener : public IRMutator { ...@@ -147,10 +147,11 @@ class StorageFlattener : public IRMutator {
} }
} }
Array<Expr> strides; Array<Expr> strides;
if (dim_align_.count(key) != 0) { if (dim_align_.count(key) != 0 && shape.size() != 0) {
std::vector<Expr> rstrides; std::vector<Expr> rstrides;
const std::vector<DimAlignInfo>& avec = dim_align_[key]; const std::vector<DimAlignInfo>& avec = dim_align_[key];
Expr stride = make_const(shape[0].type(), 1); int first_dim = 0;
Expr stride = make_const(shape[first_dim].type(), 1);
for (size_t i = shape.size(); i != 0; --i) { for (size_t i = shape.size(); i != 0; --i) {
size_t dim = i - 1; size_t dim = i - 1;
if (dim < avec.size() && avec[dim].align_factor != 0) { if (dim < avec.size() && avec[dim].align_factor != 0) {
...@@ -164,6 +165,7 @@ class StorageFlattener : public IRMutator { ...@@ -164,6 +165,7 @@ class StorageFlattener : public IRMutator {
} }
strides = Array<Expr>(rstrides.rbegin(), rstrides.rend()); strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
} }
e.buffer = BufferNode::make( e.buffer = BufferNode::make(
Var(key.GetName(), Handle()), Var(key.GetName(), Handle()),
op->type, shape, strides, Expr(), op->type, shape, strides, Expr(),
...@@ -176,13 +178,18 @@ class StorageFlattener : public IRMutator { ...@@ -176,13 +178,18 @@ class StorageFlattener : public IRMutator {
Stmt ret; Stmt ret;
if (strides.size() != 0) { if (strides.size() != 0) {
int first_dim = 0;
ret = Allocate::make( ret = Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->data, e.buffer->dtype,
{arith::ComputeExpr<Mul>(e.buffer->strides[0], e.buffer->shape[0])}, {arith::ComputeExpr<Mul>(e.buffer->strides[first_dim], e.buffer->shape[first_dim])},
make_const(Bool(e.buffer->dtype.lanes()), true), body); make_const(Bool(e.buffer->dtype.lanes()), true), body);
} else { } else {
shape = e.buffer->shape;
if (shape.size() == 0) {
shape.push_back(make_const(Int(32), 1));
}
ret = Allocate::make( ret = Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape, e.buffer->data, e.buffer->dtype, shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body); make_const(Bool(e.buffer->dtype.lanes()), true), body);
} }
ret = AttrStmt::make( ret = AttrStmt::make(
......
...@@ -405,7 +405,7 @@ class StoragePlanRewriter : public IRMutator { ...@@ -405,7 +405,7 @@ class StoragePlanRewriter : public IRMutator {
// Build a merged allocation // Build a merged allocation
Expr combo_size; Expr combo_size;
for (const Allocate* op : e->allocs) { for (const Allocate* op : e->allocs) {
Expr sz = arith::ComputeReduce<Mul>(op->extents); Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(Int(32), 1));
if (alloc_type.lanes() != op->type.lanes()) { if (alloc_type.lanes() != op->type.lanes()) {
sz = (sz * make_const(sz.type(), op->type.lanes()) + sz = (sz * make_const(sz.type(), op->type.lanes()) +
make_const(sz.type(), alloc_type.lanes() - 1)) / make_const(sz.type(), alloc_type.lanes() - 1)) /
......
...@@ -352,9 +352,13 @@ int TVMArrayAlloc(const tvm_index_t* shape, ...@@ -352,9 +352,13 @@ int TVMArrayAlloc(const tvm_index_t* shape,
arr->dtype.code = static_cast<uint8_t>(dtype_code); arr->dtype.code = static_cast<uint8_t>(dtype_code);
arr->dtype.bits = static_cast<uint8_t>(dtype_bits); arr->dtype.bits = static_cast<uint8_t>(dtype_bits);
arr->dtype.lanes = static_cast<uint16_t>(dtype_lanes); arr->dtype.lanes = static_cast<uint16_t>(dtype_lanes);
if (ndim != 0) {
tvm_index_t* shape_copy = new tvm_index_t[ndim]; tvm_index_t* shape_copy = new tvm_index_t[ndim];
std::copy(shape, shape + ndim, shape_copy); std::copy(shape, shape + ndim, shape_copy);
arr->shape = shape_copy; arr->shape = shape_copy;
} else {
arr->shape = nullptr;
}
// ctx // ctx
arr->ctx.device_type = static_cast<DLDeviceType>(device_type); arr->ctx.device_type = static_cast<DLDeviceType>(device_type);
arr->ctx.device_id = device_id; arr->ctx.device_id = device_id;
......
...@@ -370,8 +370,10 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) { ...@@ -370,8 +370,10 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) {
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype))) CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype)))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim); std::vector<int64_t> shape(tensor.ndim);
if (tensor.ndim != 0) {
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim)) CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim))
<< "Invalid DLTensor file format"; << "Invalid DLTensor file format";
}
CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch"; CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch";
CHECK(tensor.dtype.bits == dst->dtype.bits && CHECK(tensor.dtype.bits == dst->dtype.bits &&
tensor.dtype.code == dst->dtype.code && tensor.dtype.code == dst->dtype.code &&
......
...@@ -47,10 +47,10 @@ Expr InjectPredicate(const Array<Expr>& predicates, ...@@ -47,10 +47,10 @@ Expr InjectPredicate(const Array<Expr>& predicates,
const Reduce* reduce = body.as<Reduce>(); const Reduce* reduce = body.as<Reduce>();
if (reduce) { if (reduce) {
std::shared_ptr<Reduce> n = std::make_shared<Reduce>(*reduce); std::shared_ptr<Reduce> n = std::make_shared<Reduce>(*reduce);
n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates); n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates, Expr());
return Expr(n); return Expr(n);
} }
return Select::make(arith::ComputeReduce<ir::And>(predicates), return Select::make(arith::ComputeReduce<ir::And>(predicates, Expr()),
body, body,
make_zero(body.type())); make_zero(body.type()));
} }
...@@ -467,7 +467,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, ...@@ -467,7 +467,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
const Reduce* reduce = compute_op->body[idx].as<Reduce>(); const Reduce* reduce = compute_op->body[idx].as<Reduce>();
CHECK(reduce) << "Can only rfactor non-inline reductions"; CHECK(reduce) << "Can only rfactor non-inline reductions";
predicates.push_back(reduce->condition); predicates.push_back(reduce->condition);
Expr predicate = arith::ComputeReduce<ir::And>(predicates); Expr predicate = arith::ComputeReduce<ir::And>(predicates, Expr());
std::unordered_map<const Variable*, Expr> vsub; std::unordered_map<const Variable*, Expr> vsub;
......
...@@ -5,8 +5,8 @@ import numpy as np ...@@ -5,8 +5,8 @@ import numpy as np
def test_add_pipeline(): def test_add_pipeline():
n = tvm.var('n') n = tvm.var('n')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') C = tvm.compute(A.shape, lambda *i: A(*i) + B(), name='C')
D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D') D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D')
s = tvm.create_schedule(D.op) s = tvm.create_schedule(D.op)
...@@ -48,7 +48,7 @@ def test_add_pipeline(): ...@@ -48,7 +48,7 @@ def test_add_pipeline():
# launch the kernel. # launch the kernel.
n = 1027 n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx) d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
f(a, b, d) f(a, b, d)
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -72,7 +72,7 @@ def test_add_pipeline(): ...@@ -72,7 +72,7 @@ def test_add_pipeline():
# launch the kernel. # launch the kernel.
n = 1027 n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx) d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
f(a, b, d) f(a, b, d)
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -84,5 +84,6 @@ def test_add_pipeline(): ...@@ -84,5 +84,6 @@ def test_add_pipeline():
check_target("nvptx", host="llvm") check_target("nvptx", host="llvm")
check_target("rocm", host="llvm") check_target("rocm", host="llvm")
if __name__ == "__main__": if __name__ == "__main__":
test_add_pipeline() test_add_pipeline()
...@@ -273,7 +273,32 @@ def test_llvm_bool(): ...@@ -273,7 +273,32 @@ def test_llvm_bool():
check_llvm(64) check_llvm(64)
def test_rank_zero():
def check_llvm(n):
if not tvm.module.enabled("llvm"):
return
A = tvm.placeholder((n, ), name='A')
scale = tvm.placeholder((), name='scale')
k = tvm.reduce_axis((0, n), name="k")
C = tvm.compute((), lambda : tvm.sum(A[k] * scale, axis=k), name="C")
D = tvm.compute((), lambda : C + 1)
s = tvm.create_schedule(D.op)
# build and invoke the kernel.
f = tvm.build(s, [A, scale, D], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
sc = tvm.nd.array(
np.random.randint(0, 2, size=()).astype(scale.dtype), ctx)
d = tvm.nd.empty((), D.dtype, ctx)
f(a, sc, d)
d_np = np.sum(a.asnumpy()) * sc.asnumpy() + 1
np.testing.assert_allclose(d.asnumpy(), d_np)
check_llvm(64)
if __name__ == "__main__": if __name__ == "__main__":
test_rank_zero()
test_llvm_bool() test_llvm_bool()
test_llvm_persist_parallel() test_llvm_persist_parallel()
test_llvm_select() test_llvm_select()
......
...@@ -19,6 +19,17 @@ def test_tensor(): ...@@ -19,6 +19,17 @@ def test_tensor():
assert(T[0][0][0].astype('float16').dtype == 'float16') assert(T[0][0][0].astype('float16').dtype == 'float16')
def test_rank_zero():
m = tvm.var('m')
A = tvm.placeholder((m,), name='A')
scale = tvm.placeholder((), name='s')
k = tvm.reduce_axis((0, m), name="k")
T = tvm.compute((), lambda : tvm.sum(A[k] * scale(), axis=k))
print(T)
print(T.op.body)
assert(tuple(T.shape) == ())
def test_conv1d(): def test_conv1d():
n = tvm.var('n') n = tvm.var('n')
A = tvm.placeholder((n+2), name='A') A = tvm.placeholder((n+2), name='A')
...@@ -173,7 +184,9 @@ def test_tensor_inputs(): ...@@ -173,7 +184,9 @@ def test_tensor_inputs():
y = tvm.compute(x.shape, lambda i: x[i] + x[i]) y = tvm.compute(x.shape, lambda i: x[i] + x[i])
assert tuple(y.op.input_tensors) == (x,) assert tuple(y.op.input_tensors) == (x,)
if __name__ == "__main__": if __name__ == "__main__":
test_rank_zero()
test_tensor_inputs() test_tensor_inputs()
test_tensor_reduce_multi_axis() test_tensor_reduce_multi_axis()
test_conv1d() test_conv1d()
......
...@@ -63,7 +63,15 @@ def test_byte_array(): ...@@ -63,7 +63,15 @@ def test_byte_array():
f(a) f(a)
def test_empty_array():
def myfunc(ss):
assert tuple(ss) == ()
x = tvm.convert(())
tvm.convert(myfunc)(x)
if __name__ == "__main__": if __name__ == "__main__":
test_empty_array()
test_get_global() test_get_global()
test_get_callback_with_node() test_get_callback_with_node()
test_convert() test_convert()
......
...@@ -25,7 +25,7 @@ def dense(data, weight, bias=None): ...@@ -25,7 +25,7 @@ def dense(data, weight, bias=None):
""" """
assert len(data.shape) == 2 and len(weight.shape) == 2, \ assert len(data.shape) == 2 and len(weight.shape) == 2, \
"only support 2-dim dense" "only support 2-dim dense"
if bias: if bias is not None:
assert len(bias.shape) == 1 assert len(bias.shape) == 1
batch, in_dim = data.shape batch, in_dim = data.shape
out_dim, _ = weight.shape out_dim, _ = weight.shape
...@@ -33,7 +33,7 @@ def dense(data, weight, bias=None): ...@@ -33,7 +33,7 @@ def dense(data, weight, bias=None):
matmul = tvm.compute((batch, out_dim), \ matmul = tvm.compute((batch, out_dim), \
lambda i, j: tvm.sum(data[i, k] * weight[j, k], axis=k), \ lambda i, j: tvm.sum(data[i, k] * weight[j, k], axis=k), \
tag='dense') tag='dense')
if bias: if bias is not None:
matmul = tvm.compute((batch, out_dim), \ matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j], \ lambda i, j: matmul[i, j] + bias[j], \
tag=tag.BROADCAST) tag=tag.BROADCAST)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment