Commit 1146495f by Tianqi Chen Committed by GitHub

[RUNTIME][PASS] Allow declare vector type array (#302)

* [RUNTIME][PASS] Allow declare vector type array

* fix bcast

* [BUFFER] Enable vload/store function in buffer

* ok
parent 1e48b02f
...@@ -33,19 +33,6 @@ class Buffer : public NodeRef { ...@@ -33,19 +33,6 @@ class Buffer : public NodeRef {
Buffer() {} Buffer() {}
explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! /*!
* \brief Generate a load expression loading the index location of buffer.
* \param index The index to the buffer.
* \return The load expression.
*/
Expr MakeLoad(Array<Expr> index) const;
/*!
* \brief Generate a store statement.
* \param index The index to the buffer.
* \param value The value to be stored.
* \return The load expression.
*/
Stmt MakeStore(Array<Expr> index, Expr value) const;
/*!
* \brief Return a new buffer that is equivalent with current one * \brief Return a new buffer that is equivalent with current one
* but always add stride field. * but always add stride field.
* \return The strided version of the buffer. * \return The strided version of the buffer.
...@@ -67,6 +54,18 @@ class Buffer : public NodeRef { ...@@ -67,6 +54,18 @@ class Buffer : public NodeRef {
*/ */
Expr access_ptr(int access_mask, Type ptr_type = Handle()) const; Expr access_ptr(int access_mask, Type ptr_type = Handle()) const;
/*! /*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param dtype The data type to be loaded.
*/
Expr vload(Array<Expr> begin, Type dtype) const;
/*!
* \brief Create a Stmt that does a vector store at begin index.
* \param begin The beginning index
* \param value The value to be stored.
*/
Stmt vstore(Array<Expr> begin, Expr value) const;
/*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
*/ */
......
...@@ -142,17 +142,22 @@ class NDArrayBase(_NDArrayBase): ...@@ -142,17 +142,22 @@ class NDArrayBase(_NDArrayBase):
if value.handle is not self.handle: if value.handle is not self.handle:
value.copyto(self) value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)): elif isinstance(value, (np.ndarray, np.generic)):
self._sync_copyfrom(value) self.copyfrom(value)
else: else:
raise TypeError('type %s not supported' % str(type(value))) raise TypeError('type %s not supported' % str(type(value)))
def _sync_copyfrom(self, source_array): def copyfrom(self, source_array):
"""Peform an synchronize copy from the array. """Peform an synchronize copy from the array.
Parameters Parameters
---------- ----------
source_array : array_like source_array : array_like
The data source we should like to copy from. The data source we should like to copy from.
Returns
-------
arr : NDArray
Reference to self.
""" """
if not isinstance(source_array, np.ndarray): if not isinstance(source_array, np.ndarray):
try: try:
...@@ -160,13 +165,21 @@ class NDArrayBase(_NDArrayBase): ...@@ -160,13 +165,21 @@ class NDArrayBase(_NDArrayBase):
except: except:
raise TypeError('array must be an array_like data,' + raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array))) 'type %s is not supported' % str(type(source_array)))
source_array = np.ascontiguousarray(source_array, dtype=self.dtype) t = TVMType(self.dtype)
if source_array.shape != self.shape: shape, dtype = self.shape, self.dtype
raise ValueError('array shape do not match the shape of NDArray') if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
source_array = np.ascontiguousarray(source_array, dtype=dtype)
if source_array.shape != shape:
raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape))
assert source_array.flags['C_CONTIGUOUS'] assert source_array.flags['C_CONTIGUOUS']
data = source_array.ctypes.data_as(ctypes.c_void_p) data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np.prod(source_array.shape) * source_array.dtype.itemsize) nbytes = ctypes.c_size_t(np.prod(source_array.shape) * source_array.dtype.itemsize)
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes)) check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
return self
def asnumpy(self): def asnumpy(self):
"""Convert this array to numpy array """Convert this array to numpy array
...@@ -176,7 +189,13 @@ class NDArrayBase(_NDArrayBase): ...@@ -176,7 +189,13 @@ class NDArrayBase(_NDArrayBase):
np_arr : numpy.ndarray np_arr : numpy.ndarray
The corresponding numpy array. The corresponding numpy array.
""" """
np_arr = np.empty(self.shape, dtype=self.dtype) t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags['C_CONTIGUOUS'] assert np_arr.flags['C_CONTIGUOUS']
data = np_arr.ctypes.data_as(ctypes.c_void_p) data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np.prod(np_arr.shape) * np_arr.dtype.itemsize) nbytes = ctypes.c_size_t(np.prod(np_arr.shape) * np_arr.dtype.itemsize)
...@@ -188,7 +207,7 @@ class NDArrayBase(_NDArrayBase): ...@@ -188,7 +207,7 @@ class NDArrayBase(_NDArrayBase):
Parameters Parameters
---------- ----------
target : tvm.NDArray target : NDArray
The target array to be copied, must have same shape as this array. The target array to be copied, must have same shape as this array.
""" """
if isinstance(target, TVMContext): if isinstance(target, TVMContext):
......
...@@ -41,30 +41,36 @@ class TVMType(ctypes.Structure): ...@@ -41,30 +41,36 @@ class TVMType(ctypes.Structure):
2 : 'float', 2 : 'float',
4 : 'handle' 4 : 'handle'
} }
def __init__(self, type_str, lanes=1): def __init__(self, type_str):
super(TVMType, self).__init__() super(TVMType, self).__init__()
if isinstance(type_str, np.dtype): if isinstance(type_str, np.dtype):
type_str = str(type_str) type_str = str(type_str)
if type_str.startswith("int"): arr = type_str.split("x")
head = arr[0]
self.lanes = int(arr[1]) if len(arr) > 1 else 1
bits = 32
if head.startswith("int"):
self.type_code = 0 self.type_code = 0
bits = int(type_str[3:]) head = head[3:]
elif type_str.startswith("uint"): elif head.startswith("uint"):
self.type_code = 1 self.type_code = 1
bits = int(type_str[4:]) head = head[4:]
elif type_str.startswith("float"): elif head.startswith("float"):
self.type_code = 2 self.type_code = 2
bits = int(type_str[5:]) head = head[5:]
elif type_str.startswith("handle"): elif head.startswith("handle"):
self.type_code = 4 self.type_code = 4
bits = 64 bits = 64
head = ""
else: else:
raise ValueError("Donot know how to handle type %s" % type_str) raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits bits = int(head) if head else bits
if (bits & (bits - 1)) != 0 or bits < 8: if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str) raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits self.bits = bits
self.lanes = lanes
def __repr__(self): def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits) x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
......
...@@ -10,6 +10,7 @@ from ._ffi.node import convert_to_node as _convert_to_node ...@@ -10,6 +10,7 @@ from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import Function from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func from ._ffi.function import _init_api, register_func, get_global_func
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from ._ffi.runtime_ctypes import TVMType
from . import _api_internal from . import _api_internal
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
...@@ -546,22 +547,6 @@ def reduce_axis(dom, name="rv"): ...@@ -546,22 +547,6 @@ def reduce_axis(dom, name="rv"):
""" """
return _IterVar(dom, name, 2) return _IterVar(dom, name, 2)
def cast(dtype, expr):
"""Cast an expression to other type
Parameters
----------
dtype : str, optional
The type of new expression
expr : Expr
The expression
Returns
-------
expr : Expr
Expression with new type
"""
return _make.Cast(dtype, expr)
def select(cond, t, f): def select(cond, t, f):
"""Construct a select branch """Construct a select branch
......
...@@ -97,10 +97,11 @@ class ExprOp(object): ...@@ -97,10 +97,11 @@ class ExprOp(object):
return _make.EQ(self, other) return _make.EQ(self, other)
def astype(self, dtype): def astype(self, dtype):
"""Cast the expression to other type """Cast the expression to other type.
Parameters Parameters
---------- ----------
dtype : str, optional dtype : str
The type of new expression The type of new expression
Returns Returns
...@@ -108,7 +109,7 @@ class ExprOp(object): ...@@ -108,7 +109,7 @@ class ExprOp(object):
expr : Expr expr : Expr
Expression with new type Expression with new type
""" """
return _make.Cast(dtype, self) return _make.static_cast(dtype, self)
class Expr(NodeBase, ExprOp): class Expr(NodeBase, ExprOp):
......
...@@ -9,6 +9,7 @@ from . import ir_pass as _pass ...@@ -9,6 +9,7 @@ from . import ir_pass as _pass
from . import container as _container from . import container as _container
from ._ffi.base import string_types from ._ffi.base import string_types
from ._ffi.node import NodeGeneric from ._ffi.node import NodeGeneric
from ._ffi.runtime_ctypes import TVMType
from .expr import Call as _Call from .expr import Call as _Call
class WithScope(object): class WithScope(object):
...@@ -56,7 +57,14 @@ class BufferVar(NodeGeneric): ...@@ -56,7 +57,14 @@ class BufferVar(NodeGeneric):
def asnode(self): def asnode(self):
return self._buffer_var return self._buffer_var
@property
def dtype(self):
return self._content_type
def __getitem__(self, index): def __getitem__(self, index):
t = TVMType(self._content_type)
if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes)
return _make.Load(self._content_type, self._buffer_var, index) return _make.Load(self._content_type, self._buffer_var, index)
def __setitem__(self, index, value): def __setitem__(self, index, value):
...@@ -65,6 +73,9 @@ class BufferVar(NodeGeneric): ...@@ -65,6 +73,9 @@ class BufferVar(NodeGeneric):
raise ValueError( raise ValueError(
"data type does not match content type %s vs %s" % ( "data type does not match content type %s vs %s" % (
value.dtype, self._content_type)) value.dtype, self._content_type))
t = TVMType(self._content_type)
if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes)
self._builder.emit(_make.Store(self._buffer_var, value, index)) self._builder.emit(_make.Store(self._buffer_var, value, index))
......
...@@ -7,6 +7,7 @@ Each api is a PackedFunc that can be called in a positional argument manner. ...@@ -7,6 +7,7 @@ Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node. You can use make function to build the IR node.
""" """
from ._ffi.function import _init_api from ._ffi.function import _init_api
from ._ffi.runtime_ctypes import TVMType
from . import stmt as _stmt from . import stmt as _stmt
def range_by_min_extent(min_value, extent): def range_by_min_extent(min_value, extent):
...@@ -30,6 +31,34 @@ def range_by_min_extent(min_value, extent): ...@@ -30,6 +31,34 @@ def range_by_min_extent(min_value, extent):
return _range_by_min_extent(min_value, extent) return _range_by_min_extent(min_value, extent)
def static_cast(dtype, expr):
"""Cast expr to dtype.
If expr is scalar and dtype is a corresponding vector
type, a Broadcast is generated. Otherwise it is a Cast.
Parameters
----------
dtype : str
The target data type.
expr : Expr
The expression to be casted.
Returns
-------
casted : Expr
The casted expression.
"""
target_type = TVMType(dtype)
src_type = TVMType(expr.dtype)
if target_type.type_code == src_type.type_code\
and src_type.lanes == 1\
and target_type.lanes > 1:
return Broadcast(expr, target_type.lanes)
return Cast(dtype, expr)
def node(type_key, **kwargs): def node(type_key, **kwargs):
"""Make a new DSL node by its type key and fields """Make a new DSL node by its type key and fields
......
...@@ -126,8 +126,6 @@ def array(arr, ctx=cpu(0)): ...@@ -126,8 +126,6 @@ def array(arr, ctx=cpu(0)):
""" """
if not isinstance(arr, _np.ndarray): if not isinstance(arr, _np.ndarray):
arr = _np.array(arr) arr = _np.array(arr)
ret = empty(arr.shape, arr.dtype, ctx) return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
ret[:] = arr
return ret
_set_class_ndarray(NDArray) _set_class_ndarray(NDArray)
...@@ -65,6 +65,46 @@ class Buffer(NodeBase): ...@@ -65,6 +65,46 @@ class Buffer(NodeBase):
access_mask = mask access_mask = mask
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type) return _api_internal._BufferAccessPtr(self, access_mask, ptr_type)
def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
dtype : str
The data type to be loaded,
can be vector type which have lanes that is multiple of Buffer.dtype
Returns
-------
load : Expr
The corresponding load expression.
"""
begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
dtype = dtype if dtype else self.dtype
return _api_internal._BufferVLoad(self, begin, dtype)
def vstore(self, begin, value):
"""Generate a Stmt that store value into begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
value : Expr
The value to be stored.
Returns
-------
store : Stmt
The corresponding store stmt.
"""
begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
return _api_internal._BufferVStore(self, begin, value)
@register_node @register_node
class Split(NodeBase): class Split(NodeBase):
......
...@@ -162,6 +162,18 @@ TVM_REGISTER_API("_BufferAccessPtr") ...@@ -162,6 +162,18 @@ TVM_REGISTER_API("_BufferAccessPtr")
.access_ptr(args[1], args[2]); .access_ptr(args[1], args[2]);
}); });
TVM_REGISTER_API("_BufferVLoad")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer()
.vload(args[1], args[2]);
});
TVM_REGISTER_API("_BufferVStore")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer()
.vstore(args[1], args[2]);
});
TVM_REGISTER_API("_Tensor") TVM_REGISTER_API("_Tensor")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorNode::make(args[0], *ret = TensorNode::make(args[0],
......
...@@ -61,27 +61,38 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) { ...@@ -61,27 +61,38 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
return base; return base;
} }
// Buffer access offset. inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, Type dtype) {
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
Expr offset = ElemOffset(n, index); Expr offset = ElemOffset(n, index);
if (n->dtype.lanes() != 1) { if (n->dtype.lanes() != 1) {
offset = offset * make_const(offset.type(), n->dtype.lanes()); offset = offset * make_const(offset.type(), dtype.lanes());
}
if (dtype.lanes() != 1) {
return ir::Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes());
} else {
return offset;
} }
return offset;
} }
Expr Buffer::MakeLoad(Array<Expr> index) const { Expr Buffer::vload(Array<Expr> begin, Type dtype) const {
const BufferNode* n = operator->(); const BufferNode* n = operator->();
CHECK(dtype.element_of() == n->dtype.element_of() &&
dtype.lanes() % n->dtype.lanes() == 0)
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
return ir::Load::make( return ir::Load::make(
n->dtype, n->data, BufferOffset(n, index), dtype, n->data, BufferOffset(n, begin, dtype),
const_true(n->dtype.lanes())); const_true(dtype.lanes()));
} }
Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const { Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
const BufferNode* n = operator->(); const BufferNode* n = operator->();
CHECK_EQ(value.type(), n->dtype); Type dtype = value.type();
return ir::Store::make(n->data, value, BufferOffset(n, index), CHECK(dtype.element_of() == n->dtype.element_of() &&
const_true(n->dtype.lanes())); dtype.lanes() % n->dtype.lanes() == 0)
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
} }
Buffer Buffer::MakeStrideView() const { Buffer Buffer::MakeStrideView() const {
......
...@@ -75,7 +75,7 @@ class StorageFlattener : public IRMutator { ...@@ -75,7 +75,7 @@ class StorageFlattener : public IRMutator {
const BufferEntry& e = it->second; const BufferEntry& e = it->second;
CHECK(!e.released) CHECK(!e.released)
<< "Read a buffer that is already out of scope"; << "Read a buffer that is already out of scope";
return e.buffer.MakeStore(e.RelIndex(op->args), op->value); return e.buffer.vstore(e.RelIndex(op->args), op->value);
} }
Stmt Mutate_(const Realize* op, const Stmt& s) final { Stmt Mutate_(const Realize* op, const Stmt& s) final {
...@@ -165,7 +165,7 @@ class StorageFlattener : public IRMutator { ...@@ -165,7 +165,7 @@ class StorageFlattener : public IRMutator {
const BufferEntry& e = it->second; const BufferEntry& e = it->second;
CHECK(!e.released) CHECK(!e.released)
<< "Read a buffer that is already out of scope"; << "Read a buffer that is already out of scope";
return e.buffer.MakeLoad(e.RelIndex(op->args)); return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype);
} else { } else {
return expr; return expr;
} }
...@@ -216,7 +216,7 @@ class StorageFlattener : public IRMutator { ...@@ -216,7 +216,7 @@ class StorageFlattener : public IRMutator {
stmt = For::make( stmt = For::make(
vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::Host, stmt); vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::Host, stmt);
} else { } else {
Expr load = e.buffer.MakeLoad(e.RelIndex(args)); Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
Expr address = Call::make(Handle(), tvm_address_of, {load}, Call::PureIntrinsic); Expr address = Call::make(Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
Expr prefetch = Call::make(op->type, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic); Expr prefetch = Call::make(op->type, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
stmt = Evaluate::make(prefetch); stmt = Evaluate::make(prefetch);
......
...@@ -594,8 +594,70 @@ class StoragePlanRewriter : public IRMutator { ...@@ -594,8 +594,70 @@ class StoragePlanRewriter : public IRMutator {
std::vector<std::unique_ptr<StorageEntry> > alloc_vec_; std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
}; };
// Turn alloc into vector alloc
// if all its access is the same vector type.
class VectorAllocRewriter : public IRMutator {
public:
Expr Mutate_(const Load* op, const Expr& e) final {
UpdateTypeMap(op->buffer_var.get(), op->type);
return IRMutator::Mutate_(op, e);
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
UpdateTypeMap(op->buffer_var.get(), op->value.type());
return IRMutator::Mutate_(op, s);
}
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
UpdateTypeMap(buffer, dtype);
}
return IRMutator::Mutate_(op, e);
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
const auto& tvec = acc_map_[op->buffer_var.get()];
if (tvec.size() == 1 &&
tvec[0].element_of() == op->type.element_of() &&
tvec[0].lanes() % op->type.lanes() == 0 &&
tvec[0].lanes() != op->type.lanes()) {
int factor = tvec[0].lanes() / op->type.lanes();
Array<Expr> extents = op->extents;
arith::ModularEntry me = EvalModular(
extents[extents.size() - 1],
std::unordered_map<const Variable*, arith::ModularEntry>());
if (me.base % factor == 0 && me.coeff % factor == 0) {
extents.Set(extents.size() - 1,
extents[extents.size() - 1] / make_const(extents[0].type(), factor));
return Allocate::make(
op->buffer_var, tvec[0], extents,
op->condition, op->body);
}
}
return stmt;
}
private:
void UpdateTypeMap(const Variable* buffer, Type t) {
auto& tvec = acc_map_[buffer];
if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
tvec.push_back(t);
}
}
// Internal access map
std::unordered_map<const Variable*,
std::vector<Type> > acc_map_;
};
Stmt StorageRewrite(Stmt stmt) { Stmt StorageRewrite(Stmt stmt) {
return StoragePlanRewriter().Rewrite(stmt); stmt = StoragePlanRewriter().Rewrite(stmt);
return VectorAllocRewriter().Mutate(stmt);
} }
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -16,6 +16,11 @@ namespace ir { ...@@ -16,6 +16,11 @@ namespace ir {
inline Expr BroadcastTo(Expr e, int lanes) { inline Expr BroadcastTo(Expr e, int lanes) {
if (e.type().lanes() == lanes) return e; if (e.type().lanes() == lanes) return e;
if (const Broadcast* op = e.as<Broadcast>()) {
if (lanes % op->lanes == 0) {
return Broadcast::make(op->value, lanes);
}
}
CHECK_EQ(e.type().lanes(), 1) CHECK_EQ(e.type().lanes(), 1)
<< "Cannot broadcast lane=" << e.type().lanes() << "Cannot broadcast lane=" << e.type().lanes()
<< " to " << lanes; << " to " << lanes;
...@@ -79,6 +84,27 @@ class Vectorizer : public IRMutator { ...@@ -79,6 +84,27 @@ class Vectorizer : public IRMutator {
return AddSubVec(op, e); return AddSubVec(op, e);
} }
Expr Mutate_(const Mul* op, const Expr &e) final { Expr Mutate_(const Mul* op, const Expr &e) final {
Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
} else {
int lanes = std::max(a.type().lanes(), b.type().lanes());
if (lanes != 1) {
const Ramp* b_ramp = b.as<Ramp>();
const Ramp* a_ramp = a.as<Ramp>();
if (a_ramp && b.type().lanes() == 1 && can_prove(b > 0)) {
return Ramp::make(
a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
}
if (b_ramp && a.type().lanes() == 1 && can_prove(a > 0)) {
return Ramp::make(
b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
}
}
return Mul::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
return BinaryVec(op, e); return BinaryVec(op, e);
} }
Expr Mutate_(const Div* op, const Expr &e) final { Expr Mutate_(const Div* op, const Expr &e) final {
...@@ -114,6 +140,27 @@ class Vectorizer : public IRMutator { ...@@ -114,6 +140,27 @@ class Vectorizer : public IRMutator {
Expr Mutate_(const Or* op, const Expr &e) final { Expr Mutate_(const Or* op, const Expr &e) final {
return BinaryVec(op, e); return BinaryVec(op, e);
} }
Expr Mutate_(const Ramp* op, const Expr &e) final {
Expr base = this->Mutate(op->base);
Expr stride = this->Mutate(op->stride);
if (base.type().lanes() > 1 && stride.type().lanes() == 1) {
const Ramp* base_ramp = base.as<Ramp>();
if (can_prove(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) {
return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes);
}
}
int lanes = std::max(base.type().lanes(), stride.type().lanes());
base = BroadcastTo(base, lanes);
stride = BroadcastTo(stride, lanes);
Array<Expr> elems;
for (size_t i = 0; i < lanes; ++i) {
elems.push_back(
Ramp::make(Shuffle::make_extract_element(base, i),
Shuffle::make_extract_element(stride, i),
op->lanes));
}
return Shuffle::make_concat(elems);
}
Expr Mutate_(const Select *op, const Expr& e) final { Expr Mutate_(const Select *op, const Expr& e) final {
Expr cond = this->Mutate(op->condition); Expr cond = this->Mutate(op->condition);
Expr t = this->Mutate(op->true_value); Expr t = this->Mutate(op->true_value);
......
...@@ -9,10 +9,8 @@ def test_add_pipeline(): ...@@ -9,10 +9,8 @@ def test_add_pipeline():
def extern_generator(ins, outs): def extern_generator(ins, outs):
"""Manually write the IR for the extern function, add pipeline""" """Manually write the IR for the extern function, add pipeline"""
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
dout = ib.buffer_ptr(outs[0]) with ib.for_range(0, n/2) as i:
din = ib.buffer_ptr(ins[0]) ib.emit(outs[0].vstore(i*2, ins[0].vload(i*2, "float32x2") + tvm.const(1, "float32x2")))
with ib.for_range(0, n) as i:
dout[i] = din[i] + 1
return ib.get() return ib.get()
C = tvm.extern(A.shape, [A], extern_generator, name='C') C = tvm.extern(A.shape, [A], extern_generator, name='C')
......
...@@ -88,6 +88,32 @@ def test_llvm_flip_pipeline(): ...@@ -88,6 +88,32 @@ def test_llvm_flip_pipeline():
check_llvm(128, 1) check_llvm(128, 1)
def test_llvm_vadd_pipeline():
def check_llvm(n, lanes):
if not tvm.module.enabled("llvm"):
return
A = tvm.placeholder((n,), name='A', dtype="float32x%d" % lanes)
B = tvm.compute((n,), lambda i: A[i], name='B')
C = tvm.compute((n,), lambda i: B[i] + tvm.const(1, A.dtype), name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=2)
s[C].parallel(xo)
s[C].vectorize(xi)
xo, xi = s[B].split(B.op.axis[0], factor=2)
s[B].vectorize(xi)
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.empty((n,), A.dtype).copyfrom(
np.random.uniform(size=(n, lanes)))
c = tvm.nd.empty((n,), C.dtype, ctx)
f(a, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + 1)
check_llvm(64, 2)
def test_llvm_madd_pipeline(): def test_llvm_madd_pipeline():
def check_llvm(nn, base, stride): def check_llvm(nn, base, stride):
if not tvm.module.enabled("llvm"): if not tvm.module.enabled("llvm"):
...@@ -114,6 +140,7 @@ def test_llvm_madd_pipeline(): ...@@ -114,6 +140,7 @@ def test_llvm_madd_pipeline():
with tvm.build_config(restricted_func=False): with tvm.build_config(restricted_func=False):
check_llvm(4, 0, 3) check_llvm(4, 0, 3)
def test_llvm_temp_space(): def test_llvm_temp_space():
nn = 1024 nn = 1024
n = tvm.convert(nn) n = tvm.convert(nn)
...@@ -172,6 +199,7 @@ def test_multiple_func(): ...@@ -172,6 +199,7 @@ def test_multiple_func():
if __name__ == "__main__": if __name__ == "__main__":
test_llvm_vadd_pipeline()
test_llvm_add_pipeline() test_llvm_add_pipeline()
test_llvm_intrin() test_llvm_intrin()
test_multiple_func() test_multiple_func()
......
...@@ -31,6 +31,15 @@ def test_let(): ...@@ -31,6 +31,15 @@ def test_let():
stmt = tvm.make.LetStmt( stmt = tvm.make.LetStmt(
x, 10, tvm.make.Evaluate(x + 1)); x, 10, tvm.make.Evaluate(x + 1));
def test_cast():
x = tvm.var('x', dtype="float32")
y = x.astype("int32")
z = x.astype("float32x4")
assert isinstance(y, tvm.expr.Cast)
assert isinstance(z, tvm.expr.Broadcast)
assert z.lanes == 4
def test_attr(): def test_attr():
x = tvm.var('x') x = tvm.var('x')
y = tvm.var('y') y = tvm.var('y')
...@@ -116,6 +125,7 @@ def test_all(): ...@@ -116,6 +125,7 @@ def test_all():
if __name__ == "__main__": if __name__ == "__main__":
test_cast()
test_attr() test_attr()
test_const() test_const()
test_make() test_make()
......
...@@ -16,7 +16,6 @@ def test_tensor(): ...@@ -16,7 +16,6 @@ def test_tensor():
assert(T.op.output(0).__hash__() == T.__hash__()) assert(T.op.output(0).__hash__() == T.__hash__())
d = {T.op.output(0) : 1} d = {T.op.output(0) : 1}
assert(d[T] == 1) assert(d[T] == 1)
assert(tvm.cast('float16', T[0][0][0]).dtype == 'float16')
assert(T[0][0][0].astype('float16').dtype == 'float16') assert(T[0][0][0].astype('float16').dtype == 'float16')
......
...@@ -7,8 +7,9 @@ def test_vectorize_loop(): ...@@ -7,8 +7,9 @@ def test_vectorize_loop():
A = ib.pointer("float32", name="A") A = ib.pointer("float32", name="A")
with ib.for_range(0, n) as i: with ib.for_range(0, n) as i:
with ib.for_range(0, 4, for_type="vectorize") as j: with ib.for_range(0, 4, for_type="vectorize") as j:
A[j + 1] = A[i] + 1 A[j] = tvm.const(1, A.dtype)
stmt = ib.get() stmt = ib.get()
assert isinstance(stmt.body, tvm.stmt.For) assert isinstance(stmt.body, tvm.stmt.For)
stmt = tvm.ir_pass.VectorizeLoop(stmt) stmt = tvm.ir_pass.VectorizeLoop(stmt)
assert isinstance(stmt, tvm.stmt.For) assert isinstance(stmt, tvm.stmt.For)
...@@ -16,6 +17,23 @@ def test_vectorize_loop(): ...@@ -16,6 +17,23 @@ def test_vectorize_loop():
assert isinstance(stmt.body.index, tvm.expr.Ramp) assert isinstance(stmt.body.index, tvm.expr.Ramp)
assert isinstance(stmt.body.value, tvm.expr.Broadcast) assert isinstance(stmt.body.value, tvm.expr.Broadcast)
def test_vectorize_vector():
dtype = 'int64'
n = tvm.var('n')
ib = tvm.ir_builder.create()
A = ib.pointer("float32x4", name="A")
with ib.for_range(0, n) as i:
with ib.for_range(0, 4, for_type="vectorize") as j:
A[j] = tvm.const(1, A.dtype)
stmt = ib.get()
assert isinstance(stmt.body, tvm.stmt.For)
stmt = tvm.ir_pass.VectorizeLoop(stmt)
assert isinstance(stmt, tvm.stmt.For)
assert not isinstance(stmt.body, tvm.stmt.For)
assert isinstance(stmt.body.index, tvm.expr.Ramp)
assert isinstance(stmt.body.value, tvm.expr.Broadcast)
def test_vectorize_with_if(): def test_vectorize_with_if():
n = tvm.var('n') n = tvm.var('n')
x = tvm.var('x') x = tvm.var('x')
...@@ -36,5 +54,6 @@ def test_vectorize_with_if(): ...@@ -36,5 +54,6 @@ def test_vectorize_with_if():
assert isinstance(stmt.else_case, tvm.stmt.For) assert isinstance(stmt.else_case, tvm.stmt.For)
if __name__ == "__main__": if __name__ == "__main__":
test_vectorize_vector()
test_vectorize_with_if() test_vectorize_with_if()
test_vectorize_loop() test_vectorize_loop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment