Commit 6f9d028b by Zhi Committed by Haichen Shen

[Relay][QNN] Add unit test for int8 (#4159)

* [bugfix][codegen] fix casting bug in llvm codegen

* update example

* retrigger ci

* check llvm version
parent e0d286a1
...@@ -160,7 +160,7 @@ def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, ...@@ -160,7 +160,7 @@ def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
qnn_output = get_output(qnn_func, golden_inputs) qnn_output = get_output(qnn_func, golden_inputs)
np.testing.assert_equal(qnn_output, golden_output) np.testing.assert_equal(qnn_output, golden_output)
def no_zero_point_test(): def test_no_zero_point():
# uint8 input # uint8 input
data_shape = (2, 1, 2, 4) data_shape = (2, 1, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -203,7 +203,7 @@ def no_zero_point_test(): ...@@ -203,7 +203,7 @@ def no_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def kernel_zero_point_test(): def test_kernel_zero_point():
# uint8 input # uint8 input
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -247,7 +247,7 @@ def kernel_zero_point_test(): ...@@ -247,7 +247,7 @@ def kernel_zero_point_test():
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def input_zero_point_test(): def test_input_zero_point():
# uint8 input # uint8 input
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -290,7 +290,7 @@ def input_zero_point_test(): ...@@ -290,7 +290,7 @@ def input_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def both_zero_point_test(): def test_both_zero_point():
# uint8 input # uint8 input
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -333,7 +333,7 @@ def both_zero_point_test(): ...@@ -333,7 +333,7 @@ def both_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def layout_test(): def test_layout():
# uint8 input # uint8 input
data_shape = (2, 2, 4, 4) # NHWC data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -378,7 +378,7 @@ def layout_test(): ...@@ -378,7 +378,7 @@ def layout_test():
def padding_test(): def test_padding():
# uint8 input # uint8 input
data_shape = (1, 4, 2, 2) data_shape = (1, 4, 2, 2)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -421,7 +421,7 @@ def padding_test(): ...@@ -421,7 +421,7 @@ def padding_test():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def dilation_test(): def test_dilation():
# uint8 input # uint8 input
data_shape = (2, 4, 4, 4) data_shape = (2, 4, 4, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -444,7 +444,7 @@ def dilation_test(): ...@@ -444,7 +444,7 @@ def dilation_test():
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def const_folding_test(): def test_const_folding():
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2) kernel_shape = (3, 4, 2, 2)
...@@ -470,7 +470,7 @@ def const_folding_test(): ...@@ -470,7 +470,7 @@ def const_folding_test():
folded_func = folded_mod["main"] folded_func = folded_mod["main"]
assert "reshape" not in folded_func.astext() assert "reshape" not in folded_func.astext()
def kernel_size_1x1_test(): def test_kernel_size_1x1():
# uint8 input # uint8 input
data_shape = (2, 4, 2, 4) data_shape = (2, 4, 2, 4)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -493,7 +493,7 @@ def kernel_size_1x1_test(): ...@@ -493,7 +493,7 @@ def kernel_size_1x1_test():
verify(ref_func, qnn_func, data_shape, data_dtype, verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype) kernel_shape, kernel_dtype)
def tflite_large_irregular_test(): def test_tflite_large_irregular():
# uint8 input # uint8 input
data_shape = (1, 1024, 1, 1) data_shape = (1, 1024, 1, 1)
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -607,7 +607,7 @@ def tflite_anistropic_strides(): ...@@ -607,7 +607,7 @@ def tflite_anistropic_strides():
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2) golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
np.testing.assert_equal(qnn_output, golden_output) np.testing.assert_equal(qnn_output, golden_output)
def broadcast_layout_test(): def test_broadcast_layout():
# Test broadcast support for NHWC layout. # Test broadcast support for NHWC layout.
data_shape = (1, 229, 229, 3) # NHWC data_shape = (1, 229, 229, 3) # NHWC
data_dtype = 'uint8' data_dtype = 'uint8'
...@@ -640,17 +640,52 @@ def broadcast_layout_test(): ...@@ -640,17 +640,52 @@ def broadcast_layout_test():
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512") graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
def test_conv2d_int8():
target = "llvm -mcpu=core-avx2"
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
data = relay.var("data", shape=(1, 28, 28, 128), dtype='uint8')
kernel = relay.var("w", shape=(3, 3, 128, 256), dtype='int8')
conv = relay.nn.conv2d(
data,
kernel,
kernel_size=(3, 3),
out_dtype='int32',
data_layout='NHWC',
kernel_layout='HWIO')
func = relay.Function([data, kernel], conv)
with relay.build_config(opt_level=0):
params = {"w": np.zeros((3, 3, 128, 256)).astype("int8")}
# -mcpu should be specified to avoid the llvm jitting error here:
# https://discuss.tvm.ai/t/segfault-in-llvm/3567
# To use VNNI, we need to specify the micro-architecture that supports
# it, e.g. cascadelake.
graph, lib, params = relay.build(func, target, params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input("data", np.zeros((1, 28, 28, 128)).astype("uint8"))
mod.set_input(**params)
mod.run()
qnn_output = mod.get_output(0).asnumpy()
golden_output = np.zeros((1, 26, 26, 256)).astype("int32")
np.testing.assert_equal(qnn_output, golden_output)
if __name__ == "__main__": if __name__ == "__main__":
no_zero_point_test() test_no_zero_point()
input_zero_point_test() test_input_zero_point()
kernel_zero_point_test() test_kernel_zero_point()
both_zero_point_test() test_both_zero_point()
layout_test() test_layout()
padding_test() test_padding()
dilation_test() test_dilation()
const_folding_test() test_const_folding()
kernel_size_1x1_test() test_kernel_size_1x1g()
tflite_large_irregular_test() test_tflite_large_irregularg()
tflite_output_multiplier_greater_than_one() test_tflite_output_multiplier_greater_than_one()
tflite_anistropic_strides() test_tflite_anistropic_strides()
broadcast_layout_test() test_broadcast_layoutg()
test_conv2d_int8()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment