Commit 11a3a777 by shoubhik Committed by Zhi

Fix infer type of kernel in dense. (#4125)

* Fix infer type of kernel in dense.

* - Moving the check of weight being nullptr up as it is needed in both the branches now.
- Adding test case for validating that data dtype and kernel dtypes can be different.

* - Fix the dtype check for weight. If the weight is not present then we will use the data dtype.
parent 4ee534ba
......@@ -49,7 +49,11 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// validate the weight shape is proper if defined
// Assign weight type
Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]});
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
// It is possible for weight to be nullptr in which case we will use
// data dtype as the weight dtype. However if weight dtype is explicitly
// present we will use that.
auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype);
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
oshape.Set((oshape.size() - 1), param->units);
} else {
if (weight == nullptr) return false;
......
......@@ -382,6 +382,21 @@ def test_dense():
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
def test_dense_dtype():
data_dtype = 'uint8'
weight_dtype = 'int8'
out_dtype = 'uint8'
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), data_dtype))
w = relay.var("w", relay.TensorType((2, w), weight_dtype))
y = relay.nn.dense(x, w, units=2, out_dtype=out_dtype)
assert "units=2" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, 2), out_dtype)
assert run_infer_type(yy.args[0]).checked_type.dtype == 'uint8'
assert run_infer_type(yy.args[1]).checked_type.dtype == 'int8'
def test_bitserial_dense():
m, k = tvm.var("m"), tvm.var("k")
x = relay.var("x", relay.TensorType((m, k), "int16"))
......@@ -405,3 +420,4 @@ if __name__ == "__main__":
test_batch_norm()
test_dense()
test_bitserial_dense()
test_dense_dtype()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment