Commit c1069108 by shoubhik Committed by Zhi

Adding support for dequantizing from int32 to float32. (#4130)

parent 46fa6eeb
......@@ -43,8 +43,9 @@ bool DequantizeRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
const auto input_dtype = data->dtype;
CHECK(input_dtype == Int(8) || input_dtype == UInt(8))
<< "Input type should be one of the quantized types [unit8, int8] but was " << input_dtype;
CHECK(input_dtype == Int(8) || input_dtype == UInt(8) || input_dtype == Int(32))
<< "Input type should be one of the quantized types [unit8, int8, int32] but was "
<< input_dtype;
const Array<tvm::Expr> oshape = data->shape;
// assign output type, output will always be float 32.
reporter->Assign(types[1], TensorTypeNode::make(oshape, Float(32)));
......
......@@ -44,10 +44,10 @@ def test_dequantize_op():
def test_uint8_to_float32():
data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \
.astype('uint8') \
.reshape((2,5))
.reshape((2, 5))
output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
.astype('float32') \
.reshape((2,5))
.reshape((2, 5))
quant_args = {"in_zero_point":127, "in_scale":0.5}
quantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data,
verify_output_data=output)
......@@ -55,16 +55,24 @@ def test_dequantize_op():
def test_int8_to_float32():
data = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \
.astype('int8') \
.reshape((2,5))
.reshape((2, 5))
output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \
.astype('float32') \
.reshape((2,5))
quant_args = {"in_zero_point":-1, "in_scale":0.5}
.reshape((2, 5))
quant_args = {"in_zero_point": -1, "in_scale": 0.5}
quantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data,
verify_output_data=output)
def test_int32_to_float32():
data = np.array([113, 29, -1052]).astype('int32')
output = np.array([0.6550452, 0.16810896, -6.098297]).astype('float32')
quant_args = {"in_zero_point": 0, "in_scale": 0.0057968604}
quantize_test_driver(in_dtype='int32', quant_args=quant_args, in_data=data,
verify_output_data=output)
test_uint8_to_float32()
test_int8_to_float32()
test_int32_to_float32()
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment