Commit 3486e2c2 by Animesh Jain Committed by Zhi

[QNN][Legalize] Specialize for Platforms without any fast Int8 arithmetic units. (#4307)

parent 8cd5ccea
...@@ -23,6 +23,14 @@ from tvm.contrib import graph_runtime ...@@ -23,6 +23,14 @@ from tvm.contrib import graph_runtime
from tvm.relay.qnn.op import register_qnn_legalize from tvm.relay.qnn.op import register_qnn_legalize
from tvm.relay import transform, analysis from tvm.relay import transform, analysis
def alpha_equal(x, y):
"""
Wrapper around alpha equality which ensures that
the hash function respects equality.
"""
x = x['main']
y = y['main']
return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes] passes = passes if isinstance(passes, list) else [passes]
...@@ -82,11 +90,11 @@ def test_qnn_legalize(): ...@@ -82,11 +90,11 @@ def test_qnn_legalize():
b = run_opt_pass(expected(), transform.InferType()) b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_qnn_legalize_qnn_conv2d(): def test_qnn_legalize_qnn_conv2d():
data_shape = (1, 64, 256, 256) def _get_mod(data_dtype, kernel_dtype):
kernel_shape = (128, 64, 3, 3) data_shape = (1, 64, 256, 256)
for dtype in ['uint8', 'int8']: kernel_shape = (128, 64, 3, 3)
data_dtype = kernel_dtype = dtype
data = relay.var("data", shape=data_shape, data = relay.var("data", shape=data_shape,
dtype=data_dtype) dtype=data_dtype)
kernel = relay.var("kernel", shape=kernel_shape, kernel = relay.var("kernel", shape=kernel_shape,
...@@ -104,12 +112,145 @@ def test_qnn_legalize_qnn_conv2d(): ...@@ -104,12 +112,145 @@ def test_qnn_legalize_qnn_conv2d():
mod = relay.Function(relay.analysis.free_vars(func), func) mod = relay.Function(relay.analysis.free_vars(func), func)
mod = relay.Module.from_expr(mod) mod = relay.Module.from_expr(mod)
return mod
# Check uint8 x uint8 and int8 x int8 transformation
for dtype in ('uint8', 'int8'):
mod = _get_mod(dtype, dtype)
#############################################################
# Check transformations for platforms with fast Int8 support.
#############################################################
# Check that Intel VNNI gets picked up.
with tvm.target.create('llvm -mcpu=skylake-avx512'): with tvm.target.create('llvm -mcpu=skylake-avx512'):
mod = relay.qnn.transform.Legalize()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext()
# Since same dtype, there should not be any transformation
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod)
################################################################
# Check transformations for platforms without fast Int8 support.
################################################################
# Older Intel versions.
with tvm.target.create('llvm'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
# Older ARM vesions.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
# Check uint8 x int8 transformation
mod = _get_mod('uint8', 'int8')
#############################################################
# Check transformations for platforms with fast Int8 support.
#############################################################
# Check no transformation for Intel VNNI.
with tvm.target.create('llvm -mcpu=skylake-avx512'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod)
# ARM - so check that transformation has happened.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext()
################################################################
# Check transformations for platforms without fast Int8 support.
################################################################
# Older Intel versions.
with tvm.target.create('llvm'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
# Older ARM vesions.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
def test_qnn_legalize_qnn_dense():
def _get_mod(data_dtype, kernel_dtype):
data_shape = (10, 3)
kernel_shape = (20, 3)
data = relay.var("data", shape=data_shape,
dtype=data_dtype)
kernel = relay.var("kernel", shape=kernel_shape,
dtype=kernel_dtype)
func = relay.qnn.op.dense(
data, kernel,
input_zero_point=1,
kernel_zero_point=1,
out_dtype='int32')
mod = relay.Function(relay.analysis.free_vars(func), func)
mod = relay.Module.from_expr(mod)
return mod
# Check uint8 x uint8 and int8 x int8 transformation
for dtype in ('uint8', 'int8'):
mod = _get_mod(dtype, dtype)
#############################################################
# Check transformations for platforms with fast Int8 support.
#############################################################
# Check that Intel VNNI gets picked up.
with tvm.target.create('llvm -mcpu=skylake-avx512'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext()
# Since same dtype, there should not be any transformation
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod)
################################################################
# Check transformations for platforms without fast Int8 support.
################################################################
# Older Intel versions.
with tvm.target.create('llvm'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
# Older ARM vesions.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
# Check uint8 x int8 transformation
mod = _get_mod('uint8', 'int8')
#############################################################
# Check transformations for platforms with fast Int8 support.
#############################################################
# Check no transformation for Intel VNNI.
with tvm.target.create('llvm -mcpu=skylake-avx512'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert alpha_equal(mod, legalized_mod)
# ARM - so check that transformation has happened.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext()
################################################################
# Check transformations for platforms without fast Int8 support.
################################################################
# Older Intel versions.
with tvm.target.create('llvm'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
# Older ARM vesions.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
assert 'cast' in mod.astext()
if __name__ == "__main__": if __name__ == "__main__":
test_qnn_legalize() test_qnn_legalize()
test_qnn_legalize_qnn_conv2d() test_qnn_legalize_qnn_conv2d()
test_qnn_legalize_qnn_dense()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment