Commit 1146495f by Tianqi Chen Committed by GitHub

[RUNTIME][PASS] Allow declare vector type array (#302)

* [RUNTIME][PASS] Allow declare vector type array

* fix bcast

* [BUFFER] Enable vload/store function in buffer

* ok
parent 1e48b02f
......@@ -33,19 +33,6 @@ class Buffer : public NodeRef {
Buffer() {}
explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief Generate a load expression loading the index location of buffer.
* \param index The index to the buffer.
* \return The load expression.
*/
Expr MakeLoad(Array<Expr> index) const;
/*!
* \brief Generate a store statement.
* \param index The index to the buffer.
* \param value The value to be stored.
* \return The load expression.
*/
Stmt MakeStore(Array<Expr> index, Expr value) const;
/*!
* \brief Return a new buffer that is equivalent with current one
* but always add stride field.
* \return The strided version of the buffer.
......@@ -67,6 +54,18 @@ class Buffer : public NodeRef {
*/
Expr access_ptr(int access_mask, Type ptr_type = Handle()) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param dtype The data type to be loaded.
*/
Expr vload(Array<Expr> begin, Type dtype) const;
/*!
* \brief Create a Stmt that does a vector store at begin index.
* \param begin The beginning index
* \param value The value to be stored.
*/
Stmt vstore(Array<Expr> begin, Expr value) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
......
......@@ -142,17 +142,22 @@ class NDArrayBase(_NDArrayBase):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)):
self._sync_copyfrom(value)
self.copyfrom(value)
else:
raise TypeError('type %s not supported' % str(type(value)))
def _sync_copyfrom(self, source_array):
def copyfrom(self, source_array):
"""Peform an synchronize copy from the array.
Parameters
----------
source_array : array_like
The data source we should like to copy from.
Returns
-------
arr : NDArray
Reference to self.
"""
if not isinstance(source_array, np.ndarray):
try:
......@@ -160,13 +165,21 @@ class NDArrayBase(_NDArrayBase):
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))
source_array = np.ascontiguousarray(source_array, dtype=self.dtype)
if source_array.shape != self.shape:
raise ValueError('array shape do not match the shape of NDArray')
t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
source_array = np.ascontiguousarray(source_array, dtype=dtype)
if source_array.shape != shape:
raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape))
assert source_array.flags['C_CONTIGUOUS']
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np.prod(source_array.shape) * source_array.dtype.itemsize)
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
return self
def asnumpy(self):
"""Convert this array to numpy array
......@@ -176,7 +189,13 @@ class NDArrayBase(_NDArrayBase):
np_arr : numpy.ndarray
The corresponding numpy array.
"""
np_arr = np.empty(self.shape, dtype=self.dtype)
t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags['C_CONTIGUOUS']
data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np.prod(np_arr.shape) * np_arr.dtype.itemsize)
......@@ -188,7 +207,7 @@ class NDArrayBase(_NDArrayBase):
Parameters
----------
target : tvm.NDArray
target : NDArray
The target array to be copied, must have same shape as this array.
"""
if isinstance(target, TVMContext):
......
......@@ -41,30 +41,36 @@ class TVMType(ctypes.Structure):
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str, lanes=1):
def __init__(self, type_str):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
arr = type_str.split("x")
head = arr[0]
self.lanes = int(arr[1]) if len(arr) > 1 else 1
bits = 32
if head.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
head = head[3:]
elif head.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
head = head[4:]
elif head.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
elif type_str.startswith("handle"):
head = head[5:]
elif head.startswith("handle"):
self.type_code = 4
bits = 64
head = ""
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
bits = int(head) if head else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
......
......@@ -10,6 +10,7 @@ from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from ._ffi.runtime_ctypes import TVMType
from . import _api_internal
from . import make as _make
from . import expr as _expr
......@@ -546,22 +547,6 @@ def reduce_axis(dom, name="rv"):
"""
return _IterVar(dom, name, 2)
def cast(dtype, expr):
"""Cast an expression to other type
Parameters
----------
dtype : str, optional
The type of new expression
expr : Expr
The expression
Returns
-------
expr : Expr
Expression with new type
"""
return _make.Cast(dtype, expr)
def select(cond, t, f):
"""Construct a select branch
......
......@@ -97,10 +97,11 @@ class ExprOp(object):
return _make.EQ(self, other)
def astype(self, dtype):
"""Cast the expression to other type
"""Cast the expression to other type.
Parameters
----------
dtype : str, optional
dtype : str
The type of new expression
Returns
......@@ -108,7 +109,7 @@ class ExprOp(object):
expr : Expr
Expression with new type
"""
return _make.Cast(dtype, self)
return _make.static_cast(dtype, self)
class Expr(NodeBase, ExprOp):
......
......@@ -9,6 +9,7 @@ from . import ir_pass as _pass
from . import container as _container
from ._ffi.base import string_types
from ._ffi.node import NodeGeneric
from ._ffi.runtime_ctypes import TVMType
from .expr import Call as _Call
class WithScope(object):
......@@ -56,7 +57,14 @@ class BufferVar(NodeGeneric):
def asnode(self):
return self._buffer_var
@property
def dtype(self):
return self._content_type
def __getitem__(self, index):
t = TVMType(self._content_type)
if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes)
return _make.Load(self._content_type, self._buffer_var, index)
def __setitem__(self, index, value):
......@@ -65,6 +73,9 @@ class BufferVar(NodeGeneric):
raise ValueError(
"data type does not match content type %s vs %s" % (
value.dtype, self._content_type))
t = TVMType(self._content_type)
if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes)
self._builder.emit(_make.Store(self._buffer_var, value, index))
......
......@@ -7,6 +7,7 @@ Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node.
"""
from ._ffi.function import _init_api
from ._ffi.runtime_ctypes import TVMType
from . import stmt as _stmt
def range_by_min_extent(min_value, extent):
......@@ -30,6 +31,34 @@ def range_by_min_extent(min_value, extent):
return _range_by_min_extent(min_value, extent)
def static_cast(dtype, expr):
"""Cast expr to dtype.
If expr is scalar and dtype is a corresponding vector
type, a Broadcast is generated. Otherwise it is a Cast.
Parameters
----------
dtype : str
The target data type.
expr : Expr
The expression to be casted.
Returns
-------
casted : Expr
The casted expression.
"""
target_type = TVMType(dtype)
src_type = TVMType(expr.dtype)
if target_type.type_code == src_type.type_code\
and src_type.lanes == 1\
and target_type.lanes > 1:
return Broadcast(expr, target_type.lanes)
return Cast(dtype, expr)
def node(type_key, **kwargs):
"""Make a new DSL node by its type key and fields
......
......@@ -126,8 +126,6 @@ def array(arr, ctx=cpu(0)):
"""
if not isinstance(arr, _np.ndarray):
arr = _np.array(arr)
ret = empty(arr.shape, arr.dtype, ctx)
ret[:] = arr
return ret
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
_set_class_ndarray(NDArray)
......@@ -65,6 +65,46 @@ class Buffer(NodeBase):
access_mask = mask
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type)
def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
dtype : str
The data type to be loaded,
can be vector type which have lanes that is multiple of Buffer.dtype
Returns
-------
load : Expr
The corresponding load expression.
"""
begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
dtype = dtype if dtype else self.dtype
return _api_internal._BufferVLoad(self, begin, dtype)
def vstore(self, begin, value):
"""Generate a Stmt that store value into begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
value : Expr
The value to be stored.
Returns
-------
store : Stmt
The corresponding store stmt.
"""
begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
return _api_internal._BufferVStore(self, begin, value)
@register_node
class Split(NodeBase):
......
......@@ -162,6 +162,18 @@ TVM_REGISTER_API("_BufferAccessPtr")
.access_ptr(args[1], args[2]);
});
TVM_REGISTER_API("_BufferVLoad")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer()
.vload(args[1], args[2]);
});
TVM_REGISTER_API("_BufferVStore")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer()
.vstore(args[1], args[2]);
});
TVM_REGISTER_API("_Tensor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorNode::make(args[0],
......
......@@ -61,27 +61,38 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
return base;
}
// Buffer access offset.
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, Type dtype) {
Expr offset = ElemOffset(n, index);
if (n->dtype.lanes() != 1) {
offset = offset * make_const(offset.type(), n->dtype.lanes());
offset = offset * make_const(offset.type(), dtype.lanes());
}
if (dtype.lanes() != 1) {
return ir::Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes());
} else {
return offset;
}
}
Expr Buffer::MakeLoad(Array<Expr> index) const {
Expr Buffer::vload(Array<Expr> begin, Type dtype) const {
const BufferNode* n = operator->();
CHECK(dtype.element_of() == n->dtype.element_of() &&
dtype.lanes() % n->dtype.lanes() == 0)
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
return ir::Load::make(
n->dtype, n->data, BufferOffset(n, index),
const_true(n->dtype.lanes()));
dtype, n->data, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
}
Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
const BufferNode* n = operator->();
CHECK_EQ(value.type(), n->dtype);
return ir::Store::make(n->data, value, BufferOffset(n, index),
const_true(n->dtype.lanes()));
Type dtype = value.type();
CHECK(dtype.element_of() == n->dtype.element_of() &&
dtype.lanes() % n->dtype.lanes() == 0)
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
}
Buffer Buffer::MakeStrideView() const {
......
......@@ -75,7 +75,7 @@ class StorageFlattener : public IRMutator {
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
return e.buffer.MakeStore(e.RelIndex(op->args), op->value);
return e.buffer.vstore(e.RelIndex(op->args), op->value);
}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
......@@ -165,7 +165,7 @@ class StorageFlattener : public IRMutator {
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
return e.buffer.MakeLoad(e.RelIndex(op->args));
return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype);
} else {
return expr;
}
......@@ -216,7 +216,7 @@ class StorageFlattener : public IRMutator {
stmt = For::make(
vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::Host, stmt);
} else {
Expr load = e.buffer.MakeLoad(e.RelIndex(args));
Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
Expr address = Call::make(Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
Expr prefetch = Call::make(op->type, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
stmt = Evaluate::make(prefetch);
......
......@@ -594,8 +594,70 @@ class StoragePlanRewriter : public IRMutator {
std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
};
// Turn alloc into vector alloc
// if all its access is the same vector type.
class VectorAllocRewriter : public IRMutator {
public:
Expr Mutate_(const Load* op, const Expr& e) final {
UpdateTypeMap(op->buffer_var.get(), op->type);
return IRMutator::Mutate_(op, e);
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
UpdateTypeMap(op->buffer_var.get(), op->value.type());
return IRMutator::Mutate_(op, s);
}
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
UpdateTypeMap(buffer, dtype);
}
return IRMutator::Mutate_(op, e);
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
const auto& tvec = acc_map_[op->buffer_var.get()];
if (tvec.size() == 1 &&
tvec[0].element_of() == op->type.element_of() &&
tvec[0].lanes() % op->type.lanes() == 0 &&
tvec[0].lanes() != op->type.lanes()) {
int factor = tvec[0].lanes() / op->type.lanes();
Array<Expr> extents = op->extents;
arith::ModularEntry me = EvalModular(
extents[extents.size() - 1],
std::unordered_map<const Variable*, arith::ModularEntry>());
if (me.base % factor == 0 && me.coeff % factor == 0) {
extents.Set(extents.size() - 1,
extents[extents.size() - 1] / make_const(extents[0].type(), factor));
return Allocate::make(
op->buffer_var, tvec[0], extents,
op->condition, op->body);
}
}
return stmt;
}
private:
void UpdateTypeMap(const Variable* buffer, Type t) {
auto& tvec = acc_map_[buffer];
if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
tvec.push_back(t);
}
}
// Internal access map
std::unordered_map<const Variable*,
std::vector<Type> > acc_map_;
};
Stmt StorageRewrite(Stmt stmt) {
return StoragePlanRewriter().Rewrite(stmt);
stmt = StoragePlanRewriter().Rewrite(stmt);
return VectorAllocRewriter().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -16,6 +16,11 @@ namespace ir {
inline Expr BroadcastTo(Expr e, int lanes) {
if (e.type().lanes() == lanes) return e;
if (const Broadcast* op = e.as<Broadcast>()) {
if (lanes % op->lanes == 0) {
return Broadcast::make(op->value, lanes);
}
}
CHECK_EQ(e.type().lanes(), 1)
<< "Cannot broadcast lane=" << e.type().lanes()
<< " to " << lanes;
......@@ -79,6 +84,27 @@ class Vectorizer : public IRMutator {
return AddSubVec(op, e);
}
Expr Mutate_(const Mul* op, const Expr &e) final {
Expr a = this->Mutate(op->a);
Expr b = this->Mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
} else {
int lanes = std::max(a.type().lanes(), b.type().lanes());
if (lanes != 1) {
const Ramp* b_ramp = b.as<Ramp>();
const Ramp* a_ramp = a.as<Ramp>();
if (a_ramp && b.type().lanes() == 1 && can_prove(b > 0)) {
return Ramp::make(
a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
}
if (b_ramp && a.type().lanes() == 1 && can_prove(a > 0)) {
return Ramp::make(
b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
}
}
return Mul::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
return BinaryVec(op, e);
}
Expr Mutate_(const Div* op, const Expr &e) final {
......@@ -114,6 +140,27 @@ class Vectorizer : public IRMutator {
Expr Mutate_(const Or* op, const Expr &e) final {
return BinaryVec(op, e);
}
Expr Mutate_(const Ramp* op, const Expr &e) final {
Expr base = this->Mutate(op->base);
Expr stride = this->Mutate(op->stride);
if (base.type().lanes() > 1 && stride.type().lanes() == 1) {
const Ramp* base_ramp = base.as<Ramp>();
if (can_prove(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) {
return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes);
}
}
int lanes = std::max(base.type().lanes(), stride.type().lanes());
base = BroadcastTo(base, lanes);
stride = BroadcastTo(stride, lanes);
Array<Expr> elems;
for (size_t i = 0; i < lanes; ++i) {
elems.push_back(
Ramp::make(Shuffle::make_extract_element(base, i),
Shuffle::make_extract_element(stride, i),
op->lanes));
}
return Shuffle::make_concat(elems);
}
Expr Mutate_(const Select *op, const Expr& e) final {
Expr cond = this->Mutate(op->condition);
Expr t = this->Mutate(op->true_value);
......
......@@ -9,10 +9,8 @@ def test_add_pipeline():
def extern_generator(ins, outs):
"""Manually write the IR for the extern function, add pipeline"""
ib = tvm.ir_builder.create()
dout = ib.buffer_ptr(outs[0])
din = ib.buffer_ptr(ins[0])
with ib.for_range(0, n) as i:
dout[i] = din[i] + 1
with ib.for_range(0, n/2) as i:
ib.emit(outs[0].vstore(i*2, ins[0].vload(i*2, "float32x2") + tvm.const(1, "float32x2")))
return ib.get()
C = tvm.extern(A.shape, [A], extern_generator, name='C')
......
......@@ -88,6 +88,32 @@ def test_llvm_flip_pipeline():
check_llvm(128, 1)
def test_llvm_vadd_pipeline():
def check_llvm(n, lanes):
if not tvm.module.enabled("llvm"):
return
A = tvm.placeholder((n,), name='A', dtype="float32x%d" % lanes)
B = tvm.compute((n,), lambda i: A[i], name='B')
C = tvm.compute((n,), lambda i: B[i] + tvm.const(1, A.dtype), name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=2)
s[C].parallel(xo)
s[C].vectorize(xi)
xo, xi = s[B].split(B.op.axis[0], factor=2)
s[B].vectorize(xi)
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.empty((n,), A.dtype).copyfrom(
np.random.uniform(size=(n, lanes)))
c = tvm.nd.empty((n,), C.dtype, ctx)
f(a, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + 1)
check_llvm(64, 2)
def test_llvm_madd_pipeline():
def check_llvm(nn, base, stride):
if not tvm.module.enabled("llvm"):
......@@ -114,6 +140,7 @@ def test_llvm_madd_pipeline():
with tvm.build_config(restricted_func=False):
check_llvm(4, 0, 3)
def test_llvm_temp_space():
nn = 1024
n = tvm.convert(nn)
......@@ -172,6 +199,7 @@ def test_multiple_func():
if __name__ == "__main__":
test_llvm_vadd_pipeline()
test_llvm_add_pipeline()
test_llvm_intrin()
test_multiple_func()
......
......@@ -31,6 +31,15 @@ def test_let():
stmt = tvm.make.LetStmt(
x, 10, tvm.make.Evaluate(x + 1));
def test_cast():
x = tvm.var('x', dtype="float32")
y = x.astype("int32")
z = x.astype("float32x4")
assert isinstance(y, tvm.expr.Cast)
assert isinstance(z, tvm.expr.Broadcast)
assert z.lanes == 4
def test_attr():
x = tvm.var('x')
y = tvm.var('y')
......@@ -116,6 +125,7 @@ def test_all():
if __name__ == "__main__":
test_cast()
test_attr()
test_const()
test_make()
......
......@@ -16,7 +16,6 @@ def test_tensor():
assert(T.op.output(0).__hash__() == T.__hash__())
d = {T.op.output(0) : 1}
assert(d[T] == 1)
assert(tvm.cast('float16', T[0][0][0]).dtype == 'float16')
assert(T[0][0][0].astype('float16').dtype == 'float16')
......
......@@ -7,8 +7,9 @@ def test_vectorize_loop():
A = ib.pointer("float32", name="A")
with ib.for_range(0, n) as i:
with ib.for_range(0, 4, for_type="vectorize") as j:
A[j + 1] = A[i] + 1
A[j] = tvm.const(1, A.dtype)
stmt = ib.get()
assert isinstance(stmt.body, tvm.stmt.For)
stmt = tvm.ir_pass.VectorizeLoop(stmt)
assert isinstance(stmt, tvm.stmt.For)
......@@ -16,6 +17,23 @@ def test_vectorize_loop():
assert isinstance(stmt.body.index, tvm.expr.Ramp)
assert isinstance(stmt.body.value, tvm.expr.Broadcast)
def test_vectorize_vector():
dtype = 'int64'
n = tvm.var('n')
ib = tvm.ir_builder.create()
A = ib.pointer("float32x4", name="A")
with ib.for_range(0, n) as i:
with ib.for_range(0, 4, for_type="vectorize") as j:
A[j] = tvm.const(1, A.dtype)
stmt = ib.get()
assert isinstance(stmt.body, tvm.stmt.For)
stmt = tvm.ir_pass.VectorizeLoop(stmt)
assert isinstance(stmt, tvm.stmt.For)
assert not isinstance(stmt.body, tvm.stmt.For)
assert isinstance(stmt.body.index, tvm.expr.Ramp)
assert isinstance(stmt.body.value, tvm.expr.Broadcast)
def test_vectorize_with_if():
n = tvm.var('n')
x = tvm.var('x')
......@@ -36,5 +54,6 @@ def test_vectorize_with_if():
assert isinstance(stmt.else_case, tvm.stmt.For)
if __name__ == "__main__":
test_vectorize_vector()
test_vectorize_with_if()
test_vectorize_loop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment