Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3ada7c0e
Commit
3ada7c0e
authored
Jul 23, 2019
by
Animesh Jain
Committed by
Tianqi Chen
Jul 23, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Checking the correct dtypes for choosing the Intel int8 instructions. (#3516)
parent
9e6a8c0d
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
147 additions
and
103 deletions
+147
-103
tests/python/relay/test_op_level2.py
+60
-0
topi/python/topi/nn/conv2d.py
+1
-24
topi/python/topi/x86/check_targets.py
+0
-28
topi/python/topi/x86/conv2d.py
+67
-18
topi/python/topi/x86/conv2d_avx_1x1.py
+6
-15
topi/python/topi/x86/conv2d_avx_common.py
+1
-14
topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py
+12
-4
No files found.
tests/python/relay/test_op_level2.py
View file @
3ada7c0e
...
...
@@ -517,6 +517,65 @@ def test_upsampling():
_test_upsampling
(
"NHWC"
,
"BILINEAR"
)
def
test_conv2d_int8_intrinsics
():
def
_compile
(
input_dtype
,
weight_dtype
,
output_dtype
,
target
):
n
,
ic
,
h
,
w
,
oc
,
ch
,
cw
=
1
,
16
,
224
,
224
,
32
,
3
,
3
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
ic
,
h
,
w
),
input_dtype
))
w
=
relay
.
var
(
"w"
,
relay
.
TensorType
((
oc
,
ic
,
ch
,
cw
),
weight_dtype
))
y
=
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
ch
,
cw
),
channels
=
oc
,
padding
=
(
1
,
1
),
dilation
=
(
1
,
1
),
out_dtype
=
output_dtype
)
func
=
relay
.
Function
([
x
,
w
],
y
)
wdata
=
np
.
random
.
rand
(
oc
,
ic
,
ch
,
cw
)
*
10
parameters
=
{
"w"
:
tvm
.
nd
.
array
(
wdata
.
astype
(
weight_dtype
))}
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
func
,
target
,
params
=
parameters
)
assembly
=
lib
.
get_source
(
"asm"
)
return
assembly
# compile conv2d for x86 (skylake) and test assembly contains *pmadd* instructions
target
=
"llvm -mcpu=skylake-avx512"
name
=
"llvm.x86.avx512.pmaddubs.w.512"
llvm_id
=
tvm
.
codegen
.
llvm_lookup_intrinsic_id
(
name
)
if
llvm_id
!=
0
:
# Intel Int8 instruction need uint8 data and int8 kernel
asm
=
_compile
(
input_dtype
=
"uint8"
,
weight_dtype
=
"int8"
,
output_dtype
=
"int32"
,
target
=
target
)
# Check that intrinisic is present in the assembly.
assert
"pmaddubs"
in
asm
# Ensure that code is generated when datatypes are not HW supported.
asm
=
_compile
(
input_dtype
=
"int8"
,
weight_dtype
=
"int8"
,
output_dtype
=
"int32"
,
target
=
target
)
# Check that intrinisic is not present in the assembly.
assert
"pmaddubs"
not
in
asm
# Ensure that code is generated when datatypes are not HW supported.
asm
=
_compile
(
input_dtype
=
"uint8"
,
weight_dtype
=
"uint8"
,
output_dtype
=
"int32"
,
target
=
target
)
# Check that intrinisic is not present in the assembly.
assert
"pmaddubs"
not
in
asm
# Check that a vectorized instruction is generated for older Intel
# generations, because we default to NCHWc layout.
target
=
"llvm -mcpu=core-avx2"
asm
=
_compile
(
input_dtype
=
"int8"
,
weight_dtype
=
"int8"
,
output_dtype
=
"int32"
,
target
=
target
)
# Check that vector int mult and add instructions are generated.
assert
"vpmulld"
in
asm
and
"vpadd"
in
asm
if
__name__
==
"__main__"
:
test_pool2d
()
test_avg_pool2d_no_count_pad
()
...
...
@@ -532,3 +591,4 @@ if __name__ == "__main__":
test_conv2d_run
()
test_batch_flatten
()
test_upsampling
()
test_conv2d_int8_intrinsics
()
topi/python/topi/nn/conv2d.py
View file @
3ada7c0e
...
...
@@ -391,10 +391,7 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
n
,
ic_chunk
,
ih
,
iw
,
ic_bn
=
get_const_tuple
(
data
.
shape
)
in_channel
=
ic_chunk
*
ic_bn
if
data
.
dtype
==
'uint8'
:
oc_chunk
,
_
,
kernel_height
,
kernel_width
,
_
,
oc_bn
,
_
=
get_const_tuple
(
kernel
.
shape
)
else
:
oc_chunk
,
_
,
kernel_height
,
kernel_width
,
_
,
oc_bn
=
get_const_tuple
(
kernel
.
shape
)
oc_chunk
,
_
,
kernel_height
,
kernel_width
,
_
,
oc_bn
=
get_const_tuple
(
kernel
.
shape
)
num_filter
=
oc_chunk
*
oc_bn
# output shape
...
...
@@ -413,26 +410,6 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
kh
=
tvm
.
reduce_axis
((
0
,
kernel_height
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'kw'
)
if
data
.
dtype
==
'uint8'
:
assert
out_dtype
==
"int32"
,
\
"INT8 convolution requires input dtype = uint8 and output dtype=int32"
# Intel performs dot product of 2 "4" Int8 values
# Current implementation requires ic_bn to be a multiple of 4
n_elems
=
4
assert
ic_bn
%
n_elems
==
0
ic_outer
=
tvm
.
reduce_axis
((
0
,
in_channel
//
ic_bn
),
name
=
'ic_outer'
)
ic_f_inner
=
tvm
.
reduce_axis
((
0
,
ic_bn
//
n_elems
),
name
=
'ic_f_inner'
)
ic_s_inner
=
tvm
.
reduce_axis
((
0
,
n_elems
),
name
=
'ic_s_inner'
)
return
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
data_pad
[
n
,
ic_outer
,
oh
*
HSTR
+
kh
,
ow
*
WSTR
+
kw
,
ic_f_inner
*
n_elems
+
ic_s_inner
]
.
astype
(
out_dtype
)
*
kernel
[
oc_chunk
,
ic_outer
,
kh
,
kw
,
ic_f_inner
,
oc_block
,
ic_s_inner
]
.
astype
(
out_dtype
),
axis
=
[
kh
,
kw
,
ic_outer
,
ic_f_inner
,
ic_s_inner
]),
name
=
'conv2d_NCHWc_int8'
,
tag
=
"conv2d_NCHWc_int8"
)
# else: fp implementation
return
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
data_pad
[
n
,
ic
//
ic_bn
,
oh
*
HSTR
+
kh
,
ow
*
WSTR
+
kw
,
ic
%
ic_bn
]
.
astype
(
out_dtype
)
*
...
...
topi/python/topi/x86/check_targets.py
deleted
100644 → 0
View file @
9e6a8c0d
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument
"""Checks different x86 targets for target specific schedules"""
def
check_skylake
(
target
):
"""
Checks if the target is skylake
"""
for
opt
in
target
.
options
:
if
opt
==
'-mcpu=skylake-avx512'
:
return
True
return
False
topi/python/topi/x86/conv2d.py
View file @
3ada7c0e
...
...
@@ -37,6 +37,29 @@ from . import conv2d_avx_1x1, conv2d_avx_common
logger
=
logging
.
getLogger
(
'topi'
)
def
_is_int8_hw_support
(
data_dtype
,
kernel_dtype
,
target
):
"""
Checks to ensure that we can use Intel DLBoost instructions
1) The datatypes are correct.
2) LLVM version has support for the instructions.
3) Target is skylake and above.
"""
# 1) Check datatypes
is_dtype_support
=
data_dtype
==
'uint8'
and
kernel_dtype
==
'int8'
# 2) Check LLVM support
llvm_intrin_fast_int8
=
"llvm.x86.avx512.pmaddubs.w.512"
llvm_id
=
tvm
.
codegen
.
llvm_lookup_intrinsic_id
(
llvm_intrin_fast_int8
)
is_llvm_support
=
llvm_id
!=
0
# 3) Check target
is_target_support
=
False
for
opt
in
target
.
options
:
if
opt
==
'-mcpu=skylake-avx512'
:
is_target_support
=
True
return
is_dtype_support
and
is_llvm_support
and
is_target_support
def
_get_default_config
(
cfg
,
data
,
kernel
,
strides
,
padding
,
out_dtype
,
is_depthwise
=
False
,
layout
=
'NCHW'
):
"""
...
...
@@ -68,7 +91,8 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
kh
,
kw
,
oc
,
_
=
kshape
elif
pat
.
match
(
layout
)
is
not
None
:
n
,
ic_chunk
,
h
,
w
,
ic_bn
=
dshape
if
data
.
dtype
==
'uint8'
:
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
):
oc_chunk
,
k_ic
,
kh
,
kw
,
k_ic_f
,
oc_bn
,
k_ic_s
=
kshape
ic
=
ic_chunk
*
ic_bn
assert
ic
==
k_ic
*
k_ic_f
*
kic_s
...
...
@@ -276,7 +300,6 @@ def schedule_conv2d_nhwc_pack(cfg, outs):
args
=
[
s
,
cfg
,
data_vec
,
conv_out
,
outs
[
0
]]
if
data
.
dtype
==
'uint8'
:
# int8 conv kernel is 7-dim
kh
,
kw
,
_
,
_
,
_
=
get_const_tuple
(
kernel
.
shape
)
if
kh
==
1
and
kw
==
1
:
conv2d_avx_1x1
.
_schedule_conv_nhwc_pack_int8
(
*
args
)
...
...
@@ -453,19 +476,42 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
new_workload
=
autotvm
.
task
.
args_to_workload
(
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
new_attrs
[
'out_layout'
],
out_dtype
],
depthwise_conv2d_NCHWc
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
else
:
out_channel
,
_
,
kh
,
kw
=
get_const_tuple
(
kernel
.
shape
)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs
[
'kernel_layout'
]
=
'OIHW
%
di
%
do'
%
(
ic_bn
,
oc_bn
)
# Store altered operator's config
new_kernel
=
tvm
.
placeholder
((
out_channel
//
oc_bn
,
in_channel
//
ic_bn
,
kh
,
kw
,
ic_bn
,
oc_bn
),
dtype
=
kernel
.
dtype
)
new_workload
=
autotvm
.
task
.
args_to_workload
(
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
new_attrs
[
'out_layout'
],
out_dtype
],
conv2d_NCHWc
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
):
# Convert kernel data layout from 4D to 7D
n_elems
=
4
out_channel
,
_
,
kh
,
kw
=
get_const_tuple
(
kernel
.
shape
)
data_expr
,
kernel_expr
=
inputs
kernel_IHWO
=
F
.
transpose
(
kernel_expr
,
axes
=
(
1
,
2
,
3
,
0
))
kernel_IHWOo
=
F
.
reshape
(
kernel_IHWO
,
(
in_channel
,
kh
,
kw
,
out_channel
//
oc_bn
,
oc_bn
))
kernel_OHWoI
=
F
.
transpose
(
kernel_IHWOo
,
axes
=
(
3
,
1
,
2
,
4
,
0
))
kernel_OHWoIi
=
F
.
reshape
(
kernel_OHWoI
,
(
out_channel
//
oc_bn
,
kh
,
kw
,
oc_bn
,
in_channel
//
ic_bn
,
ic_bn
))
kernel_OHWoIie
=
F
.
reshape
(
kernel_OHWoIi
,
(
out_channel
//
oc_bn
,
kh
,
kw
,
oc_bn
,
in_channel
//
ic_bn
,
ic_bn
//
n_elems
,
n_elems
))
kernel_OIHWioe
=
F
.
transpose
(
kernel_OHWoIie
,
axes
=
(
0
,
4
,
1
,
2
,
5
,
3
,
6
))
copy_inputs
=
[
data_expr
,
kernel_OIHWioe
]
# Store altered operator's config
new_kernel
=
tvm
.
placeholder
((
out_channel
//
oc_bn
,
kh
,
kw
,
oc_bn
,
in_channel
//
ic_bn
,
ic_bn
//
n_elems
,
n_elems
))
new_workload
=
autotvm
.
task
.
args_to_workload
(
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
new_attrs
[
'out_layout'
],
out_dtype
],
conv2d_NCHWc
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
else
:
out_channel
,
_
,
kh
,
kw
=
get_const_tuple
(
kernel
.
shape
)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs
[
'kernel_layout'
]
=
'OIHW
%
di
%
do'
%
(
ic_bn
,
oc_bn
)
# Store altered operator's config
new_kernel
=
tvm
.
placeholder
((
out_channel
//
oc_bn
,
in_channel
//
ic_bn
,
kh
,
kw
,
ic_bn
,
oc_bn
),
dtype
=
kernel
.
dtype
)
new_workload
=
autotvm
.
task
.
args_to_workload
(
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
new_attrs
[
'out_layout'
],
out_dtype
],
conv2d_NCHWc
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
if
is_depthwise
:
if
F
.
__name__
==
'nnvm.symbol'
:
...
...
@@ -505,7 +551,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
n
,
ic_chunk
,
ih
,
iw
,
ic_bn
=
get_const_tuple
(
data
.
shape
)
in_channel
=
ic_chunk
*
ic_bn
if
data
.
dtype
==
'uint8'
:
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
):
oc_chunk
,
ic_chunk_group
,
kernel_height
,
kernel_width
,
_
,
oc_bn
,
_
=
\
get_const_tuple
(
kernel
.
shape
)
else
:
...
...
@@ -539,7 +586,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
kh
=
tvm
.
reduce_axis
((
0
,
kernel_height
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'kw'
)
if
data
.
dtype
==
'uint8'
and
groups
==
1
:
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
)
and
groups
==
1
:
assert
out_dtype
==
"int32"
,
\
"INT8 convolution requires input dtype = uint8 and output dtype=int32"
# Intel performs dot product of 2 "4" Int8 values
...
...
@@ -559,7 +606,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
oc_block
,
ic_s_inner
]
.
astype
(
out_dtype
),
axis
=
[
kh
,
kw
,
ic_outer
,
ic_f_inner
,
ic_s_inner
]),
name
=
'conv2d_NCHWc_int8'
,
tag
=
"conv2d_NCHWc_int8"
)
if
data
.
dtype
==
'uint8'
:
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
):
# for int8 group conv support
n_elems
=
4
ic_chunk
=
in_channel
//
ic_bn
...
...
@@ -615,7 +663,8 @@ def _schedule_conv2d_NCHWc(cfg, outs):
data
=
data_pad
.
op
.
input_tensors
[
0
]
args
=
[
s
,
cfg
,
data_vec
,
conv_out
,
outs
[
0
]]
if
data
.
dtype
==
'uint8'
:
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
):
# int8 conv kernel is 7-dim
_
,
_
,
kh
,
kw
,
_
,
_
,
_
=
get_const_tuple
(
kernel
.
shape
)
if
kh
==
1
and
kw
==
1
:
...
...
topi/python/topi/x86/conv2d_avx_1x1.py
View file @
3ada7c0e
...
...
@@ -24,7 +24,6 @@ from ..nn.pad import pad
from
..nn.util
import
infer_pad
,
get_pad_tuple
from
..util
import
get_const_tuple
,
simplify
from
.tensor_intrin
import
dot_16x1x16_int8_int8_int32
from
.check_targets
import
check_skylake
from
.util
import
get_fp32_len
def
_fallback_schedule
(
cfg
,
wkl
):
...
...
@@ -187,13 +186,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
int32_lanes
=
-
1
if
check_skylake
(
target
):
int32_lanes
=
16
else
:
return
s
assert
int32_lanes
!=
-
1
int32_lanes
=
16
oh_factor
,
ow_factor
=
cfg
[
"tile_oh"
]
.
val
,
cfg
[
"tile_ow"
]
.
size
[
-
1
]
_
,
_
,
_
,
_
,
ic_bn
=
get_const_tuple
(
data
.
shape
)
...
...
@@ -310,13 +303,11 @@ def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last):
packing of weight to make the address access be friendly to int8
intrinsic
"""
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
int32_lanes
=
-
1
if
check_skylake
(
target
):
int32_lanes
=
16
else
:
return
s
assert
int32_lanes
!=
-
1
# FIXME - https://github.com/dmlc/tvm/issues/3598
# pylint: disable=unreachable
return
s
int32_lanes
=
16
# assertion to fail the unhandled case
_
,
_
,
_
,
ic_num
=
get_const_tuple
(
data
.
shape
)
...
...
topi/python/topi/x86/conv2d_avx_common.py
View file @
3ada7c0e
...
...
@@ -23,7 +23,6 @@ from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from
..nn.util
import
infer_pad
from
..util
import
get_const_tuple
from
.tensor_intrin
import
dot_16x1x16_int8_int8_int32
from
.check_targets
import
check_skylake
from
.util
import
get_fp32_len
def
_fallback_schedule
(
cfg
,
wkl
):
...
...
@@ -186,19 +185,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
# Currently INT8 operations are supported for only Skylake
# In future the _intrin_reduce4int8 will be updated for VNNI instructions
# In case of unsupported target, the schedule will go to the original
# compute
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
int32_lanes
=
-
1
if
check_skylake
(
target
):
int32_lanes
=
16
else
:
return
s
assert
int32_lanes
!=
-
1
int32_lanes
=
16
reg_n
,
unroll_kw
=
cfg
[
"tile_ow"
]
.
size
[
-
1
],
cfg
[
"unroll_kw"
]
.
val
_
,
_
,
_
,
_
,
ic_bn
=
get_const_tuple
(
data
.
shape
)
...
...
topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py
View file @
3ada7c0e
...
...
@@ -24,6 +24,7 @@ import topi
import
topi.testing
from
tvm.contrib.pickle_memoize
import
memoize
from
topi.util
import
get_const_tuple
from
nose.tools
import
nottest
from
common
import
get_all_backend
...
...
@@ -97,15 +98,22 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filte
func
(
a
,
w
,
c
)
tvm
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
c_np
,
rtol
=
1e-3
)
# for device in ["llvm
-mcpu=skylake-avx512
"]:
for
device
in
[
"llvm"
]:
# for device in ["llvm"]:
for
device
in
[
"llvm
-mcpu=skylake-avx512
"
]:
with
autotvm
.
tophub
.
context
(
device
):
# load tophub pre-tuned parameters
check_device
(
device
)
@nottest
def
test_conv2d_NCHWc
():
# ResNet50 workloads
verify_group_conv2d_NCHWc_int8
(
1
,
256
,
32
,
224
,
64
,
7
,
2
,
3
)
if
__name__
==
"__main__"
:
test_conv2d_NCHWc
()
# The test requires Skylake and newer Intel machines to generate the correct
# instruction. This test directly calls the topi operator, requiring correct
# kernel shape. For older generation of Intel machines, the kernel needs to
# be 6D. This test tests 7D kernel, that can only work on Skylake+ machines.
# So, disabling the test.
# test_conv2d_NCHWc()
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment