Commit 5f89a50e by Leyuan Wang Committed by Tianqi Chen

[Bugfix] Repeat and tile bug fixed, relay tests added (#2804)

parent 046e4ff0
...@@ -363,7 +363,7 @@ RELAY_REGISTER_OP("stack") ...@@ -363,7 +363,7 @@ RELAY_REGISTER_OP("stack")
.set_attrs_type_key("relay.attrs.StackAttrs") .set_attrs_type_key("relay.attrs.StackAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input list of tensors.") .add_argument("data", "Tensor", "The input list of tensors.")
.set_support_level(1) .set_support_level(3)
.add_type_rel("Stack", StackRel) .add_type_rel("Stack", StackRel)
.set_attr<FTVMCompute>("FTVMCompute", StackCompute) .set_attr<FTVMCompute>("FTVMCompute", StackCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
...@@ -1109,7 +1109,7 @@ RELAY_REGISTER_OP("repeat") ...@@ -1109,7 +1109,7 @@ RELAY_REGISTER_OP("repeat")
.set_num_inputs(1) .set_num_inputs(1)
.set_attrs_type_key("relay.attrs.Repeat") .set_attrs_type_key("relay.attrs.Repeat")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1) .set_support_level(3)
.add_type_rel("Repeat", RepeatRel) .add_type_rel("Repeat", RepeatRel)
.set_attr<FTVMCompute>("FTVMCompute", RepeatCompute) .set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast); .set_attr<TOpPattern>("TOpPattern", kBroadcast);
...@@ -1134,9 +1134,15 @@ bool TileRel(const Array<Type>& types, ...@@ -1134,9 +1134,15 @@ bool TileRel(const Array<Type>& types,
const size_t ndim = data->shape.size(); const size_t ndim = data->shape.size();
const Array<Integer>& reps = param->reps; const Array<Integer>& reps = param->reps;
// check dimension match // check dimension match
CHECK(!reps.defined()) CHECK(reps.defined())
<< "repetition array is not defined. data.ndim = " << ndim; << "repetition array is not defined. data.ndim = " << ndim;
const size_t rndim = reps.size(); const size_t rndim = reps.size();
for (size_t i = 0; i < rndim; ++i) {
if (const tvm::ir::IntImm* val = reps[i].as<tvm::ir::IntImm>()) {
CHECK_GT(val->value, 0)
<< "Tile reps value should always be larger than 0, but get: " << val->value;
}
}
size_t tndim = (ndim > rndim) ? ndim : rndim; size_t tndim = (ndim > rndim) ? ndim : rndim;
// re-construct data shape or reps shape // re-construct data shape or reps shape
std::vector<IndexExpr> data_shape; std::vector<IndexExpr> data_shape;
...@@ -1158,6 +1164,10 @@ bool TileRel(const Array<Type>& types, ...@@ -1158,6 +1164,10 @@ bool TileRel(const Array<Type>& types,
} else { } else {
for (size_t i = 0; i < rndim; ++i) for (size_t i = 0; i < rndim; ++i)
reps_shape.emplace_back(reps[i]); reps_shape.emplace_back(reps[i]);
for (size_t i = 0; i < (rndim - ndim); ++i)
data_shape.emplace_back(1);
for (size_t i = 0; i < ndim; ++i)
data_shape.emplace_back(data->shape[i]);
} }
std::vector<IndexExpr> oshape; std::vector<IndexExpr> oshape;
oshape.reserve(tndim); oshape.reserve(tndim);
...@@ -1199,7 +1209,7 @@ RELAY_REGISTER_OP("tile") ...@@ -1199,7 +1209,7 @@ RELAY_REGISTER_OP("tile")
.set_num_inputs(1) .set_num_inputs(1)
.set_attrs_type_key("relay.attrs.Tile") .set_attrs_type_key("relay.attrs.Tile")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1) .set_support_level(3)
.add_type_rel("Tile", TileRel) .add_type_rel("Tile", TileRel)
.set_attr<FTVMCompute>("FTVMCompute", TileCompute) .set_attr<FTVMCompute>("FTVMCompute", TileCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast); .set_attr<TOpPattern>("TOpPattern", kBroadcast);
......
...@@ -491,6 +491,62 @@ def test_arange(): ...@@ -491,6 +491,62 @@ def test_arange():
verify_arange(20, 1, -1) verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5) verify_arange(20, 1, -1.5)
def test_tile():
def verify_tile(dshape, reps):
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.tile(x, reps=reps)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.tile(x_data, reps=reps)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_tile((2, 3, 4), (3, 2, 1))
verify_tile((2, 3, 4), (1, 2))
verify_tile((2, 3), (3, 2, 1))
def test_repeat():
def verify_repeat(dshape, repeats, axis):
x = relay.Var("x", relay.TensorType(dshape, "float32"))
func = relay.Function([x], relay.repeat(x, repeats, axis))
data = np.random.uniform(size=dshape).astype("float32")
ref_res = np.repeat(data, repeats, axis)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_repeat((3,), 2, 0)
verify_repeat((3, 10), 2, -1)
verify_repeat((3, 2, 4), 3, 1)
def test_stack():
def verify_stack(dshapes, axis):
y = []
for shape in dshapes:
y.append(relay.var("input", relay.TensorType(shape, "float32")))
x = relay.Tuple(y)
z = relay.stack(x, axis=axis)
func = relay.Function(y, z)
x_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
ref_res = np.stack(x_data, axis=axis)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(*x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_stack([(2,), (2,), (2,)], -1)
verify_stack([(2,), (2,), (2,)], 0)
verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)
def test_reverse(): def test_reverse():
def verify_reverse(dshape, axis): def verify_reverse(dshape, axis):
...@@ -536,3 +592,6 @@ if __name__ == "__main__": ...@@ -536,3 +592,6 @@ if __name__ == "__main__":
test_split_infer_type() test_split_infer_type()
test_arange() test_arange()
test_reverse() test_reverse()
test_stack()
test_tile()
test_repeat()
...@@ -29,7 +29,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): ...@@ -29,7 +29,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
check_device(device) check_device(device)
def verify_tranpose(in_shape, axes): def verify_transpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.transpose(A, axes) B = topi.transpose(A, axes)
def check_device(device): def check_device(device):
...@@ -40,7 +40,7 @@ def verify_tranpose(in_shape, axes): ...@@ -40,7 +40,7 @@ def verify_tranpose(in_shape, axes):
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_injective(B) s = topi.generic.schedule_injective(B)
foo = tvm.build(s, [A, B], device, name="tranpose") foo = tvm.build(s, [A, B], device, name="transpose")
data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype) data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
out_npy = data_npy.transpose(axes) out_npy = data_npy.transpose(axes)
data_nd = tvm.nd.array(data_npy, ctx) data_nd = tvm.nd.array(data_npy, ctx)
...@@ -416,10 +416,10 @@ def test_expand_dims(): ...@@ -416,10 +416,10 @@ def test_expand_dims():
verify_expand_dims((3, 10), (1, 3, 10), -3, 1) verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
def test_tranpose(): def test_transpose():
verify_tranpose((3, 10, 2), (1, 0, 2)) verify_transpose((3, 10, 2), (1, 0, 2))
verify_tranpose((3, 10, 5), (2, 0, 1)) verify_transpose((3, 10, 5), (2, 0, 1))
verify_tranpose((3, 10), None) verify_transpose((3, 10), None)
def test_reshape(): def test_reshape():
...@@ -595,7 +595,7 @@ if __name__ == "__main__": ...@@ -595,7 +595,7 @@ if __name__ == "__main__":
test_strided_slice() test_strided_slice()
test_concatenate() test_concatenate()
test_stack() test_stack()
test_tranpose() test_transpose()
test_expand_dims() test_expand_dims()
test_reshape() test_reshape()
test_squeeze() test_squeeze()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment