Commit f34dea41 by Animesh Jain Committed by Zhi

[QNN] Use Int16 upcast in Fallback Conv2D. Fix test names. (#4329)

parent fed79b3a
...@@ -106,8 +106,6 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv ...@@ -106,8 +106,6 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv
* \brief Fallback to simpler lowering for dilation or depthwise conv. * \brief Fallback to simpler lowering for dilation or depthwise conv.
* \param data The input expr. * \param data The input expr.
* \param weight The weight expr. * \param weight The weight expr.
* \param zp_data The data zero point expr.
* \param zp_kernel The kernel zero point expr.
* \param param The qnn conv2d attributes. * \param param The qnn conv2d attributes.
* \return The fallback lowered sequence of Relay expr. * \return The fallback lowered sequence of Relay expr.
* \note In case of dilation, normal lowering would require a dilated pool. * \note In case of dilation, normal lowering would require a dilated pool.
...@@ -115,16 +113,19 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv ...@@ -115,16 +113,19 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv
* Relay operations. This will potentially lead to performance degradation * Relay operations. This will potentially lead to performance degradation
* as the convolution is called on int32 tensors instead of int8 tensors. * as the convolution is called on int32 tensors instead of int8 tensors.
*/ */
Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& zp_data, Expr Conv2DFallBack(const Expr& data, const Expr& weight, const QnnConv2DAttrs* param) {
const Expr& zp_kernel, const QnnConv2DAttrs* param) { // Upcast the zero point to Int16.
auto shifted_data = data; auto zp_data = MakeConstantScalar(Int(16), param->input_zero_point);
auto zp_kernel = MakeConstantScalar(Int(16), param->kernel_zero_point);
auto shifted_data = Cast(data, Int(16));
if (param->input_zero_point != 0) { if (param->input_zero_point != 0) {
shifted_data = Subtract(Cast(data, Int(32)), zp_data); shifted_data = Subtract(Cast(data, Int(16)), zp_data);
} }
auto shifted_kernel = weight; auto shifted_kernel = Cast(weight, Int(16));
if (param->kernel_zero_point != 0) { if (param->kernel_zero_point != 0) {
shifted_kernel = Subtract(Cast(weight, Int(32)), zp_kernel); shifted_kernel = Subtract(Cast(weight, Int(16)), zp_kernel);
} }
return Conv2D(shifted_data, shifted_kernel, param->strides, param->padding, param->dilation, return Conv2D(shifted_data, shifted_kernel, param->strides, param->padding, param->dilation,
...@@ -186,7 +187,6 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2 ...@@ -186,7 +187,6 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2
/* /*
* \brief Calculates the second term in the qnn.conv2d lowering sequence. * \brief Calculates the second term in the qnn.conv2d lowering sequence.
* \param padded_data The padded data expr. * \param padded_data The padded data expr.
* \param zp_kernel The kernel zero point expr.
* \param param The qnn conv2d attributes. * \param param The qnn conv2d attributes.
* \param kernel_h The height of kernel. * \param kernel_h The height of kernel.
* \param kernel_w The width of kernel. * \param kernel_w The width of kernel.
...@@ -200,8 +200,11 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2 ...@@ -200,8 +200,11 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2
* followed by a reduce on the C axis. Using avg_pool2d also gives an * followed by a reduce on the C axis. Using avg_pool2d also gives an
* opportunity to reuse alter_op_layout infrastructure. * opportunity to reuse alter_op_layout infrastructure.
*/ */
Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnConv2DAttrs* param, Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int kernel_h,
int kernel_h, int kernel_w, int out_channels) { int kernel_w, int out_channels) {
// Constant Expr for the kernel zero point.
auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point);
auto casted_t2 = Cast(padded_data, Int(32)); auto casted_t2 = Cast(padded_data, Int(32));
// We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum. // We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum.
...@@ -241,7 +244,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC ...@@ -241,7 +244,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC
/* /*
* \brief Calculates the third term in the qnn.conv2d lowering sequence. * \brief Calculates the third term in the qnn.conv2d lowering sequence.
* \param weight The weight expr. * \param weight The weight expr.
* \param zp_data The data zero point expr.
* \param param The qnn conv2d attributes. * \param param The qnn conv2d attributes.
* \param batch_size The batch size. * \param batch_size The batch size.
* \param out_channels The number of output channels. * \param out_channels The number of output channels.
...@@ -254,8 +256,11 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC ...@@ -254,8 +256,11 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC
* a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW * a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW
* format. * format.
*/ */
Expr Conv2DThirdTerm(const Expr& weight, const Expr& zp_data, const QnnConv2DAttrs* param, Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_size,
int batch_size, int out_channels) { int out_channels) {
// Constant expr for input zero point.
auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);
// Find which dimensions are C, R, S. // Find which dimensions are C, R, S.
Array<Integer> axes_t3; Array<Integer> axes_t3;
if (param->kernel_layout == "OIHW") { if (param->kernel_layout == "OIHW") {
...@@ -415,21 +420,19 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, ...@@ -415,21 +420,19 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
int batch_size, in_channels, out_channels, kernel_h, kernel_w; int batch_size, in_channels, out_channels, kernel_h, kernel_w;
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w) = std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w) =
GetWorkload(arg_types, param); GetWorkload(arg_types, param);
auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);
auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point);
// Fallback to int32 conv if there is dilation or depthwise conv2d // Fallback to int32 conv if there is dilation or depthwise conv2d
CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation"; CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
auto dilation_h = get_const_int(param->dilation[0]); auto dilation_h = get_const_int(param->dilation[0]);
auto dilation_w = get_const_int(param->dilation[1]); auto dilation_w = get_const_int(param->dilation[1]);
if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) { if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) {
return Conv2DFallBack(data, weight, zp_data, zp_kernel, param); return Conv2DFallBack(data, weight, param);
} }
auto padded_data = Conv2DPadInput(data, param); auto padded_data = Conv2DPadInput(data, param);
auto term1 = Conv2DFirstTerm(padded_data, weight, param); auto term1 = Conv2DFirstTerm(padded_data, weight, param);
auto term2 = Conv2DSecondTerm(padded_data, zp_kernel, param, kernel_h, kernel_w, out_channels); auto term2 = Conv2DSecondTerm(padded_data, param, kernel_h, kernel_w, out_channels);
auto term3 = Conv2DThirdTerm(weight, zp_data, param, batch_size, out_channels); auto term3 = Conv2DThirdTerm(weight, param, batch_size, out_channels);
auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h, kernel_w); auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h, kernel_w);
return Conv2DCombineTerms(term1, term2, term3, term4, param); return Conv2DCombineTerms(term1, term2, term3, term4, param);
} }
......
...@@ -160,7 +160,7 @@ def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, ...@@ -160,7 +160,7 @@ def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
qnn_output = get_output(qnn_func, golden_inputs) qnn_output = get_output(qnn_func, golden_inputs)
np.testing.assert_equal(qnn_output, golden_output) np.testing.assert_equal(qnn_output, golden_output)
def no_zero_point_test(): def test_no_zero_point():
# uint8 input # uint8 input
data_shape = (2, 1, 2, 4) data_shape = (2, 1, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -203,7 +203,7 @@ def no_zero_point_test(): ...@@ -203,7 +203,7 @@ def no_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def kernel_zero_point_test(): def test_kernel_zero_point():
# uint8 input # uint8 input
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -247,7 +247,7 @@ def kernel_zero_point_test(): ...@@ -247,7 +247,7 @@ def kernel_zero_point_test():
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def input_zero_point_test(): def test_input_zero_point():
# uint8 input # uint8 input
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -290,7 +290,7 @@ def input_zero_point_test(): ...@@ -290,7 +290,7 @@ def input_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def both_zero_point_test(): def test_both_zero_point():
# uint8 input # uint8 input
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -333,7 +333,7 @@ def both_zero_point_test(): ...@@ -333,7 +333,7 @@ def both_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def layout_test(): def test_layout():
# uint8 input # uint8 input
data_shape = (2, 2, 4, 4) # NHWC data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -378,7 +378,7 @@ def layout_test(): ...@@ -378,7 +378,7 @@ def layout_test():
def padding_test(): def test_padding():
# uint8 input # uint8 input
data_shape = (1, 4, 2, 2) data_shape = (1, 4, 2, 2)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -421,7 +421,7 @@ def padding_test(): ...@@ -421,7 +421,7 @@ def padding_test():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def dilation_test(): def test_dilation():
# uint8 input # uint8 input
data_shape = (2, 4, 4, 4) data_shape = (2, 4, 4, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -444,7 +444,7 @@ def dilation_test(): ...@@ -444,7 +444,7 @@ def dilation_test():
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def const_folding_test(): def test_const_folding():
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2) kernel_shape = (3, 4, 2, 2)
...@@ -470,7 +470,7 @@ def const_folding_test(): ...@@ -470,7 +470,7 @@ def const_folding_test():
folded_func = folded_mod["main"] folded_func = folded_mod["main"]
assert "reshape" not in folded_func.astext() assert "reshape" not in folded_func.astext()
def kernel_size_1x1_test(): def test_kernel_size_1x1():
# uint8 input # uint8 input
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -493,7 +493,7 @@ def kernel_size_1x1_test(): ...@@ -493,7 +493,7 @@ def kernel_size_1x1_test():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def tflite_large_irregular_test(): def test_tflite_large_irregular():
# uint8 input # uint8 input
data_shape = (1, 1024, 1, 1) data_shape = (1, 1024, 1, 1)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -526,7 +526,7 @@ def tflite_large_irregular_test(): ...@@ -526,7 +526,7 @@ def tflite_large_irregular_test():
golden_output = np.full((1, 1001, 1, 1), 0).astype('uint8') golden_output = np.full((1, 1001, 1, 1), 0).astype('uint8')
np.testing.assert_equal(qnn_output, golden_output) np.testing.assert_equal(qnn_output, golden_output)
def tflite_output_multiplier_greater_than_one(): def test_tflite_output_multiplier_greater_than_one():
# uint8 input # uint8 input
data_shape = (2, 1, 2, 4) data_shape = (2, 1, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -570,7 +570,7 @@ def tflite_output_multiplier_greater_than_one(): ...@@ -570,7 +570,7 @@ def tflite_output_multiplier_greater_than_one():
0, 0)).reshape(2, 3, 1, 2) 0, 0)).reshape(2, 3, 1, 2)
np.testing.assert_equal(qnn_output, golden_output) np.testing.assert_equal(qnn_output, golden_output)
def tflite_anistropic_strides(): def test_tflite_anistropic_strides():
# uint8 input # uint8 input
data_shape = (1, 1, 3, 6) data_shape = (1, 1, 3, 6)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -607,7 +607,7 @@ def tflite_anistropic_strides(): ...@@ -607,7 +607,7 @@ def tflite_anistropic_strides():
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2) golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
np.testing.assert_equal(qnn_output, golden_output) np.testing.assert_equal(qnn_output, golden_output)
def broadcast_layout_test(): def test_broadcast_layout():
# Test broadcast support for NHWC layout. # Test broadcast support for NHWC layout.
data_shape = (1, 229, 229, 3) # NHWC data_shape = (1, 229, 229, 3) # NHWC
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -641,16 +641,16 @@ def broadcast_layout_test(): ...@@ -641,16 +641,16 @@ def broadcast_layout_test():
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512") graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
if __name__ == "__main__": if __name__ == "__main__":
no_zero_point_test() test_no_zero_point()
input_zero_point_test() test_input_zero_point()
kernel_zero_point_test() test_kernel_zero_point()
both_zero_point_test() test_both_zero_point()
layout_test() test_layout()
padding_test() test_padding()
dilation_test() test_dilation()
const_folding_test() test_const_folding()
kernel_size_1x1_test() test_kernel_size_1x1()
tflite_large_irregular_test() test_tflite_large_irregular()
tflite_output_multiplier_greater_than_one() test_broadcast_layout()
tflite_anistropic_strides() test_tflite_output_multiplier_greater_than_one()
broadcast_layout_test() test_tflite_anistropic_strides()
...@@ -22,230 +22,227 @@ from tvm.contrib import graph_runtime ...@@ -22,230 +22,227 @@ from tvm.contrib import graph_runtime
roundings = ["UPWARD", "TONEAREST"] roundings = ["UPWARD", "TONEAREST"]
def test_requantize(): def verify(mod, goldens):
def verify(mod, goldens): with relay.build_config(opt_level=3):
with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None)
graph, lib, params = relay.build(mod, "llvm", params=None) golden_data, golden_output = goldens
golden_data, golden_output = goldens rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) rt_mod.set_input("quantized_data",golden_data)
rt_mod.set_input("quantized_data",golden_data) rt_mod.set_input(**params)
rt_mod.set_input(**params) rt_mod.run()
rt_mod.run() res = rt_mod.get_output(0).asnumpy()
res = rt_mod.get_output(0).asnumpy() np.testing.assert_equal(res, golden_output)
np.testing.assert_equal(res, golden_output)
def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale, input_zero_point=0, output_zero_point=0, rounding="TONEAREST"):
input_zero_point=0, output_zero_point=0, rounding="TONEAREST"): quantized_data = relay.var("quantized_data", shape=data_shape,
quantized_data = relay.var("quantized_data", shape=data_shape, dtype=data_dtype)
dtype=data_dtype) mod = relay.qnn.op.requantize(
mod = relay.qnn.op.requantize( quantized_data,
quantized_data, input_scale=input_scale,
input_scale=input_scale, input_zero_point=input_zero_point,
input_zero_point=input_zero_point, output_scale=output_scale,
output_scale=output_scale, output_zero_point=output_zero_point,
output_zero_point=output_zero_point, rounding=rounding,
rounding=rounding, out_dtype=out_dtype)
out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(mod), mod)
mod = relay.Function(relay.analysis.free_vars(mod), mod) mod = relay.Module.from_expr(mod)
mod = relay.Module.from_expr(mod) return mod
return mod
def test_same_scale():
def same_scale_test(): # Have same scales, everything within range
# Have same scales, everything within range golden_data = np.arange(-100, 100, 1).astype('int32')
golden_data = np.arange(-100, 100, 1).astype('int32') golden_output = golden_data
golden_output = golden_data
for rounding in roundings:
for rounding in roundings: mod = get_mod(data_shape=(200, ),
mod = get_mod(data_shape=(200, ), data_dtype='int32',
data_dtype='int32', out_dtype="int8",
out_dtype="int8", input_scale=0.5,
input_scale=0.5, output_scale=0.5,
output_scale=0.5, rounding=rounding)
rounding=rounding) assert 'right_shift' not in mod.astext()
assert 'right_shift' not in mod.astext() verify(mod, (golden_data, golden_output))
verify(mod, (golden_data, golden_output))
def test_downscale():
def downscale_test(): for rounding in roundings:
for rounding in roundings: mod = get_mod(data_shape=(32, ),
mod = get_mod(data_shape=(32, ), data_dtype='int32',
data_dtype='int32', out_dtype='int8',
out_dtype='int8', input_scale=1,
input_scale=1, output_scale=16,
output_scale=16, rounding=rounding)
rounding=rounding)
# Try positive values
# Try positive values # 8 corresponds to 0.5, resulting in 1
# 8 corresponds to 0.5, resulting in 1 golden_data = np.arange(0, 32, 1).astype('int32')
golden_data = np.arange(0, 32, 1).astype('int32') golden_output = np.repeat([0, 1, 2], [8, 16, 8])
golden_output = np.repeat([0, 1, 2], [8, 16, 8]) verify(mod, (golden_data, golden_output))
verify(mod, (golden_data, golden_output))
# Try negative values
# Try negative values # -8 corresponds to -0.5. For UPWARD, this is 0
# -8 corresponds to -0.5. For UPWARD, this is 0 golden_data = np.arange(0, -32, -1).astype('int32')
golden_data = np.arange(0, -32, -1).astype('int32') if rounding == "UPWARD":
if rounding == "UPWARD": golden_output = np.repeat([0, -1, -2], [9, 16, 7])
golden_output = np.repeat([0, -1, -2], [9, 16, 7]) else:
else: golden_output = np.repeat([0, -1, -2], [8, 16, 8])
golden_output = np.repeat([0, -1, -2], [8, 16, 8]) verify(mod, (golden_data, golden_output))
verify(mod, (golden_data, golden_output))
# Try a different scale
# Try a different scale mod = get_mod(data_shape=(32, ),
mod = get_mod(data_shape=(32, ), data_dtype='int32',
data_dtype='int32', out_dtype="int8",
out_dtype="int8", input_scale=1,
input_scale=1, output_scale=4,
output_scale=4, rounding=rounding)
rounding=rounding)
# Try positive values
# Try positive values # 2I corresponds to 0.5, resulting in 1
# 2I corresponds to 0.5, resulting in 1 golden_data = np.arange(0, 32, 1).astype('int32')
golden_data = np.arange(0, 32, 1).astype('int32') golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8],
golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 4, 4, 4, 4, 4, 4, 2])
verify(mod, (golden_data, golden_output))
# Try negative values
# -8 corresponds to -0.5. For UPWARD, this is 0
golden_data = np.arange(0, -32, -1).astype('int32')
if rounding == "UPWARD":
golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
[3, 4, 4, 4, 4, 4, 4, 4, 1])
else:
golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
[2, 4, 4, 4, 4, 4, 4, 4, 2]) [2, 4, 4, 4, 4, 4, 4, 4, 2])
verify(mod, (golden_data, golden_output)) verify(mod, (golden_data, golden_output))
# Try negative values # Try uint8 out_dtype
# -8 corresponds to -0.5. For UPWARD, this is 0 mod = get_mod(data_shape=(32, ),
golden_data = np.arange(0, -32, -1).astype('int32') data_dtype='int32',
if rounding == "UPWARD": out_dtype='uint8',
golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8], input_scale=1,
[3, 4, 4, 4, 4, 4, 4, 4, 1]) output_scale=16,
else: rounding=rounding)
golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
[2, 4, 4, 4, 4, 4, 4, 4, 2]) # Try positive values
verify(mod, (golden_data, golden_output)) # 8 corresponds to 0.5, resulting in 1
golden_data = np.arange(0, 32, 1).astype('int32')
# Try uint8 out_dtype golden_output = np.repeat([0, 1, 2], [8, 16, 8])
mod = get_mod(data_shape=(32, ), verify(mod, (golden_data, golden_output))
data_dtype='int32',
out_dtype='uint8', # Try uint8 in_dtyope and uint8 out_dtype
input_scale=1, mod = get_mod(data_shape=(32, ),
output_scale=16, data_dtype='uint8',
rounding=rounding) out_dtype='uint8',
input_scale=1,
# Try positive values output_scale=16,
# 8 corresponds to 0.5, resulting in 1 rounding=rounding)
golden_data = np.arange(0, 32, 1).astype('int32')
golden_output = np.repeat([0, 1, 2], [8, 16, 8]) # Try positive values
verify(mod, (golden_data, golden_output)) # 8 corresponds to 0.5, resulting in 1
golden_data = np.arange(0, 32, 1).astype('int32')
# Try uint8 in_dtyope and uint8 out_dtype golden_output = np.repeat([0, 1, 2], [8, 16, 8])
mod = get_mod(data_shape=(32, ), verify(mod, (golden_data, golden_output))
data_dtype='uint8',
out_dtype='uint8', def test_upscale():
input_scale=1, for rounding in roundings:
output_scale=16, mod = get_mod(data_shape=(32, ),
rounding=rounding) data_dtype='int32',
out_dtype="int8",
# Try positive values input_scale=2,
# 8 corresponds to 0.5, resulting in 1 output_scale=1,
golden_data = np.arange(0, 32, 1).astype('int32') rounding=rounding)
golden_output = np.repeat([0, 1, 2], [8, 16, 8])
verify(mod, (golden_data, golden_output)) # Try positive values
# 8 corresponds to 0.5, resulting in 1
def upscale_test(): golden_data = np.arange(0, 32, 1).astype('int32')
for rounding in roundings: golden_output = np.multiply(2, golden_data)
mod = get_mod(data_shape=(32, ), verify(mod, (golden_data, golden_output))
data_dtype='int32',
out_dtype="int8", # Try negative values
input_scale=2, # -8 corresponds to -0.5. For UPWARD, this is 0
output_scale=1, golden_data = np.arange(0, -32, -1).astype('int32')
rounding=rounding) golden_output = np.multiply(2, golden_data)
verify(mod, (golden_data, golden_output))
# Try positive values
# 8 corresponds to 0.5, resulting in 1 def test_saturation():
golden_data = np.arange(0, 32, 1).astype('int32') for rounding in roundings:
golden_output = np.multiply(2, golden_data) mod = get_mod(data_shape=(16, ),
verify(mod, (golden_data, golden_output)) data_dtype='int32',
out_dtype="int8",
# Try negative values input_scale=0.5,
# -8 corresponds to -0.5. For UPWARD, this is 0 output_scale=0.5,
golden_data = np.arange(0, -32, -1).astype('int32') rounding=rounding)
golden_output = np.multiply(2, golden_data) golden_data = np.arange(0, 16, 1).astype('int32')
verify(mod, (golden_data, golden_output)) golden_data = np.add(120, golden_data)
output = np.array([120, 121, 122, 123, 124, 125, 126, 127,
def saturation_test(): 127, 127, 127, 127, 127, 127, 127, 127])
for rounding in roundings: golden_output = output
mod = get_mod(data_shape=(16, ), verify(mod, (golden_data, golden_output))
data_dtype='int32',
out_dtype="int8", # Try negative numbers
input_scale=0.5, golden_data = np.arange(0, -16, -1).astype('int32')
output_scale=0.5, golden_data = np.add(-120, golden_data)
rounding=rounding) output = np.array([-120, -121, -122, -123, -124, -125, -126, -127,
golden_data = np.arange(0, 16, 1).astype('int32') -128, -128, -128, -128, -128, -128, -128, -128])
golden_data = np.add(120, golden_data) golden_output = output
output = np.array([120, 121, 122, 123, 124, 125, 126, 127, verify(mod, (golden_data, golden_output))
127, 127, 127, 127, 127, 127, 127, 127])
golden_output = output def test_zero_point():
verify(mod, (golden_data, golden_output)) # Output zero point
for rounding in roundings:
# Try negative numbers mod = get_mod(data_shape=(32, ),
golden_data = np.arange(0, -16, -1).astype('int32') data_dtype='int32',
golden_data = np.add(-120, golden_data) out_dtype='int8',
output = np.array([-120, -121, -122, -123, -124, -125, -126, -127, input_scale=1,
-128, -128, -128, -128, -128, -128, -128, -128]) output_scale=16,
golden_output = output output_zero_point=1,
verify(mod, (golden_data, golden_output)) rounding=rounding)
def zero_point_test(): # Try positive values
# Output zero point # 8 corresponds to 0.5, resulting in 1
for rounding in roundings: golden_data = np.arange(0, 32, 1).astype('int32')
mod = get_mod(data_shape=(32, ), golden_output = np.repeat([0, 1, 2], [8, 16, 8])
data_dtype='int32', golden_output = np.add(1, golden_output)
out_dtype='int8', verify(mod, (golden_data, golden_output))
input_scale=1,
output_scale=16, # Try negative values
output_zero_point=1, # -8 corresponds to -0.5. For UPWARD, this is 0
rounding=rounding) golden_data = np.arange(-32, -64, -1).astype('int32')
if rounding == "UPWARD":
# Try positive values golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
# 8 corresponds to 0.5, resulting in 1 else:
golden_data = np.arange(0, 32, 1).astype('int32') golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
golden_output = np.repeat([0, 1, 2], [8, 16, 8]) golden_output = np.add(1, golden_output)
golden_output = np.add(1, golden_output) verify(mod, (golden_data, golden_output))
verify(mod, (golden_data, golden_output))
# Input zero point
# Try negative values for rounding in roundings:
# -8 corresponds to -0.5. For UPWARD, this is 0 mod = get_mod(data_shape=(32, ),
golden_data = np.arange(-32, -64, -1).astype('int32') data_dtype='int32',
if rounding == "UPWARD": out_dtype='int8',
golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) input_scale=1,
else: output_scale=16,
golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) input_zero_point=16,
golden_output = np.add(1, golden_output) rounding=rounding)
verify(mod, (golden_data, golden_output))
# Try positive values
# Input zero point golden_data = np.arange(32, 64, 1).astype('int32')
for rounding in roundings: golden_output = np.repeat([2, 3, 4], [8, 16, 8])
mod = get_mod(data_shape=(32, ), golden_output = np.subtract(golden_output, 1)
data_dtype='int32', verify(mod, (golden_data, golden_output))
out_dtype='int8',
input_scale=1, # Try negative values
output_scale=16, golden_data = np.arange(-32, -64, -1).astype('int32')
input_zero_point=16, if rounding == "UPWARD":
rounding=rounding) golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
else:
# Try positive values golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
golden_data = np.arange(32, 64, 1).astype('int32') golden_output = np.subtract(golden_output, 1)
golden_output = np.repeat([2, 3, 4], [8, 16, 8]) verify(mod, (golden_data, golden_output))
golden_output = np.subtract(golden_output, 1)
verify(mod, (golden_data, golden_output))
# Try negative values
golden_data = np.arange(-32, -64, -1).astype('int32')
if rounding == "UPWARD":
golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
else:
golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
golden_output = np.subtract(golden_output, 1)
verify(mod, (golden_data, golden_output))
same_scale_test()
downscale_test()
upscale_test()
saturation_test()
zero_point_test()
if __name__ == "__main__": if __name__ == "__main__":
test_requantize() test_same_scale()
test_downscale()
test_upscale()
test_saturation()
test_zero_point()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment