Commit 19194e97 by Leyuan Wang Committed by Yizhi Liu

[Relay/TOPI][Frontend] Add tile and repeat operators in Relay and TOPI (#2720)

* tile and repeat operator added in rely

* fix pylint

* fix make warnings

* comments addressed

* fix lint error

* comment addressed
parent 801068f3
...@@ -73,6 +73,8 @@ List of operators ...@@ -73,6 +73,8 @@ List of operators
topi.logical_not topi.logical_not
topi.arange topi.arange
topi.stack topi.stack
topi.repeat
topi.tile
topi.layout_transform topi.layout_transform
topi.image.resize topi.image.resize
...@@ -132,6 +134,8 @@ topi ...@@ -132,6 +134,8 @@ topi
.. autofunction:: topi.less .. autofunction:: topi.less
.. autofunction:: topi.arange .. autofunction:: topi.arange
.. autofunction:: topi.stack .. autofunction:: topi.stack
.. autofunction:: topi.repeat
.. autofunction:: topi.tile
.. autofunction:: topi.layout_transform .. autofunction:: topi.layout_transform
topi.nn topi.nn
......
...@@ -97,6 +97,8 @@ This level enables additional math and transform operators. ...@@ -97,6 +97,8 @@ This level enables additional math and transform operators.
tvm.relay.split tvm.relay.split
tvm.relay.arange tvm.relay.arange
tvm.relay.stack tvm.relay.stack
tvm.relay.repeat
tvm.relay.tile
**Level 4: Broadcast and Reductions** **Level 4: Broadcast and Reductions**
...@@ -225,6 +227,8 @@ Level 3 Definitions ...@@ -225,6 +227,8 @@ Level 3 Definitions
.. autofunction:: tvm.relay.split .. autofunction:: tvm.relay.split
.. autofunction:: tvm.relay.arange .. autofunction:: tvm.relay.arange
.. autofunction:: tvm.relay.stack .. autofunction:: tvm.relay.stack
.. autofunction:: tvm.relay.repeat
.. autofunction:: tvm.relay.tile
Level 4 Definitions Level 4 Definitions
......
...@@ -124,6 +124,28 @@ struct StackAttrs : public tvm::AttrsNode<StackAttrs> { ...@@ -124,6 +124,28 @@ struct StackAttrs : public tvm::AttrsNode<StackAttrs> {
} }
}; // struct StackAttrs }; // struct StackAttrs
/*! \brief Attributes used in repeat operators */
struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
Integer repeats;
Integer axis;
TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") {
TVM_ATTR_FIELD(repeats)
.describe("The number of repetitions for each element.");
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
.describe(" The axis along which to repeat values.");
}
}; // struct RepeatAttrs
/*! \brief Attributes used in tile operators */
struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
Array<Integer> reps;
TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs") {
TVM_ATTR_FIELD(reps)
.describe("The number of times for repeating the tensor a."
"Each dim sizeof reps must be a positive integer.");
}
}; // struct TileAttrs
/*! \brief Attributes used in squeeze operators */ /*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> { struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
// use axis to make the name numpy compatible. // use axis to make the name numpy compatible.
......
...@@ -166,6 +166,10 @@ def _mx_dropout(inputs, attrs): ...@@ -166,6 +166,10 @@ def _mx_dropout(inputs, attrs):
return _op.nn.dropout(inputs[0], rate=rate) return _op.nn.dropout(inputs[0], rate=rate)
def _mx_BlockGrad(inputs, attrs): #pylint: disable=unused-argument
return inputs
def _mx_batch_norm(inputs, attrs): def _mx_batch_norm(inputs, attrs):
if attrs.get_bool("output_mean_var", False): if attrs.get_bool("output_mean_var", False):
raise RuntimeError("batch_norm do not support output_mean_var") raise RuntimeError("batch_norm do not support output_mean_var")
...@@ -357,6 +361,21 @@ def _mx_arange(inputs, attrs): ...@@ -357,6 +361,21 @@ def _mx_arange(inputs, attrs):
return _op.arange(**new_attrs) return _op.arange(**new_attrs)
def _mx_repeat(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["repeats"] = attrs.get_int("repeats")
new_attrs["axis"] = attrs.get_int("axis", 0)
return _op.repeat(inputs[0], **new_attrs)
def _mx_tile(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["reps"] = attrs.get_int_tuple("reps")
return _op.tile(inputs[0], **new_attrs)
def _mx_roi_align(inputs, attrs): def _mx_roi_align(inputs, attrs):
new_attrs = {} new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size") new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
...@@ -490,6 +509,9 @@ _convert_map = { ...@@ -490,6 +509,9 @@ _convert_map = {
"batch_dot" : _mx_batch_dot, "batch_dot" : _mx_batch_dot,
"LeakyReLU" : _mx_leaky_relu, "LeakyReLU" : _mx_leaky_relu,
"_arange" : _mx_arange, "_arange" : _mx_arange,
"repeat" : _mx_repeat,
"tile" : _mx_tile,
"BlockGrad" : _mx_BlockGrad,
"SoftmaxOutput" : _mx_softmax_output, "SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation, "SoftmaxActivation" : _mx_softmax_activation,
# vision # vision
......
...@@ -19,6 +19,8 @@ _reg.register_schedule("reshape_like", schedule_injective) ...@@ -19,6 +19,8 @@ _reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective) _reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective) _reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("arange", schedule_injective) _reg.register_schedule("arange", schedule_injective)
_reg.register_schedule("repeat", schedule_broadcast)
_reg.register_schedule("tile", schedule_broadcast)
_reg.register_schedule("cast", schedule_injective) _reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective) _reg.register_schedule("slice_like", schedule_injective)
......
...@@ -316,6 +316,75 @@ def stack(data, axis): ...@@ -316,6 +316,75 @@ def stack(data, axis):
return _make.stack(data, axis) return _make.stack(data, axis)
def repeat(data, repeats, axis):
"""Repeats elements of an array.
By default, repeat flattens the input array into 1-D and then repeats the elements.
repeats : int
The number of repetitions for each element.
axis: int
The axis along which to repeat values. The negative numbers are interpreted
counting from the backward. By default, use the flattened input array, and
return a flat output array.
Returns
-------
ret : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
x = [[1, 2], [3, 4]]
relay.repeat(x, repeats=2) = [1., 1., 2., 2., 3., 3., 4., 4.]
relay.repeat(x, repeats=2, axis=1) = [[1., 1., 2., 2.],
[3., 3., 4., 4.]]
"""
return _make.repeat(data, repeats, axis)
def tile(data, reps):
"""Repeats the whole array multiple times.
Parameters
----------
data : relay.Expr
The input data to the operator.
reps : tuple of int
The number of times repeating the tensor data.
.. note::
Each dim size of reps must be a positive integer. If reps has length d,
the result will have dimension of max(d, data.ndim); If data.ndim < d,
data is promoted to be d-dimensional by prepending new axes.
If data.ndim >= d, reps is promoted to a.ndim by pre-pending 1's to it.
Returns
-------
ret : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
x = [[1, 2], [3, 4]]
relay.tile(x, reps=(2,3)) = [[1., 2., 1., 2., 1., 2.],
[3., 4., 3., 4., 3., 4.],
[1., 2., 1., 2., 1., 2.],
[3., 4., 3., 4., 3., 4.]]
relay.tile(x, reps=(2,)) = [[1., 2., 1., 2.],
[3., 4., 3., 4.]]
"""
return _make.tile(data, reps)
def where(condition, x, y): def where(condition, x, y):
"""Selecting elements from either x or y depending on the value of the """Selecting elements from either x or y depending on the value of the
condition. condition.
......
...@@ -1035,6 +1035,175 @@ RELAY_REGISTER_OP("arange") ...@@ -1035,6 +1035,175 @@ RELAY_REGISTER_OP("arange")
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute) .set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
// repeat operator
TVM_REGISTER_NODE_TYPE(RepeatAttrs);
bool RepeatRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "repeat: expect input type to be TensorType but get "
<< types[0];
return false;
}
const auto* param = attrs.as<RepeatAttrs>();
const int ndim = static_cast<int>(data->shape.size());
const int repeats = param->repeats;
const int axis = param->axis;
CHECK(repeats >= 1)
<< "repeat only accepts `repeats >= 1`"
<< ", but got repeats = " << repeats;
CHECK(-ndim - 1 <= axis && axis <= ndim)
<< "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
<< ", but got axis = " << axis
<< ", and data.ndim = " << ndim;
const int pivot = axis < 0 ? ndim + axis : axis;
std::vector<IndexExpr> oshape;
oshape.reserve(ndim + repeats);
for (int i = 0; i < pivot; ++i) {
oshape.emplace_back(data->shape[i]);
}
oshape.emplace_back(data->shape[pivot] * repeats);
for (int i = pivot + 1; i < ndim; ++i) {
oshape.emplace_back(data->shape[i]);
}
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Array<Tensor> RepeatCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const RepeatAttrs *param = attrs.as<RepeatAttrs>();
CHECK(param != nullptr);
return { topi::repeat(inputs[0], param->repeats, param->axis) };
}
Expr MakeRepeat(Expr data,
int repeats,
int axis) {
auto attrs = make_node<RepeatAttrs>();
attrs->repeats = repeats;
attrs->axis = axis;
static const Op& op = Op::Get("repeat");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.repeat")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeRepeat, args, rv);
});
RELAY_REGISTER_OP("repeat")
.describe(R"code(Repeat elements of an array `repeats` times along axis `axis`
- **data**: The input data to the operator.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.Repeat")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("Repeat", RepeatRel)
.set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
// tile operator
TVM_REGISTER_NODE_TYPE(TileAttrs);
bool TileRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "tile: expect input type to be TensorType but get "
<< types[0];
return false;
}
const auto* param = attrs.as<TileAttrs>();
const size_t ndim = data->shape.size();
const Array<Integer>& reps = param->reps;
// check dimension match
CHECK(!reps.defined())
<< "repetition array is not defined. data.ndim = " << ndim;
const size_t rndim = reps.size();
size_t tndim = (ndim > rndim) ? ndim : rndim;
// re-construct data shape or reps shape
std::vector<IndexExpr> data_shape;
std::vector<IndexExpr> reps_shape;
data_shape.reserve(tndim);
reps_shape.reserve(tndim);
if (ndim == rndim) {
for (size_t i = 0; i < tndim; ++i) {
data_shape.emplace_back(data->shape[i]);
reps_shape.emplace_back(reps[i]);
}
} else if (ndim > rndim) {
for (size_t i = 0; i < ndim; ++i)
data_shape.emplace_back(data->shape[i]);
for (size_t i = 0; i < (ndim - rndim); ++i)
reps_shape.emplace_back(1);
for (size_t i = 0; i < rndim; ++i)
reps_shape.emplace_back(reps[i]);
} else {
for (size_t i = 0; i < rndim; ++i)
reps_shape.emplace_back(reps[i]);
}
std::vector<IndexExpr> oshape;
oshape.reserve(tndim);
for (size_t i = 0; i < tndim; ++i) {
oshape.emplace_back(data_shape[i] * reps_shape[i]);
}
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Array<Tensor> TileCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const TileAttrs *param = attrs.as<TileAttrs>();
CHECK(param != nullptr);
return { topi::tile(inputs[0], param->reps) };
}
Expr MakeTile(Expr data,
Array<Integer> reps) {
auto attrs = make_node<TileAttrs>();
attrs->reps = reps;
static const Op& op = Op::Get("tile");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.tile")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeTile, args, rv);
});
RELAY_REGISTER_OP("tile")
.describe(R"code(Repeat the whole array multiple times.
- **data**: The input data to the operator.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.Tile")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("Tile", TileRel)
.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
// where operator // where operator
bool WhereRel(const Array<Type>& types, bool WhereRel(const Array<Type>& types,
int num_inputs, int num_inputs,
......
...@@ -720,6 +720,115 @@ inline Tensor where(const Tensor& condition, ...@@ -720,6 +720,115 @@ inline Tensor where(const Tensor& condition,
} }
/*! /*!
* \brief Creates an operation to repeat elements of an array
*
* \param x The input tensor
* \param repeats The number of repetitions for each element
* \param axis The axis along which to repeat values (allows
* negative indices as offsets from the last dimension)
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the repeat operation
*/
inline Tensor repeat(const Tensor& x,
int repeats,
int axis,
std::string name = "tensor",
std::string tag = kBroadcast) {
int ndim = static_cast<int>(x->shape.size());
CHECK(-ndim - 1 <= axis && axis <= ndim)
<< "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
<< ", but got axis = " << axis
<< ", and data.ndim = " << ndim;
CHECK(repeats >= 1)
<< "repeat only accepts `repeats >= 1`"
<< ", but got repeats = " << repeats;
if (axis < 0) {
// Calculate offset from last dimension
axis += ndim;
}
Array<Expr> new_shape;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
new_shape.push_back(x->shape[i]);
}
new_shape.push_back(repeats * x->shape[axis]);
for (size_t i = axis + 1; i < x->shape.size(); ++i) {
new_shape.push_back(x->shape[i]);
}
return compute(
new_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
idx.push_back(indices[i]);
}
idx.push_back(indices[axis] / repeats);
for (size_t i = axis + 1; i < indices.size(); ++i) {
idx.push_back(indices[i]);
}
return x(idx);
}, name, tag);
}
/*!
* \brief Creates an operation to tile elements of an array
*
* \param x The input tensor
* \param reps The number of times for repeating the tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the tile operation
*/
inline Tensor tile(const Tensor& x,
Array<Integer> reps,
std::string name = "tensor",
std::string tag = kBroadcast) {
size_t ndim = x->shape.size();
size_t rdim = reps.size();
size_t tdim = (ndim > rdim) ? ndim : rdim;
Array<Expr> data_shape;
Array<Expr> reps_shape;
Array<Expr> new_shape;
if (ndim == rdim) {
for (size_t i = 0; i < ndim; ++i) {
data_shape.push_back(x->shape[i]);
reps_shape.push_back(reps[i]);
}
} else if (ndim > rdim) {
for (size_t i = 0; i < ndim; ++i)
data_shape.push_back(x->shape[i]);
for (size_t i = 0; i < (ndim - rdim); ++i)
reps_shape.push_back(1);
for (size_t i = 0; i < rdim; ++i)
reps_shape.push_back(reps[i]);
} else {
for (size_t i = 0; i < (rdim - ndim); ++i)
data_shape.push_back(1);
for (size_t i = 0; i < ndim; ++i)
data_shape.push_back(x->shape[i]);
for (size_t i = 0; i < rdim; ++i)
reps_shape.push_back(reps[i]);
}
for (size_t i = 0; i < tdim; ++i)
new_shape.push_back(data_shape[i] * reps_shape[i]);
return compute(
new_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
if (ndim >= rdim) {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indices[i] % x->shape[i]);
} else {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indices[rdim - ndim + i] % x->shape[i]);
}
return x(idx);
}, name, tag);
}
/*!
* \brief Gather elements from a n-dimension array. * \brief Gather elements from a n-dimension array.
* *
* \param data The source array. * \param data The source array.
......
...@@ -339,6 +339,45 @@ def arange(start, stop=None, step=1, dtype="float32"): ...@@ -339,6 +339,45 @@ def arange(start, stop=None, step=1, dtype="float32"):
return cpp.arange(start, stop, step, dtype) return cpp.arange(start, stop, step, dtype)
def repeat(a, repeats, axis):
"""Repeats elements of an array.
Parameters
----------
a : tvm.Tensor
The tensor to be repeated.
repeats: int, required
Number of repetitions for each element
axis: int, optional
The axis along which to repeat values
Returns
-------
ret : tvm.Tensor
"""
return cpp.repeat(a, repeats, axis)
def tile(a, reps):
"""Repeats the whole array multiple times.
Parameters
----------
a : tvm.Tensor
The tensor to be tiled.
reps: tuple of ints, required
The number of times for repeating the tensor
Returns
-------
ret : tvm.Tensor
"""
return cpp.tile(a, reps)
def layout_transform(array, src_layout, dst_layout): def layout_transform(array, src_layout, dst_layout):
"""Transform the layout according to src_layout and dst_layout """Transform the layout according to src_layout and dst_layout
......
...@@ -305,6 +305,16 @@ TVM_REGISTER_GLOBAL("topi.arange") ...@@ -305,6 +305,16 @@ TVM_REGISTER_GLOBAL("topi.arange")
*rv = arange(args[0], args[1], args[2], args[3]); *rv = arange(args[0], args[1], args[2], args[3]);
}); });
TVM_REGISTER_GLOBAL("topi.repeat")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = repeat(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.tile")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.gather_nd") TVM_REGISTER_GLOBAL("topi.gather_nd")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = gather_nd(args[0], args[1]); *rv = gather_nd(args[0], args[1]);
......
...@@ -359,6 +359,50 @@ def verify_arange(start, stop, step): ...@@ -359,6 +359,50 @@ def verify_arange(start, stop, step):
for device in get_all_backend(): for device in get_all_backend():
check_device(device) check_device(device)
def verify_repeat(in_shape, repeats, axis):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.repeat(A, repeats, axis)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
foo = tvm.build(s, [A, B], device, name="repeat")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.repeat(data_npy, repeats, axis)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
foo(data_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in get_all_backend():
check_device(device)
def verify_tile(in_shape, reps):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.tile(A, reps)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
foo = tvm.build(s, [A, B], device, name="tile")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.tile(data_npy, reps)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
foo(data_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in get_all_backend():
check_device(device)
def test_strided_slice(): def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
...@@ -481,6 +525,16 @@ def test_arange(): ...@@ -481,6 +525,16 @@ def test_arange():
verify_arange(20, 1, -1) verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5) verify_arange(20, 1, -1.5)
def test_repeat():
verify_repeat((2,), 1, 0)
verify_repeat((3, 2), 2, 0)
verify_repeat((3, 2, 4), 3, 1)
verify_repeat((1, 3, 2, 4), 4, -1)
def test_tile():
verify_tile((3, 2), (2, 3))
verify_tile((3, 2, 5), (2,))
verify_tile((3, ), (2, 3, 3))
def test_layout_transform(): def test_layout_transform():
in_shape = (1, 32, 8, 8) in_shape = (1, 32, 8, 8)
...@@ -525,3 +579,5 @@ if __name__ == "__main__": ...@@ -525,3 +579,5 @@ if __name__ == "__main__":
test_gather_nd() test_gather_nd()
test_arange() test_arange()
test_layout_transform() test_layout_transform()
test_repeat()
test_tile()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment