Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b0ddcff6
Commit
b0ddcff6
authored
Sep 18, 2019
by
Animesh Jain
Committed by
Yizhi Liu
Sep 19, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Legalize and AlterOpLayout for Int8 Intel. (#3961)
parent
92439166
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
164 additions
and
177 deletions
+164
-177
tests/python/relay/test_op_level2.py
+61
-22
topi/python/topi/nn/conv2d.py
+1
-1
topi/python/topi/x86/__init__.py
+1
-0
topi/python/topi/x86/conv2d.py
+3
-154
topi/python/topi/x86/conv2d_alter_op.py
+0
-0
topi/python/topi/x86/conv2d_avx_1x1.py
+30
-0
topi/python/topi/x86/conv2d_avx_common.py
+28
-0
topi/python/topi/x86/conv2d_int8.py
+40
-0
No files found.
tests/python/relay/test_op_level2.py
View file @
b0ddcff6
...
...
@@ -541,18 +541,35 @@ def test_upsampling():
def
test_conv2d_int8_intrinsics
():
def
_compile
(
input_dtype
,
weight_dtype
,
output_dtype
,
target
):
n
,
ic
,
h
,
w
,
oc
,
ch
,
cw
=
1
,
16
,
224
,
224
,
32
,
3
,
3
def
_compile
(
ic
,
oc
,
target
,
data_layout
,
kernel_layout
,
dtypes
):
input_dtype
,
weight_dtype
,
output_dtype
=
dtypes
n
,
h
,
w
,
ch
,
cw
=
1
,
64
,
64
,
3
,
3
if
data_layout
==
'NCHW'
:
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
ic
,
h
,
w
),
input_dtype
))
w
=
relay
.
var
(
"w"
,
relay
.
TensorType
((
oc
,
ic
,
ch
,
cw
),
weight_dtype
))
elif
data_layout
==
'NHWC'
:
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
h
,
w
,
ic
),
input_dtype
))
else
:
raise
ValueError
(
'Not supported'
)
if
kernel_layout
==
'OIHW'
:
kernel_shape
=
(
oc
,
ic
,
ch
,
cw
)
elif
kernel_layout
==
'HWIO'
:
kernel_shape
=
(
ch
,
cw
,
ic
,
oc
)
else
:
raise
ValueError
(
'Not supported'
)
w
=
relay
.
var
(
"w"
,
relay
.
TensorType
(
kernel_shape
,
weight_dtype
))
y
=
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
ch
,
cw
),
channels
=
oc
,
padding
=
(
1
,
1
),
dilation
=
(
1
,
1
),
data_layout
=
data_layout
,
kernel_layout
=
kernel_layout
,
out_dtype
=
output_dtype
)
func
=
relay
.
Function
([
x
,
w
],
y
)
wdata
=
np
.
random
.
rand
(
oc
,
ic
,
ch
,
cw
)
*
10
wdata
=
np
.
random
.
rand
(
*
kernel_shape
)
*
10
parameters
=
{
"w"
:
tvm
.
nd
.
array
(
wdata
.
astype
(
weight_dtype
))}
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
parameters
)
...
...
@@ -564,37 +581,59 @@ def test_conv2d_int8_intrinsics():
name
=
"llvm.x86.avx512.pmaddubs.w.512"
llvm_id
=
tvm
.
codegen
.
llvm_lookup_intrinsic_id
(
name
)
if
llvm_id
!=
0
:
# Intel Int8 instruction need uint8 data and int8 kernel
asm
=
_compile
(
input_dtype
=
"uint8"
,
weight_dtype
=
"int8"
,
output_dtype
=
"int32"
,
target
=
target
)
# Check that intrinisic is present in the assembly.
fast_int8_dtypes
=
(
'uint8'
,
'int8'
,
'int32'
)
# Sweep the input channels to check int8 robustness
for
ic
in
range
(
1
,
24
):
asm
=
_compile
(
ic
=
ic
,
oc
=
32
,
target
=
target
,
data_layout
=
"NCHW"
,
kernel_layout
=
'OIHW'
,
dtypes
=
fast_int8_dtypes
)
assert
"pmaddubs"
in
asm
for
ic
in
range
(
1
,
24
):
asm
=
_compile
(
ic
=
ic
,
oc
=
32
,
target
=
target
,
data_layout
=
"NHWC"
,
kernel_layout
=
'HWIO'
,
dtypes
=
fast_int8_dtypes
)
assert
"pmaddubs"
in
asm
# Sweep the output channels to check int8 robustness
for
oc
in
range
(
2
,
24
):
asm
=
_compile
(
ic
=
16
,
oc
=
oc
,
target
=
target
,
data_layout
=
"NCHW"
,
kernel_layout
=
'OIHW'
,
dtypes
=
fast_int8_dtypes
)
assert
"pmaddubs"
in
asm
for
oc
in
range
(
2
,
24
):
asm
=
_compile
(
ic
=
16
,
oc
=
oc
,
target
=
target
,
data_layout
=
"NHWC"
,
kernel_layout
=
'HWIO'
,
dtypes
=
fast_int8_dtypes
)
assert
"pmaddubs"
in
asm
# Check that both non-divisible oc and ic work
asm
=
_compile
(
ic
=
17
,
oc
=
29
,
target
=
target
,
data_layout
=
"NCHW"
,
kernel_layout
=
'OIHW'
,
dtypes
=
fast_int8_dtypes
)
assert
"pmaddubs"
in
asm
asm
=
_compile
(
ic
=
17
,
oc
=
29
,
target
=
target
,
data_layout
=
"NHWC"
,
kernel_layout
=
'HWIO'
,
dtypes
=
fast_int8_dtypes
)
assert
"pmaddubs"
in
asm
# Ensure that code is generated when datatypes are not HW supported.
asm
=
_compile
(
input_dtype
=
"int8"
,
weight_dtype
=
"int8"
,
output_dtype
=
"int32"
,
target
=
target
)
dtypes
=
(
'int8'
,
'int8'
,
'int32'
)
asm
=
_compile
(
ic
=
16
,
oc
=
32
,
target
=
target
,
data_layout
=
"NHWC"
,
kernel_layout
=
'HWIO'
,
dtypes
=
dtypes
)
# Check that intrinisic is not present in the assembly.
assert
"pmaddubs"
not
in
asm
# Ensure that code is generated when datatypes are not HW supported.
asm
=
_compile
(
input_dtype
=
"uint8"
,
weight_dtype
=
"uint8"
,
output_dtype
=
"int32"
,
target
=
target
)
dtypes
=
(
'uint8'
,
'uint8'
,
'int32'
)
asm
=
_compile
(
ic
=
16
,
oc
=
32
,
target
=
target
,
data_layout
=
"NHWC"
,
kernel_layout
=
'HWIO'
,
dtypes
=
dtypes
)
# Check that intrinisic is not present in the assembly.
assert
"pmaddubs"
not
in
asm
# Check that a vectorized instruction is generated for older Intel
# generations, because we default to NCHWc layout.
target
=
"llvm -mcpu=core-avx2"
asm
=
_compile
(
input_dtype
=
"int8"
,
weight_dtype
=
"int8"
,
output_dtype
=
"int32"
,
target
=
target
)
fast_int8_dtypes
=
(
'uint8'
,
'int8'
,
'int32'
)
asm
=
_compile
(
ic
=
16
,
oc
=
32
,
target
=
target
,
data_layout
=
"NCHW"
,
kernel_layout
=
'OIHW'
,
dtypes
=
fast_int8_dtypes
)
# Check that vector int mult and add instructions are generated.
assert
"vpmulld"
in
asm
and
"vpadd"
in
asm
...
...
topi/python/topi/nn/conv2d.py
View file @
b0ddcff6
...
...
@@ -151,7 +151,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
if
data_layout
==
'NCHW'
:
CO
,
CIG
,
KH
,
KW
=
[
x
.
value
for
x
in
kernel
.
shape
]
else
:
KH
,
KW
,
C
O
,
CIG
=
[
x
.
value
for
x
in
kernel
.
shape
]
KH
,
KW
,
C
IG
,
CO
=
[
x
.
value
for
x
in
kernel
.
shape
]
HPAD
,
WPAD
,
_
,
_
=
get_pad_tuple
(
padding
,
kernel
)
GRPS
=
CI
//
CIG
...
...
topi/python/topi/x86/__init__.py
View file @
b0ddcff6
...
...
@@ -17,3 +17,4 @@ from .batch_matmul import schedule_batch_matmul
from
.roi_align
import
roi_align_nchw
from
.conv2d_transpose
import
_schedule_conv2d_transpose_nchw
from
.sparse
import
*
from
.conv2d_alter_op
import
*
topi/python/topi/x86/conv2d.py
View file @
b0ddcff6
...
...
@@ -26,40 +26,16 @@ from tvm.autotvm.task.topi_integration import deserialize_args
from
tvm.autotvm.task
import
get_config
from
..
import
generic
,
tag
from
..
import
nn
from
..util
import
get_const_tuple
,
get_shape
from
..nn.conv2d
import
conv2d
,
conv2d_NCHWc
,
conv2d_NCHWc_int8
,
\
conv2d_alter_layout
,
conv2d_infer_layout
,
_get_workload
as
_get_conv2d_workload
from
..nn.conv2d
import
conv2d
,
conv2d_NCHWc
,
\
conv2d_infer_layout
,
_get_workload
as
_get_conv2d_workload
from
..nn.depthwise_conv2d
import
_get_workload
as
_get_depthwise_conv2d_workload
from
..nn.depthwise_conv2d
import
depthwise_conv2d_NCHWc
,
depthwise_conv2d_nchw
from
..nn.pad
import
pad
from
..util
import
get_const_tuple
from
.
import
conv2d_avx_1x1
,
conv2d_avx_common
logger
=
logging
.
getLogger
(
'topi'
)
def
_is_int8_hw_support
(
data_dtype
,
kernel_dtype
,
target
):
"""
Checks to ensure that we can use Intel DLBoost instructions
1) The datatypes are correct.
2) LLVM version has support for the instructions.
3) Target is skylake and above.
"""
# 1) Check datatypes
is_dtype_support
=
data_dtype
==
'uint8'
and
kernel_dtype
==
'int8'
# 2) Check LLVM support
llvm_intrin_fast_int8
=
"llvm.x86.avx512.pmaddubs.w.512"
llvm_id
=
tvm
.
codegen
.
llvm_lookup_intrinsic_id
(
llvm_intrin_fast_int8
)
is_llvm_support
=
llvm_id
!=
0
# 3) Check target
is_target_support
=
False
for
opt
in
target
.
options
:
if
opt
==
'-mcpu=skylake-avx512'
:
is_target_support
=
True
return
is_dtype_support
and
is_llvm_support
and
is_target_support
def
_get_default_config
(
cfg
,
data
,
kernel
,
strides
,
padding
,
out_dtype
,
is_depthwise
=
False
,
layout
=
'NCHW'
):
"""
...
...
@@ -353,133 +329,6 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
return
s
,
[
new_data
,
new_kernel
,
C
]
@conv2d_alter_layout.register
(
"cpu"
)
def
_alter_conv2d_layout
(
attrs
,
inputs
,
tinfo
,
F
):
copy_inputs
=
[
s
for
s
in
inputs
]
new_attrs
=
{
k
:
attrs
[
k
]
for
k
in
attrs
.
keys
()}
if
F
.
__name__
==
'tvm.relay.op'
:
# Derive channels for frontends (e.g ONNX) that miss "channel" field.
new_attrs
[
"channels"
]
=
inputs
[
1
]
.
checked_type
.
shape
[
attrs
[
'kernel_layout'
]
.
index
(
'O'
)]
data
,
kernel
=
tinfo
[
0
],
tinfo
[
1
]
batch_size
,
in_channel
,
height
,
width
=
get_const_tuple
(
data
.
shape
)
groups
=
attrs
.
get_int
(
"groups"
)
out_channel
=
attrs
.
get_int
(
"channels"
)
\
if
F
.
__name__
==
'nnvm.symbol'
else
new_attrs
[
"channels"
]
padding
=
attrs
.
get_int_tuple
(
"padding"
)
strides
=
attrs
.
get_int_tuple
(
"strides"
)
dilation
=
attrs
.
get_int_tuple
(
"dilation"
)
out_dtype
=
attrs
[
"out_dtype"
]
layout_name
=
'layout'
if
F
.
__name__
==
'nnvm.symbol'
else
'data_layout'
layout
=
attrs
[
layout_name
]
kh
,
kw
=
attrs
.
get_int_tuple
(
"kernel_size"
)
dtype
=
data
.
dtype
out_dtype
=
dtype
if
out_dtype
in
(
"same"
,
""
)
else
out_dtype
kshape
=
get_shape
(
kernel
.
shape
,
attrs
[
"kernel_layout"
],
"OIHW"
)
is_depthwise
=
groups
==
kshape
[
0
]
and
kshape
[
1
]
==
1
# only optimize for NCHW
if
layout
!=
'NCHW'
or
attrs
[
"kernel_layout"
]
!=
"OIHW"
:
return
None
if
groups
!=
1
and
not
is_depthwise
:
return
None
dispatch_ctx
=
autotvm
.
task
.
DispatchContext
.
current
target
=
tvm
.
target
.
current_target
()
# query schedule and fallback if necessary
workload
=
autotvm
.
task
.
args_to_workload
(
[
data
,
kernel
,
strides
,
padding
,
dilation
,
out_dtype
],
depthwise_conv2d_nchw
)
\
if
is_depthwise
else
\
autotvm
.
task
.
args_to_workload
(
[
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_dtype
],
conv2d
)
cfg
=
dispatch_ctx
.
query
(
target
,
workload
)
if
cfg
.
is_fallback
:
_get_default_config
(
cfg
,
data
,
kernel
,
strides
,
padding
,
out_dtype
,
is_depthwise
)
ic_bn
,
oc_bn
=
cfg
[
"tile_ic"
]
.
size
[
-
1
],
cfg
[
"tile_oc"
]
.
size
[
-
1
]
new_attrs
[
layout_name
]
=
'NCHW
%
dc'
%
ic_bn
new_attrs
[
'out_layout'
]
=
'NCHW
%
dc'
%
oc_bn
new_data
=
tvm
.
placeholder
((
batch_size
,
in_channel
//
ic_bn
,
height
,
width
,
ic_bn
),
dtype
=
data
.
dtype
)
if
is_depthwise
:
new_attrs
[
'kernel_layout'
]
=
'OIHW1i
%
do'
%
oc_bn
# Store altered operator's config
new_kernel
=
tvm
.
placeholder
((
out_channel
//
oc_bn
,
1
,
kh
,
kw
,
1
,
oc_bn
),
dtype
=
kernel
.
dtype
)
new_workload
=
autotvm
.
task
.
args_to_workload
(
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
new_attrs
[
'out_layout'
],
out_dtype
],
depthwise_conv2d_NCHWc
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
if
F
.
__name__
==
'nnvm.symbol'
:
logging
.
warning
(
"Use native layout for depthwise convolution on NNVM."
)
return
None
return
F
.
nn
.
contrib_depthwise_conv2d_nchwc
(
*
copy_inputs
,
**
new_attrs
)
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
):
# Convert kernel data layout from 4D to 7D
n_elems
=
4
out_channel
,
_
,
kh
,
kw
=
get_const_tuple
(
kernel
.
shape
)
data_expr
,
kernel_expr
=
inputs
kernel_IHWO
=
F
.
transpose
(
kernel_expr
,
axes
=
(
1
,
2
,
3
,
0
))
kernel_IHWOo
=
F
.
reshape
(
kernel_IHWO
,
(
in_channel
,
kh
,
kw
,
out_channel
//
oc_bn
,
oc_bn
))
kernel_OHWoI
=
F
.
transpose
(
kernel_IHWOo
,
axes
=
(
3
,
1
,
2
,
4
,
0
))
kernel_OHWoIi
=
F
.
reshape
(
kernel_OHWoI
,
(
out_channel
//
oc_bn
,
kh
,
kw
,
oc_bn
,
in_channel
//
ic_bn
,
ic_bn
))
kernel_OHWoIie
=
F
.
reshape
(
kernel_OHWoIi
,
(
out_channel
//
oc_bn
,
kh
,
kw
,
oc_bn
,
in_channel
//
ic_bn
,
ic_bn
//
n_elems
,
n_elems
))
kernel_OIHWioe
=
F
.
transpose
(
kernel_OHWoIie
,
axes
=
(
0
,
4
,
1
,
2
,
5
,
3
,
6
))
copy_inputs
=
[
data_expr
,
kernel_OIHWioe
]
# Store altered operator's config. New kernel layout OIHWio4
new_kernel
=
tvm
.
placeholder
((
out_channel
//
oc_bn
,
in_channel
//
ic_bn
,
kh
,
kw
,
ic_bn
//
n_elems
,
oc_bn
,
n_elems
),
dtype
=
kernel
.
dtype
)
new_workload
=
autotvm
.
task
.
args_to_workload
([
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
new_attrs
[
'out_layout'
],
out_dtype
],
conv2d_NCHWc_int8
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
if
F
.
__name__
==
'nnvm.symbol'
:
logging
.
warning
(
"Use native layout for int8 convolution on NNVM."
)
return
None
return
F
.
nn
.
contrib_conv2d_nchwc_int8
(
*
copy_inputs
,
**
new_attrs
)
out_channel
,
_
,
kh
,
kw
=
get_const_tuple
(
kernel
.
shape
)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs
[
'kernel_layout'
]
=
'OIHW
%
di
%
do'
%
(
ic_bn
,
oc_bn
)
# Store altered operator's config
new_kernel
=
tvm
.
placeholder
((
out_channel
//
oc_bn
,
in_channel
//
ic_bn
,
kh
,
kw
,
ic_bn
,
oc_bn
),
dtype
=
kernel
.
dtype
)
new_workload
=
autotvm
.
task
.
args_to_workload
(
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
new_attrs
[
'out_layout'
],
out_dtype
],
conv2d_NCHWc
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
if
F
.
__name__
==
'nnvm.symbol'
:
return
F
.
contrib
.
conv2d_NCHWc
(
*
copy_inputs
,
**
new_attrs
)
return
F
.
nn
.
contrib_conv2d_nchwc
(
*
copy_inputs
,
**
new_attrs
)
@conv2d_infer_layout.register
(
"cpu"
)
def
_conv2d_infer_layout
(
workload
,
cfg
):
_
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
dtype
=
workload
...
...
topi/python/topi/x86/conv2d_alter_op.py
0 → 100644
View file @
b0ddcff6
This diff is collapsed.
Click to expand it.
topi/python/topi/x86/conv2d_avx_1x1.py
View file @
b0ddcff6
...
...
@@ -57,6 +57,36 @@ def _fallback_schedule(cfg, wkl):
raise
ValueError
(
"cannot decide default schedule for workload: {}"
.
format
(
wkl
))
def
_fallback_schedule_int8
(
cfg
,
wkl
):
simd_width
=
get_fp32_len
()
HPAD
,
WPAD
=
wkl
.
hpad
,
wkl
.
wpad
HSTR
,
WSTR
=
wkl
.
hstride
,
wkl
.
wstride
out_height
=
(
wkl
.
height
+
2
*
HPAD
-
wkl
.
hkernel
)
//
HSTR
+
1
out_width
=
(
wkl
.
width
+
2
*
WPAD
-
wkl
.
wkernel
)
//
WSTR
+
1
oc_bn
=
16
assert
wkl
.
out_filter
%
oc_bn
==
0
ic_bn
=
1
for
bn
in
range
(
oc_bn
,
0
,
-
4
):
if
wkl
.
in_filter
%
bn
==
0
:
ic_bn
=
bn
break
assert
wkl
.
in_filter
%
4
==
0
for
ow_factor
in
range
(
out_width
,
0
,
-
1
):
if
out_width
%
ow_factor
==
0
:
for
oh_factor
in
range
(
out_height
,
0
,
-
1
):
if
out_height
%
oh_factor
==
0
and
ow_factor
*
oh_factor
<
32
:
cfg
[
"tile_ic"
]
=
SplitEntity
([
wkl
.
in_filter
//
ic_bn
,
ic_bn
])
cfg
[
"tile_oc"
]
=
SplitEntity
([
wkl
.
out_filter
//
oc_bn
,
oc_bn
])
cfg
[
"tile_oh"
]
=
OtherOptionEntity
(
oh_factor
)
cfg
[
"tile_ow"
]
=
SplitEntity
([
out_width
//
ow_factor
,
ow_factor
])
return
raise
ValueError
(
"cannot decide default schedule for workload: {}"
.
format
(
wkl
))
def
_schedule_conv
(
s
,
cfg
,
data
,
data_pad
,
data_vec
,
kernel_vec
,
conv_out
,
output
,
last
):
# fetch schedule
ic_bn
,
oc_bn
,
oh_factor
,
ow_factor
=
(
cfg
[
"tile_ic"
]
.
size
[
-
1
],
cfg
[
"tile_oc"
]
.
size
[
-
1
],
...
...
topi/python/topi/x86/conv2d_avx_common.py
View file @
b0ddcff6
...
...
@@ -55,6 +55,34 @@ def _fallback_schedule(cfg, wkl):
cfg
[
"unroll_kw"
]
=
OtherOptionEntity
(
False
)
def
_fallback_schedule_int8
(
cfg
,
wkl
):
simd_width
=
get_fp32_len
()
HPAD
,
WPAD
=
wkl
.
hpad
,
wkl
.
wpad
HSTR
,
WSTR
=
wkl
.
hstride
,
wkl
.
wstride
out_width
=
(
wkl
.
width
+
2
*
WPAD
-
wkl
.
wkernel
)
//
WSTR
+
1
oc_bn
=
16
assert
wkl
.
out_filter
%
oc_bn
==
0
ic_bn
=
1
for
bn
in
range
(
oc_bn
,
0
,
-
4
):
if
wkl
.
in_filter
%
bn
==
0
:
ic_bn
=
bn
break
assert
wkl
.
in_filter
%
4
==
0
reg_n
=
1
for
n
in
range
(
31
,
0
,
-
1
):
if
out_width
%
n
==
0
:
reg_n
=
n
break
cfg
[
"tile_ic"
]
=
SplitEntity
([
wkl
.
in_filter
//
ic_bn
,
ic_bn
])
cfg
[
"tile_oc"
]
=
SplitEntity
([
wkl
.
out_filter
//
oc_bn
,
oc_bn
])
cfg
[
"tile_ow"
]
=
SplitEntity
([
out_width
//
reg_n
,
reg_n
])
cfg
[
"unroll_kw"
]
=
OtherOptionEntity
(
False
)
def
_schedule_conv
(
s
,
cfg
,
data
,
data_pad
,
data_vec
,
kernel_vec
,
conv_out
,
output
,
last
):
# fetch schedule
ic_bn
,
oc_bn
,
reg_n
,
unroll_kw
=
(
cfg
[
"tile_ic"
]
.
size
[
-
1
],
cfg
[
"tile_oc"
]
.
size
[
-
1
],
...
...
topi/python/topi/x86/conv2d_int8.py
View file @
b0ddcff6
...
...
@@ -22,12 +22,52 @@ import tvm
from
tvm
import
autotvm
from
tvm.autotvm.task
import
get_config
from
tvm.autotvm.task.topi_integration
import
deserialize_args
from
..nn.conv2d
import
_get_workload
as
_get_conv2d_workload
from
..
import
generic
,
tag
from
..util
import
get_const_tuple
from
..nn.conv2d
import
conv2d_NCHWc_int8
from
..
import
nn
from
.
import
conv2d_avx_1x1
,
conv2d_avx_common
def
_get_default_config_int8
(
cfg
,
data
,
kernel
,
strides
,
padding
,
out_dtype
,
is_depthwise
=
False
,
layout
=
'NCHW'
):
"""
Get default schedule config for the workload
"""
assert
not
is_depthwise
,
"Depthwise Int8 not supported"
wkl
=
_get_conv2d_workload
(
data
,
kernel
,
strides
,
padding
,
out_dtype
,
layout
)
is_kernel_1x1
=
wkl
.
hkernel
==
1
and
wkl
.
wkernel
==
1
if
is_kernel_1x1
:
conv2d_avx_1x1
.
_fallback_schedule_int8
(
cfg
,
wkl
)
else
:
conv2d_avx_common
.
_fallback_schedule_int8
(
cfg
,
wkl
)
def
_is_int8_hw_support
(
data_dtype
,
kernel_dtype
):
"""
Checks to ensure that we can use Intel DLBoost instructions
1) The datatypes are correct.
2) LLVM version has support for the instructions.
3) Target is skylake and above.
"""
# 1) Check datatypes
is_dtype_support
=
data_dtype
==
'uint8'
and
kernel_dtype
==
'int8'
# 2) Check LLVM support
llvm_intrin_fast_int8
=
"llvm.x86.avx512.pmaddubs.w.512"
llvm_id
=
tvm
.
codegen
.
llvm_lookup_intrinsic_id
(
llvm_intrin_fast_int8
)
is_llvm_support
=
llvm_id
!=
0
# 3) Check target
target
=
tvm
.
target
.
current_target
()
is_target_support
=
False
for
opt
in
target
.
options
:
if
opt
==
'-mcpu=skylake-avx512'
:
is_target_support
=
True
return
is_dtype_support
and
is_llvm_support
and
is_target_support
def
_create_tuning_space_int8
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
):
"""Create schedule configuration from input arguments"""
dshape
=
get_const_tuple
(
data
.
shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment