Commit ba9d96bc by Zhi Committed by Yao Wang

[relay][op] Add shape func to tile (#4441)

* [relay][op] Add shape func to tile

* retrigger ci

* check dynamic axes

* retrigger ci
parent 6ef2418b
......@@ -501,3 +501,35 @@ def reshape_like_shape_func(attrs, inputs, _):
Shape function for reshape_like op.
"""
return [_reshape_like_shape_func(inputs[1])]
@script
def _tile_shape_func(data, reps, ndim, tndim, rndim):
out = output_tensor((tndim,), "int64")
if ndim == rndim:
for i in const_range(tndim):
out[i] = data[i] * int64(reps[i])
elif ndim > rndim:
ngap = ndim - rndim
for i in const_range(ndim):
if i < ngap:
out[i] = data[i]
else:
out[i] = data[i] * int64(reps[i - ngap])
else:
rgap = rndim - ndim
for i in const_range(rndim):
if i < rgap:
out[i] = int64(reps[i])
else:
out[i] = int64(reps[i]) * data[i - rgap]
return out
@_reg.register_shape_func("tile", False)
def tile_shape_func(attrs, inputs, _):
reps = get_const_tuple(attrs.reps)
ndim = inputs[0].shape[0].value
rndim = len(reps)
tndim = ndim if ndim > rndim else rndim
return [_tile_shape_func(inputs[0], convert(reps), convert(ndim),
convert(tndim), convert(rndim))]
......@@ -1393,28 +1393,39 @@ bool TileRel(const Array<Type>& types,
reps_shape.reserve(tndim);
if (ndim == rndim) {
for (size_t i = 0; i < tndim; ++i) {
data_shape.emplace_back(data->shape[i]);
reps_shape.emplace_back(reps[i]);
data_shape.emplace_back(data->shape[i]);
reps_shape.emplace_back(reps[i]);
}
} else if (ndim > rndim) {
for (size_t i = 0; i < ndim; ++i)
data_shape.emplace_back(data->shape[i]);
for (size_t i = 0; i < (ndim - rndim); ++i)
reps_shape.emplace_back(1);
for (size_t i = 0; i < rndim; ++i)
reps_shape.emplace_back(reps[i]);
for (size_t i = 0; i < ndim; ++i) {
data_shape.emplace_back(data->shape[i]);
}
for (size_t i = 0; i < (ndim - rndim); ++i) {
reps_shape.emplace_back(1);
}
for (size_t i = 0; i < rndim; ++i) {
reps_shape.emplace_back(reps[i]);
}
} else {
for (size_t i = 0; i < rndim; ++i)
reps_shape.emplace_back(reps[i]);
for (size_t i = 0; i < (rndim - ndim); ++i)
data_shape.emplace_back(1);
for (size_t i = 0; i < ndim; ++i)
data_shape.emplace_back(data->shape[i]);
for (size_t i = 0; i < rndim; ++i) {
reps_shape.emplace_back(reps[i]);
}
for (size_t i = 0; i < (rndim - ndim); ++i) {
data_shape.emplace_back(1);
}
for (size_t i = 0; i < ndim; ++i) {
data_shape.emplace_back(data->shape[i]);
}
}
std::vector<IndexExpr> oshape;
oshape.reserve(tndim);
for (size_t i = 0; i < tndim; ++i) {
oshape.emplace_back(data_shape[i] * reps_shape[i]);
// Save Any if it is dynamic shape
if (!data_shape[i].as<IntImm>()) {
oshape.emplace_back(Any::make());
} else {
oshape.emplace_back(data_shape[i] * reps_shape[i]);
}
}
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
......
......@@ -193,6 +193,25 @@ def test_any_take():
verify_any_take(any_dims(2), any_dims(3), None, (4, 5), (2, 3, 4))
verify_any_take(any_dims(2), any_dims(4), -1, (4, 5), (2, 3, 4, 5))
def verify_any_tile(dshape, reps, np_dshape, np_reps):
mod = relay.Module()
x = relay.var("x", shape=dshape, dtype="float32")
y = relay.tile(x, reps=reps)
mod["main"] = relay.Function([x], y)
x_data = np.random.uniform(size=np_dshape).astype("float32")
ref_res = np.tile(x_data, reps=np_reps)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
res = ex.evaluate()(x_data)
tvm.testing.assert_allclose(res.asnumpy(), ref_res, rtol=1e-5)
def test_any_tile():
verify_any_tile(any_dims(3), (3, 2, 1), (2, 3, 4), (3, 2, 1))
verify_any_tile(any_dims(3), (1, 2), (2, 3, 4), (1, 2))
verify_any_tile(any_dims(2), (3, 2, 1), (2, 3), (3, 2, 1))
verify_any_tile(any_dims(3), (1,), (2, 3, 4), (1,))
def test_any_shape_of():
x = relay.var('x', shape=any_dims(2), dtype='float32')
y = relay.shape_of(x)
......@@ -586,6 +605,7 @@ if __name__ == "__main__":
test_any_concat()
test_any_reshape()
test_any_take()
test_any_tile()
test_any_shape_of()
test_any_reduce()
test_any_layout_transform()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment