Commit cdbf4d85 by Ligeng Zhu Committed by Zhi

[Relay] Add shape check for ConcatenateRel and StackRel (#3699)

* [Relay] add shape check for concat

* [Relay] add shape check for stack

* add test case for shape mismatch

* [typo] add the missing assert

* fix lint errors.

* replace int with size_t.

* statically cast param->axis to size_t.

* switch to run_infer_type.

* fix checking for negative index

* add static_cast for param->axis

* merge to latest tvm

* fix lint error

* Fix an error with negative index.

* Update transform.h

* Update transform.cc
parent f3abb3d8
...@@ -358,8 +358,17 @@ bool StackRel(const Array<Type>& types, ...@@ -358,8 +358,17 @@ bool StackRel(const Array<Type>& types,
} }
const auto* param = attrs.as<StackAttrs>(); const auto* param = attrs.as<StackAttrs>();
const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]); const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
// Sanity check: ndim and dtype.
const int ndim = static_cast<int>(first->shape.size()); const int ndim = static_cast<int>(first->shape.size());
// Sanity check: axis
int axis = param->axis;
CHECK(-ndim <= axis && axis < ndim)
<< "stack only accepts `axis` in [-ndim, ndim)"
<< ", but got axis = " << axis
<< ", and ndim = " << ndim;
axis = axis < 0 ? ndim + axis + 1: axis;
// Sanity check: ndim and dtype.
const DataType dtype = first->dtype; const DataType dtype = first->dtype;
for (const Type& ele : tensor_tuple->fields) { for (const Type& ele : tensor_tuple->fields) {
const auto& e = Downcast<TensorType>(ele); const auto& e = Downcast<TensorType>(ele);
...@@ -367,14 +376,14 @@ bool StackRel(const Array<Type>& types, ...@@ -367,14 +376,14 @@ bool StackRel(const Array<Type>& types,
const DataType& e_dtype = e->dtype; const DataType& e_dtype = e->dtype;
CHECK_EQ(e_ndim, ndim) << "relay.stack requires all tensors have the same ndim"; CHECK_EQ(e_ndim, ndim) << "relay.stack requires all tensors have the same ndim";
CHECK_EQ(e_dtype, dtype) << "relay.stack requires all tensors have the same dtype"; CHECK_EQ(e_dtype, dtype) << "relay.stack requires all tensors have the same dtype";
for (size_t j = 0; j < first->shape.size(); ++j) {
if (j == static_cast<size_t>(axis)) continue;
if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue;
throw relay::Error("relay.stack requires all tensors have the same shape "
"on non-stacking axes");
} }
// Sanity check: axis }
int axis = param->axis;
CHECK(-ndim <= axis && axis < ndim)
<< "stack only accepts `axis` in [-ndim, ndim)"
<< ", but got axis = " << axis
<< ", and ndim = " << ndim;
axis = axis < 0 ? ndim + axis + 1 : axis;
// Calculate shape // Calculate shape
std::vector<IndexExpr> oshape; std::vector<IndexExpr> oshape;
oshape.reserve(ndim + 1); oshape.reserve(ndim + 1);
......
...@@ -65,6 +65,16 @@ bool ConcatenateRel(const Array<Type>& types, ...@@ -65,6 +65,16 @@ bool ConcatenateRel(const Array<Type>& types,
const int ndim = static_cast<int>(first->shape.size()); const int ndim = static_cast<int>(first->shape.size());
const DataType dtype = first->dtype; const DataType dtype = first->dtype;
// Sanity check: axis
int axis = param->axis;
if (!(-ndim <= axis && axis < ndim)) {
throw relay::Error(RELAY_ERROR(
"concatenate only accepts `axis` in [-ndim, ndim)" <<
", but got axis = " << axis <<
", and ndim = " << ndim));
}
axis = axis < 0 ? ndim + axis : axis;
for (const Type& ele : tensor_tuple->fields) { for (const Type& ele : tensor_tuple->fields) {
if (ele.as<IncompleteTypeNode>()) { if (ele.as<IncompleteTypeNode>()) {
return false; return false;
...@@ -80,16 +90,14 @@ bool ConcatenateRel(const Array<Type>& types, ...@@ -80,16 +90,14 @@ bool ConcatenateRel(const Array<Type>& types,
if (e_dtype != dtype) { if (e_dtype != dtype) {
throw relay::Error("relay.concatenate requires all tensors have the same dtype"); throw relay::Error("relay.concatenate requires all tensors have the same dtype");
} }
for (size_t j = 0; j < first->shape.size(); ++j) {
if (j == static_cast<size_t>(axis)) continue;
if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue;
throw relay::Error("relay.concatenate requires all tensors have the same shape "
"on non-concatenating axes");
} }
// Sanity check: axis
int axis = param->axis;
if (!(-ndim <= axis && axis < ndim)) {
throw relay::Error(RELAY_ERROR(
"concatenate only accepts `axis` in [-ndim, ndim)" <<
", but got axis = " << axis <<
", and ndim = " << ndim));
} }
axis = axis < 0 ? ndim + axis : axis;
// Calculate shape // Calculate shape
std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end()); std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
IndexExpr &concat_dim = oshape[axis]; IndexExpr &concat_dim = oshape[axis];
......
...@@ -219,6 +219,18 @@ def test_concatenate(): ...@@ -219,6 +219,18 @@ def test_concatenate():
zz = run_infer_type(z) zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((n, t + t, 100)) assert zz.checked_type == relay.TensorType((n, t + t, 100))
# check shape mismatches (the following case is expected to raise tvm._ffi.base.TVMError.
try:
x = relay.var('p1', shape=(2, 5))
y = relay.var('p2', shape=(2, 3))
c = relay.concatenate([x, y], axis=0)
func = relay.Function([x, y], c)
zz = run_infer_type(func)
except tvm._ffi.base.TVMError:
pass
else:
assert False
x = relay.var("x", shape=(10, 5)) x = relay.var("x", shape=(10, 5))
y = relay.var("y", shape=(10, 5)) y = relay.var("y", shape=(10, 5))
t = relay.var("z", shape=()) t = relay.var("z", shape=())
...@@ -301,7 +313,7 @@ def test_dense(): ...@@ -301,7 +313,7 @@ def test_dense():
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
w = relay.var("w", relay.TensorType((2, w), "float32")) w = relay.var("w", relay.TensorType((2, w), "float32"))
y = relay.nn.dense(x, w, units=2) y = relay.nn.dense(x, w, units=2)
"units=2" in y.astext() assert "units=2" in y.astext()
yy = run_infer_type(y) yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment