Commit dd13c2c2 by Wei Chen Committed by Tianqi Chen

[Relay][Op] Add type check to dense (#4724)

parent 13ffd989
......@@ -57,6 +57,11 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
} else {
if (weight == nullptr) return false;
Array<tvm::PrimExpr> wshape = weight->shape;
CHECK(static_cast<int>(weight->shape.size()) == 2);
CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1],
weight->shape[1]))
<< "DenseRel: input dimension doesn't match,"
<< " data shape=" << data->shape << ", weight shape=" << weight->shape;
oshape.Set((oshape.size() - 1), wshape[0]);
}
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest
import tvm
import scipy
from tvm import relay
......@@ -336,6 +337,16 @@ def test_batch_norm():
relay.ty.TensorType((3,), dtype)
]))
@pytest.mark.xfail
def test_dense_type_check():
dtype = 'float16'
n, c , h, w = 2, 2 , 2 ,2
x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
# it should fail since it does not match with m(2)
mismatch_w = 3
w = relay.var("w", relay.TensorType((2, mismatch_w), dtype))
y = relay.nn.dense(x, w)
yy = run_infer_type(y)
def test_dense():
for dtype in ['float16', 'float32']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment