Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3486e2c2
Commit
3486e2c2
authored
Nov 13, 2019
by
Animesh Jain
Committed by
Zhi
Nov 13, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[QNN][Legalize] Specialize for Platforms without any fast Int8 arithmetic units. (#4307)
parent
8cd5ccea
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
306 additions
and
37 deletions
+306
-37
python/tvm/relay/qnn/op/legalizations.py
+161
-33
tests/python/relay/test_pass_qnn_legalize.py
+145
-4
No files found.
python/tvm/relay/qnn/op/legalizations.py
View file @
3486e2c2
...
@@ -22,10 +22,43 @@ import tvm
...
@@ -22,10 +22,43 @@ import tvm
from
tvm
import
relay
from
tvm
import
relay
from
..
import
op
as
reg
from
..
import
op
as
reg
#################################################
# Register the functions for different operators.
#################################################
# Registering QNN Conv2D legalization function.
# Registering QNN Conv2D legalization function.
@reg.register_qnn_legalize
(
"qnn.conv2d"
)
@reg.register_qnn_legalize
(
"qnn.conv2d"
)
def
legalize_qnn_conv2d
(
attrs
,
inputs
,
types
):
def
legalize_qnn_conv2d
(
attrs
,
inputs
,
types
):
"""Legalizes QNN conv2d op.
return
qnn_conv2d_legalize
(
attrs
,
inputs
,
types
)
# Registering QNN dense legalization function.
@reg.register_qnn_legalize
(
"qnn.dense"
)
def
legalize_qnn_dense
(
attrs
,
inputs
,
types
):
return
qnn_dense_legalize
(
attrs
,
inputs
,
types
)
# Default to None. If overridden by target, this will not be run.
# Generic QNN Conv2D legalization function.
@tvm.target.generic_func
def
qnn_conv2d_legalize
(
attrs
,
inputs
,
types
):
"""Default legalization is None."""
return
None
# Generic QNN Conv2D legalization function.
@tvm.target.generic_func
def
qnn_dense_legalize
(
attrs
,
inputs
,
types
):
"""Default legalization is None."""
return
None
###################
# Helper functions.
###################
# Helper function for lowering in the abscence of fast Int8 arithmetic units.
def
helper_no_fast_int8_hw_legalization
(
attrs
,
inputs
,
types
,
relay_op
):
""" Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions
much more efficiently if the convolution or dense operator input datatypes are int16 instead of
int8. More details are present at https://github.com/apache/incubator-tvm/pull/4277.
Parameters
Parameters
----------
----------
...
@@ -41,19 +74,27 @@ def legalize_qnn_conv2d(attrs, inputs, types):
...
@@ -41,19 +74,27 @@ def legalize_qnn_conv2d(attrs, inputs, types):
result : tvm.relay.Expr
result : tvm.relay.Expr
The legalized expr
The legalized expr
"""
"""
return
qnn_conv2d_legalize
(
attrs
,
inputs
,
types
)
# Generic QNN Conv2D legalization function.
# Collect the input exprs.
@tvm.target.generic_func
data
,
kernel
=
inputs
def
qnn_conv2d_legalize
(
attrs
,
inputs
,
types
):
"""Default legalization is None."""
return
None
# Intel x86 QNN Conv2D legalization function.
input_zp
=
attrs
[
'input_zero_point'
]
@qnn_conv2d_legalize.register
(
'cpu'
)
kernel_zp
=
attrs
[
'kernel_zero_point'
]
def
_qnn_conv2d_legalize
(
attrs
,
inputs
,
types
):
"""Legalizes QNN conv2d op. VNNI supports u8 x i8 fast conv/MM. If the dtypes are already good,
shift_data
=
relay
.
subtract
(
relay
.
cast
(
data
,
dtype
=
'int16'
),
we dont transform. Else, we shift the tensor values and zero points to change the dtype.
relay
.
const
(
input_zp
,
'int16'
))
shift_kernel
=
relay
.
subtract
(
relay
.
cast
(
kernel
,
dtype
=
'int16'
),
relay
.
const
(
kernel_zp
,
'int16'
))
new_attrs
=
{
k
:
attrs
[
k
]
for
k
in
attrs
.
keys
()}
del
new_attrs
[
'kernel_zero_point'
]
del
new_attrs
[
'input_zero_point'
]
return
relay_op
(
shift_data
,
shift_kernel
,
**
new_attrs
)
# Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
def
helper_change_dtypes_to_uint8_int8
(
attrs
,
inputs
,
types
,
relay_op
):
"""Legalizes QNN conv2d/dense op for Intel HW. VNNI supports u8 x i8 fast conv/MM. If the dtypes
are already good, we dont transform. Else, we shift the tensor values and zero points to change
the dtype.
Converting from int8 to uint8 can be done in following manner.
Converting from int8 to uint8 can be done in following manner.
...
@@ -82,26 +123,18 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
...
@@ -82,26 +123,18 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
The legalized expr
The legalized expr
"""
"""
def
_shift
(
data
,
out_dtype
):
def
_shift
(
data
,
zero_point
,
out_dtype
):
"""Shifts (add/subtracts) the qnn tensor with +/-128)"""
"""Shifts (add/subtracts) the qnn tensor with +/-128)"""
if
out_dtype
==
'uint8'
:
if
out_dtype
==
'uint8'
:
shift
=
128
shift
=
128
elif
out_dtype
==
'int8'
:
elif
out_dtype
==
'int8'
:
shift
=
-
128
shift
=
-
128
else
:
else
:
raise
ValueError
(
"Unsupport out dtype."
)
raise
ValueError
(
"Unsupport
ed
out dtype."
)
data_modified
=
relay
.
cast
(
data
,
'int32'
)
data_modified
=
relay
.
cast
(
data
,
'int32'
)
data_modified
=
relay
.
add
(
data_modified
,
relay
.
const
(
shift
,
'int32'
))
data_modified
=
relay
.
add
(
data_modified
,
relay
.
const
(
shift
,
'int32'
))
data_modified
=
relay
.
cast
(
data_modified
,
out_dtype
)
data_modified
=
relay
.
cast
(
data_modified
,
out_dtype
)
return
data_modified
return
(
data_modified
,
zero_point
+
shift
)
def
_is_int8_hw_support
(
target
):
"""
Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake
and above.
"""
supported_arches
=
{
'-mcpu=skylake-avx512'
,
'-mcpu=cascadelake'
}
return
supported_arches
.
intersection
(
set
(
target
.
options
))
# Collect the dtypes.
# Collect the dtypes.
data_dtype
=
types
[
0
]
.
dtype
data_dtype
=
types
[
0
]
.
dtype
...
@@ -110,11 +143,6 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
...
@@ -110,11 +143,6 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
# Collect the input exprs.
# Collect the input exprs.
data
,
kernel
=
inputs
data
,
kernel
=
inputs
# The VNNI transformations are applicable only Skylake and above.g
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
if
not
_is_int8_hw_support
(
target
):
return
None
# VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
# VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
if
data_dtype
==
'uint8'
and
kernel_dtype
==
'int8'
:
if
data_dtype
==
'uint8'
and
kernel_dtype
==
'int8'
:
return
None
return
None
...
@@ -123,18 +151,118 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
...
@@ -123,18 +151,118 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
input_zp
=
attrs
[
'input_zero_point'
]
input_zp
=
attrs
[
'input_zero_point'
]
if
data_dtype
==
'int8'
:
if
data_dtype
==
'int8'
:
# Compute (QA + 128) and (zp_a + 128)
# Compute (QA + 128) and (zp_a + 128)
data
=
_shift
(
data
,
'uint8'
)
data
,
input_zp
=
_shift
(
data
,
input_zp
,
'uint8'
)
input_zp
=
input_zp
+
128
# Shift kernel if necessary.
# Shift kernel if necessary.
kernel_zp
=
attrs
[
'kernel_zero_point'
]
kernel_zp
=
attrs
[
'kernel_zero_point'
]
if
kernel_dtype
==
'uint8'
:
if
kernel_dtype
==
'uint8'
:
# Compute (QA - 128) and (zp_a - 128)
# Compute (QA - 128) and (zp_a - 128)
kernel
=
_shift
(
kernel
,
'int8'
)
kernel
,
kernel_zp
=
_shift
(
kernel
,
kernel_zp
,
'int8'
)
kernel_zp
=
kernel_zp
-
128
# Call qnn.conv2d with modified inputs and zero points.
# Call qnn.conv2d with modified inputs and zero points.
new_attrs
=
{
k
:
attrs
[
k
]
for
k
in
attrs
.
keys
()}
new_attrs
=
{
k
:
attrs
[
k
]
for
k
in
attrs
.
keys
()}
new_attrs
[
'input_zero_point'
]
=
input_zp
new_attrs
[
'input_zero_point'
]
=
input_zp
new_attrs
[
'kernel_zero_point'
]
=
kernel_zp
new_attrs
[
'kernel_zero_point'
]
=
kernel_zp
return
relay
.
qnn
.
op
.
conv2d
(
data
,
kernel
,
**
new_attrs
)
return
relay_op
(
data
,
kernel
,
**
new_attrs
)
# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
def
helper_change_dtypes_to_be_same
(
attrs
,
inputs
,
types
,
relay_op
):
""" Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
conv2d/dense such that both the dtypes are same.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
def
_shift
(
data
,
zero_point
,
out_dtype
):
"""Shifts (adds/subtracts) the qnn tensor by 128)"""
if
out_dtype
==
'uint8'
:
shift
=
128
elif
out_dtype
==
'int8'
:
shift
=
-
128
else
:
raise
ValueError
(
"Unsupported out dtype."
)
data_modified
=
relay
.
cast
(
data
,
'int32'
)
data_modified
=
relay
.
add
(
data_modified
,
relay
.
const
(
shift
,
'int32'
))
data_modified
=
relay
.
cast
(
data_modified
,
out_dtype
)
return
(
data_modified
,
zero_point
+
shift
)
# Collect the dtypes.
data_dtype
=
types
[
0
]
.
dtype
kernel_dtype
=
types
[
1
]
.
dtype
if
data_dtype
==
kernel_dtype
:
return
None
# Collect the input exprs.
data
,
kernel
=
inputs
assert
'int8'
in
data_dtype
and
'int8'
in
kernel_dtype
,
\
"Qnn Conv2D/Dense only accepts uint8 or int8 inputs"
# Shift input if necessary.
input_zp
=
attrs
[
'input_zero_point'
]
data
,
input_zp
=
_shift
(
data
,
input_zp
,
kernel_dtype
)
new_attrs
=
{
k
:
attrs
[
k
]
for
k
in
attrs
.
keys
()}
new_attrs
[
'input_zero_point'
]
=
input_zp
return
relay_op
(
data
,
kernel
,
**
new_attrs
)
def
is_fast_int8_on_intel
():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
intel_supported_arches
=
{
'-mcpu=skylake-avx512'
,
'-mcpu=cascadelake'
}
return
intel_supported_arches
.
intersection
(
set
(
target
.
options
))
def
is_fast_int8_on_arm
():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
return
'+v8.2a,+dotprod'
in
' '
.
join
(
target
.
options
)
########################
# ARM CPU legalizations.
########################
@qnn_conv2d_legalize.register
(
'arm_cpu'
)
def
_qnn_conv2d_legalize_arm_cpu
(
attrs
,
inputs
,
types
):
# ARM prefers the dtypes to be same.
if
is_fast_int8_on_arm
():
return
helper_change_dtypes_to_be_same
(
attrs
,
inputs
,
types
,
relay
.
qnn
.
op
.
conv2d
)
return
helper_no_fast_int8_hw_legalization
(
attrs
,
inputs
,
types
,
relay
.
nn
.
conv2d
)
@qnn_dense_legalize.register
(
'arm_cpu'
)
def
_qnn_dense_legalize_arm_cpu
(
attrs
,
inputs
,
types
):
# ARM prefers the dtypes to be same.
if
is_fast_int8_on_arm
():
return
helper_change_dtypes_to_be_same
(
attrs
,
inputs
,
types
,
relay
.
qnn
.
op
.
dense
)
return
helper_no_fast_int8_hw_legalization
(
attrs
,
inputs
,
types
,
relay
.
nn
.
dense
)
##########################
# Intel CPU legalizations.
##########################
@qnn_conv2d_legalize.register
(
'cpu'
)
def
_qnn_conv2d_legalize_intel_cpu
(
attrs
,
inputs
,
types
):
# The VNNI transformations prefer uint8 x int8 datatypes.
if
is_fast_int8_on_intel
():
return
helper_change_dtypes_to_uint8_int8
(
attrs
,
inputs
,
types
,
relay
.
qnn
.
op
.
conv2d
)
return
helper_no_fast_int8_hw_legalization
(
attrs
,
inputs
,
types
,
relay
.
nn
.
conv2d
)
@qnn_dense_legalize.register
(
'cpu'
)
def
_qnn_dense_legalize_intel_cpu
(
attrs
,
inputs
,
types
):
# The VNNI transformations prefer uint8 x int8 datatypes.
if
is_fast_int8_on_intel
():
return
helper_change_dtypes_to_uint8_int8
(
attrs
,
inputs
,
types
,
relay
.
qnn
.
op
.
dense
)
return
helper_no_fast_int8_hw_legalization
(
attrs
,
inputs
,
types
,
relay
.
nn
.
dense
)
tests/python/relay/test_pass_qnn_legalize.py
View file @
3486e2c2
...
@@ -23,6 +23,14 @@ from tvm.contrib import graph_runtime
...
@@ -23,6 +23,14 @@ from tvm.contrib import graph_runtime
from
tvm.relay.qnn.op
import
register_qnn_legalize
from
tvm.relay.qnn.op
import
register_qnn_legalize
from
tvm.relay
import
transform
,
analysis
from
tvm.relay
import
transform
,
analysis
def
alpha_equal
(
x
,
y
):
"""
Wrapper around alpha equality which ensures that
the hash function respects equality.
"""
x
=
x
[
'main'
]
y
=
y
[
'main'
]
return
analysis
.
alpha_equal
(
x
,
y
)
and
analysis
.
structural_hash
(
x
)
==
analysis
.
structural_hash
(
y
)
def
run_opt_pass
(
expr
,
passes
):
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
...
@@ -82,11 +90,11 @@ def test_qnn_legalize():
...
@@ -82,11 +90,11 @@ def test_qnn_legalize():
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_qnn_legalize_qnn_conv2d
():
def
test_qnn_legalize_qnn_conv2d
():
def
_get_mod
(
data_dtype
,
kernel_dtype
):
data_shape
=
(
1
,
64
,
256
,
256
)
data_shape
=
(
1
,
64
,
256
,
256
)
kernel_shape
=
(
128
,
64
,
3
,
3
)
kernel_shape
=
(
128
,
64
,
3
,
3
)
for
dtype
in
[
'uint8'
,
'int8'
]:
data_dtype
=
kernel_dtype
=
dtype
data
=
relay
.
var
(
"data"
,
shape
=
data_shape
,
data
=
relay
.
var
(
"data"
,
shape
=
data_shape
,
dtype
=
data_dtype
)
dtype
=
data_dtype
)
kernel
=
relay
.
var
(
"kernel"
,
shape
=
kernel_shape
,
kernel
=
relay
.
var
(
"kernel"
,
shape
=
kernel_shape
,
...
@@ -104,12 +112,145 @@ def test_qnn_legalize_qnn_conv2d():
...
@@ -104,12 +112,145 @@ def test_qnn_legalize_qnn_conv2d():
mod
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
func
),
func
)
mod
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
func
),
func
)
mod
=
relay
.
Module
.
from_expr
(
mod
)
mod
=
relay
.
Module
.
from_expr
(
mod
)
return
mod
# Check uint8 x uint8 and int8 x int8 transformation
for
dtype
in
(
'uint8'
,
'int8'
):
mod
=
_get_mod
(
dtype
,
dtype
)
#############################################################
# Check transformations for platforms with fast Int8 support.
#############################################################
# Check that Intel VNNI gets picked up.
with
tvm
.
target
.
create
(
'llvm -mcpu=skylake-avx512'
):
with
tvm
.
target
.
create
(
'llvm -mcpu=skylake-avx512'
):
mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
'cast'
in
legalized_mod
.
astext
()
and
"qnn.conv2d"
in
legalized_mod
.
astext
()
# Since same dtype, there should not be any transformation
with
tvm
.
target
.
create
(
'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
alpha_equal
(
mod
,
legalized_mod
)
################################################################
# Check transformations for platforms without fast Int8 support.
################################################################
# Older Intel versions.
with
tvm
.
target
.
create
(
'llvm'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
'cast'
in
legalized_mod
.
astext
()
and
"qnn"
not
in
legalized_mod
.
astext
()
# Older ARM vesions.
with
tvm
.
target
.
create
(
'llvm -device=arm_cpu -target=aarch64-linux-gnu'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
'cast'
in
legalized_mod
.
astext
()
and
"qnn"
not
in
legalized_mod
.
astext
()
# Check uint8 x int8 transformation
mod
=
_get_mod
(
'uint8'
,
'int8'
)
#############################################################
# Check transformations for platforms with fast Int8 support.
#############################################################
# Check no transformation for Intel VNNI.
with
tvm
.
target
.
create
(
'llvm -mcpu=skylake-avx512'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
alpha_equal
(
mod
,
legalized_mod
)
# ARM - so check that transformation has happened.
with
tvm
.
target
.
create
(
'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
'cast'
in
legalized_mod
.
astext
()
and
"qnn.conv2d"
in
legalized_mod
.
astext
()
################################################################
# Check transformations for platforms without fast Int8 support.
################################################################
# Older Intel versions.
with
tvm
.
target
.
create
(
'llvm'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
'cast'
in
legalized_mod
.
astext
()
and
"qnn"
not
in
legalized_mod
.
astext
()
# Older ARM vesions.
with
tvm
.
target
.
create
(
'llvm -device=arm_cpu -target=aarch64-linux-gnu'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
'cast'
in
legalized_mod
.
astext
()
and
"qnn"
not
in
legalized_mod
.
astext
()
def
test_qnn_legalize_qnn_dense
():
def
_get_mod
(
data_dtype
,
kernel_dtype
):
data_shape
=
(
10
,
3
)
kernel_shape
=
(
20
,
3
)
data
=
relay
.
var
(
"data"
,
shape
=
data_shape
,
dtype
=
data_dtype
)
kernel
=
relay
.
var
(
"kernel"
,
shape
=
kernel_shape
,
dtype
=
kernel_dtype
)
func
=
relay
.
qnn
.
op
.
dense
(
data
,
kernel
,
input_zero_point
=
1
,
kernel_zero_point
=
1
,
out_dtype
=
'int32'
)
mod
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
func
),
func
)
mod
=
relay
.
Module
.
from_expr
(
mod
)
return
mod
# Check uint8 x uint8 and int8 x int8 transformation
for
dtype
in
(
'uint8'
,
'int8'
):
mod
=
_get_mod
(
dtype
,
dtype
)
#############################################################
# Check transformations for platforms with fast Int8 support.
#############################################################
# Check that Intel VNNI gets picked up.
with
tvm
.
target
.
create
(
'llvm -mcpu=skylake-avx512'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
'cast'
in
legalized_mod
.
astext
()
and
"qnn.dense"
in
legalized_mod
.
astext
()
# Since same dtype, there should not be any transformation
with
tvm
.
target
.
create
(
'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
alpha_equal
(
mod
,
legalized_mod
)
################################################################
# Check transformations for platforms without fast Int8 support.
################################################################
# Older Intel versions.
with
tvm
.
target
.
create
(
'llvm'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
'cast'
in
legalized_mod
.
astext
()
and
"qnn"
not
in
legalized_mod
.
astext
()
# Older ARM vesions.
with
tvm
.
target
.
create
(
'llvm -device=arm_cpu -target=aarch64-linux-gnu'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
'cast'
in
legalized_mod
.
astext
()
and
"qnn"
not
in
legalized_mod
.
astext
()
# Check uint8 x int8 transformation
mod
=
_get_mod
(
'uint8'
,
'int8'
)
#############################################################
# Check transformations for platforms with fast Int8 support.
#############################################################
# Check no transformation for Intel VNNI.
with
tvm
.
target
.
create
(
'llvm -mcpu=skylake-avx512'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
alpha_equal
(
mod
,
legalized_mod
)
# ARM - so check that transformation has happened.
with
tvm
.
target
.
create
(
'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
'cast'
in
legalized_mod
.
astext
()
and
"qnn.dense"
in
legalized_mod
.
astext
()
################################################################
# Check transformations for platforms without fast Int8 support.
################################################################
# Older Intel versions.
with
tvm
.
target
.
create
(
'llvm'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
'cast'
in
legalized_mod
.
astext
()
and
"qnn"
not
in
legalized_mod
.
astext
()
# Older ARM vesions.
with
tvm
.
target
.
create
(
'llvm -device=arm_cpu -target=aarch64-linux-gnu'
):
legalized_mod
=
relay
.
qnn
.
transform
.
Legalize
()(
mod
)
assert
'cast'
in
legalized_mod
.
astext
()
and
"qnn"
not
in
legalized_mod
.
astext
()
assert
'cast'
in
mod
.
astext
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_qnn_legalize
()
test_qnn_legalize
()
test_qnn_legalize_qnn_conv2d
()
test_qnn_legalize_qnn_conv2d
()
test_qnn_legalize_qnn_dense
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment