Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
eb220d92
Commit
eb220d92
authored
Sep 13, 2019
by
Animesh Jain
Committed by
Yao Wang
Sep 13, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactoring x86 conv2d_NCHWc (#3944)
parent
195973c0
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
606 additions
and
198 deletions
+606
-198
python/tvm/relay/op/nn/_nn.py
+28
-0
python/tvm/relay/op/nn/nn.py
+66
-0
src/relay/op/nn/convolution.cc
+48
-0
topi/python/topi/generic/nn.py
+19
-0
topi/python/topi/nn/conv2d.py
+245
-16
topi/python/topi/x86/__init__.py
+1
-0
topi/python/topi/x86/conv2d.py
+56
-182
topi/python/topi/x86/conv2d_int8.py
+143
-0
No files found.
python/tvm/relay/op/nn/_nn.py
View file @
eb220d92
...
@@ -548,6 +548,34 @@ reg.register_pattern("nn.contrib_conv2d_NCHWc",
...
@@ -548,6 +548,34 @@ reg.register_pattern("nn.contrib_conv2d_NCHWc",
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
@reg.register_compute
(
"nn.contrib_conv2d_NCHWc_int8"
)
def
compute_contrib_conv2d_NCHWc_int8
(
attrs
,
inputs
,
out_dtype
,
target
):
"""Compute definition of conv2d NCHWc"""
# pylint: disable=assignment-from-no-return
padding
=
attrs
.
get_int_tuple
(
"padding"
)
strides
=
attrs
.
get_int_tuple
(
"strides"
)
dilation
=
attrs
.
get_int_tuple
(
"dilation"
)
data_layout
=
attrs
.
get_str
(
"data_layout"
)
out_layout
=
attrs
.
get_str
(
"out_layout"
)
out_dtype
=
attrs
.
get_str
(
"out_dtype"
)
out_dtype
=
inputs
[
0
]
.
dtype
if
out_dtype
==
""
else
out_dtype
out
=
topi
.
nn
.
conv2d_NCHWc_int8
(
inputs
[
0
],
inputs
[
1
],
strides
,
padding
,
dilation
,
data_layout
,
out_layout
,
out_dtype
)
return
[
out
]
@reg.register_schedule
(
"nn.contrib_conv2d_NCHWc_int8"
)
def
schedule_contrib_conv2d_NCHWc_int8
(
attrs
,
outs
,
target
):
"""Schedule definition of contrib_conv2d_NCHWc_int8"""
with
target
:
return
topi
.
generic
.
schedule_conv2d_NCHWc_int8
(
outs
)
reg
.
register_pattern
(
"nn.contrib_conv2d_NCHWc_int8"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
@reg.register_compute
(
"nn.contrib_depthwise_conv2d_NCHWc"
)
@reg.register_compute
(
"nn.contrib_depthwise_conv2d_NCHWc"
)
def
compute_contrib_depthwise_conv2d_NCHWc
(
attrs
,
inputs
,
out_dtype
,
target
):
def
compute_contrib_depthwise_conv2d_NCHWc
(
attrs
,
inputs
,
out_dtype
,
target
):
"""Compute definition of depthwise conv2d NCHWc"""
"""Compute definition of depthwise conv2d NCHWc"""
...
...
python/tvm/relay/op/nn/nn.py
View file @
eb220d92
...
@@ -1340,6 +1340,72 @@ def contrib_depthwise_conv2d_nchwc(data,
...
@@ -1340,6 +1340,72 @@ def contrib_depthwise_conv2d_nchwc(data,
groups
,
channels
,
kernel_size
,
data_layout
,
groups
,
channels
,
kernel_size
,
data_layout
,
kernel_layout
,
out_layout
,
out_dtype
)
kernel_layout
,
out_layout
,
out_dtype
)
def
contrib_conv2d_nchwc_int8
(
data
,
kernel
,
strides
=
(
1
,
1
),
padding
=
(
0
,
0
),
dilation
=
(
1
,
1
),
groups
=
1
,
channels
=
None
,
kernel_size
=
None
,
data_layout
=
"NCHW8c"
,
kernel_layout
=
"OIHW"
,
out_layout
=
""
,
out_dtype
=
""
):
r"""Variant of 2D convolution. It deals with only int8 inputs.
This operator takes the weight as the convolution kernel
and convolves it with data to produce an output, following a specialized
NCHWc data layout.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
kernel : tvm.relay.Expr
The kernel expressions.
strides : tuple of int, optional
The strides of convolution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return
_make
.
contrib_conv2d_NCHWc_int8
(
data
,
kernel
,
strides
,
padding
,
dilation
,
groups
,
channels
,
kernel_size
,
data_layout
,
kernel_layout
,
out_layout
,
out_dtype
)
def
contrib_conv2d_winograd_weight_transform
(
weight
,
def
contrib_conv2d_winograd_weight_transform
(
weight
,
tile_size
):
tile_size
):
r"""Weight Transformation part for 2D convolution with winograd algorithm.
r"""Weight Transformation part for 2D convolution with winograd algorithm.
...
...
src/relay/op/nn/convolution.cc
View file @
eb220d92
...
@@ -570,6 +570,54 @@ weight transformation in advance.
...
@@ -570,6 +570,54 @@ weight transformation in advance.
.
set_support_level
(
10
)
.
set_support_level
(
10
)
.
add_type_rel
(
"Conv2DWinogradNNPACKWeightTransform"
,
Conv2DWinogradNNPACKWeightTransformRel
);
.
add_type_rel
(
"Conv2DWinogradNNPACKWeightTransform"
,
Conv2DWinogradNNPACKWeightTransformRel
);
// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
Expr
MakeConv2DNCHWcInt8
(
Expr
data
,
Expr
kernel
,
Array
<
IndexExpr
>
strides
,
Array
<
IndexExpr
>
padding
,
Array
<
IndexExpr
>
dilation
,
int
groups
,
IndexExpr
channels
,
Array
<
IndexExpr
>
kernel_size
,
std
::
string
data_layout
,
std
::
string
kernel_layout
,
std
::
string
out_layout
,
DataType
out_dtype
)
{
auto
attrs
=
make_node
<
Conv2DAttrs
>
();
attrs
->
strides
=
std
::
move
(
strides
);
attrs
->
padding
=
std
::
move
(
padding
);
attrs
->
dilation
=
std
::
move
(
dilation
);
attrs
->
groups
=
groups
;
attrs
->
channels
=
channels
;
attrs
->
kernel_size
=
std
::
move
(
kernel_size
);
attrs
->
data_layout
=
std
::
move
(
data_layout
);
attrs
->
kernel_layout
=
std
::
move
(
kernel_layout
);
attrs
->
out_layout
=
std
::
move
(
out_layout
);
attrs
->
out_dtype
=
std
::
move
(
out_dtype
);
static
const
Op
&
op
=
Op
::
Get
(
"nn.contrib_conv2d_NCHWc_int8"
);
return
CallNode
::
make
(
op
,
{
data
,
kernel
},
Attrs
(
attrs
),
{});
}
TVM_REGISTER_API
(
"relay.op.nn._make.contrib_conv2d_NCHWc_int8"
)
.
set_body_typed
(
MakeConv2DNCHWcInt8
);
RELAY_REGISTER_OP
(
"nn.contrib_conv2d_NCHWc_int8"
)
.
describe
(
R"code(Compute conv2d with NCHWc data layout with int8 inputs.
- **data**: Input is 5D packed tensor.
- **weight**: 7D packed tensor.
- **out**: Output is 5D packed tensor
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.Conv2D"
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"weight"
,
"Tensor"
,
"The weight tensor."
)
.
set_support_level
(
10
)
.
add_type_rel
(
"Conv2DNCHWcInt8"
,
Conv2DWinogradRel
<
Conv2DAttrs
>
)
.
set_attr
<
FInferCorrectLayout
>
(
"FInferCorrectLayout"
,
Conv2DInferCorrectLayout
<
Conv2DAttrs
>
);
// Positional relay function to create conv2d NCHWc operator
// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
// used by frontend FFI.
...
...
topi/python/topi/generic/nn.py
View file @
eb220d92
...
@@ -108,6 +108,25 @@ def schedule_conv2d_NCHWc(outs):
...
@@ -108,6 +108,25 @@ def schedule_conv2d_NCHWc(outs):
@tvm.target.generic_func
@tvm.target.generic_func
def
schedule_conv2d_NCHWc_int8
(
outs
):
"""Schedule for conv2d_NCHW[x]c_int8
Parameters
----------
outs : Array of Tensor
The computation graph description of conv2d_NCHWc_int8
in the format of an array of tensors.
The number of filter, i.e., the output channel.
Returns
-------
sch : Schedule
The computation schedule for the op.
"""
return
_default_schedule
(
outs
,
False
)
@tvm.target.generic_func
def
schedule_conv2d_winograd_weight_transform
(
outs
):
def
schedule_conv2d_winograd_weight_transform
(
outs
):
"""Schedule for weight transformation of winograd
"""Schedule for weight transformation of winograd
...
...
topi/python/topi/nn/conv2d.py
View file @
eb220d92
...
@@ -398,27 +398,75 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
...
@@ -398,27 +398,75 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
output : tvm.Tensor
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
"""
# search platform specific declaration first
# default declaration
return
conv2d_NCHWc_compute
(
data
,
kernel
,
stride
,
padding
,
dilation
,
layout
,
out_layout
,
out_dtype
)
def
conv2d_NCHWc_compute
(
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_layout
,
out_dtype
):
"""Conv2D operator compute for nChw[x]c layout.
Parameters
----------
data : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
kernel : tvm.Tensor
6-D with shape
[num_filter_chunk, in_channel_chunk, filter_height, filter_width,
in_channel_block, num_filter_block]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
Input data layout
out_layout : str
Output data layout
out_dtype : str
output data type
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
# layout and out_layout are not used here,
# layout and out_layout are not used here,
# we keep them for debug convenience when dumping autotvm workload
# we keep them for debug convenience when dumping autotvm workload
pad_top
,
pad_left
,
pad_down
,
pad_right
=
get_pad_tuple
(
padding
,
HPAD
,
WPAD
=
padding
if
isinstance
(
padding
,
(
tuple
,
list
))
else
(
padding
,
padding
)
(
dilated_kernel_h
,
HSTR
,
WSTR
=
strides
if
isinstance
(
strides
,
(
tuple
,
list
))
else
(
strides
,
strides
)
dilated_kernel_w
))
dilation_h
,
dilation_w
=
dilation
if
isinstance
(
dilation
,
(
tuple
,
list
))
\
HPAD
=
pad_top
+
pad_down
else
(
dilation
,
dilation
)
WPAD
=
pad_left
+
pad_right
HSTR
,
WSTR
=
stride
if
isinstance
(
stride
,
(
tuple
,
list
))
else
(
stride
,
stride
)
dh
,
dw
=
dilation
if
isinstance
(
dilation
,
(
tuple
,
list
))
else
(
dilation
,
dilation
)
assert
(
dh
,
dw
)
==
(
1
,
1
),
"Does not support dilation"
n
,
ic_chunk
,
ih
,
iw
,
ic_bn
=
get_const_tuple
(
data
.
shape
)
n
,
ic_chunk
,
ih
,
iw
,
ic_bn
=
get_const_tuple
(
data
.
shape
)
in_channel
=
ic_chunk
*
ic_bn
in_channel
=
ic_chunk
*
ic_bn
oc_chunk
,
_
,
kernel_height
,
kernel_width
,
_
,
oc_bn
=
get_const_tuple
(
kernel
.
shape
)
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
oc_chunk
,
ic_chunk_group
,
kernel_height
,
kernel_width
,
_
,
oc_bn
=
\
get_const_tuple
(
kernel
.
shape
)
num_filter
=
oc_chunk
*
oc_bn
num_filter
=
oc_chunk
*
oc_bn
groups
=
ic_chunk
//
ic_chunk_group
dilated_kernel_h
=
(
kernel_height
-
1
)
*
dilation_h
+
1
dilated_kernel_w
=
(
kernel_width
-
1
)
*
dilation_w
+
1
# output shape
# output shape
out_height
=
(
ih
+
2
*
HPAD
-
kernel_height
)
//
HSTR
+
1
out_height
=
(
ih
+
2
*
HPAD
-
dilated_kernel_h
)
//
HSTR
+
1
out_width
=
(
iw
+
2
*
WPAD
-
kernel_width
)
//
WSTR
+
1
out_width
=
(
iw
+
2
*
WPAD
-
dilated_kernel_w
)
//
WSTR
+
1
oshape
=
(
n
,
oc_chunk
,
out_height
,
out_width
,
oc_bn
)
oshape
=
(
n
,
oc_chunk
,
out_height
,
out_width
,
oc_bn
)
# DOPAD
# DOPAD
...
@@ -433,13 +481,194 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
...
@@ -433,13 +481,194 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
kw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'kw'
)
kw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'kw'
)
return
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
return
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
data_pad
[
n
,
ic
//
ic_bn
,
oh
*
HSTR
+
kh
,
ow
*
WSTR
+
kw
,
tvm
.
sum
(
data_pad
[
n
,
ic
%
ic_bn
]
.
astype
(
out_dtype
)
*
ic
//
ic_bn
,
kernel
[
oc_chunk
,
ic
//
ic_bn
,
kh
,
kw
,
ic
%
ic_bn
,
oc_block
],
oh
*
HSTR
+
kh
*
dilation_h
,
ow
*
WSTR
+
kw
*
dilation_w
,
ic
%
ic_bn
]
.
astype
(
out_dtype
)
*
kernel
[
oc_chunk
,
ic
//
ic_bn
,
kh
,
kw
,
ic
%
ic_bn
,
oc_block
],
axis
=
[
ic
,
kh
,
kw
]),
axis
=
[
ic
,
kh
,
kw
]),
name
=
'conv2d_NCHWc'
,
tag
=
"conv2d_NCHWc"
)
name
=
'conv2d_NCHWc'
,
tag
=
"conv2d_NCHWc"
)
@tvm.target.generic_func
def
conv2d_NCHWc_int8
(
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_layout
,
out_dtype
=
'int32'
):
"""Conv2D operator for nChw[x]c layout.
Parameters
----------
data : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
kernel : tvm.Tensor
7-D with shape
[num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block/4,
num_filter_block, 4]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
Input data layout
out_layout : str
Output data layout
out_dtype : str
output data type
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
return
conv2d_NCHWc_int8_compute
(
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_layout
,
out_dtype
)
def
conv2d_NCHWc_int8_compute
(
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_layout
,
out_dtype
=
'int32'
):
"""Conv2D operator for nChw[x]c layout.
Parameters
----------
data : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
kernel : tvm.Tensor
7-D with shape
[num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block/4,
num_filter_block, 4]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
Input data layout
out_layout : str
Output data layout
out_dtype : str
output data type
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
# layout and out_layout are not used here,
# we keep them for debug convenience when dumping autotvm workload
HPAD
,
WPAD
=
padding
if
isinstance
(
padding
,
(
tuple
,
list
))
else
(
padding
,
padding
)
HSTR
,
WSTR
=
strides
if
isinstance
(
strides
,
(
tuple
,
list
))
else
(
strides
,
strides
)
dilation_h
,
dilation_w
=
dilation
if
isinstance
(
dilation
,
(
tuple
,
list
))
\
else
(
dilation
,
dilation
)
n
,
ic_chunk
,
ih
,
iw
,
ic_bn
=
get_const_tuple
(
data
.
shape
)
in_channel
=
ic_chunk
*
ic_bn
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
oc_chunk
,
ic_chunk_group
,
kernel_height
,
kernel_width
,
_
,
oc_bn
,
_
=
\
get_const_tuple
(
kernel
.
shape
)
num_filter
=
oc_chunk
*
oc_bn
groups
=
ic_chunk
//
ic_chunk_group
# Since the weight is 7-D and the last element size is 4, we have to
# check ic_bn should be a multiple of 4.
# Similary, oc_bn has to be a multiple of 4.
assert
ic_bn
%
4
==
0
assert
oc_bn
%
16
==
0
dilated_kernel_h
=
(
kernel_height
-
1
)
*
dilation_h
+
1
dilated_kernel_w
=
(
kernel_width
-
1
)
*
dilation_w
+
1
# output shape
out_height
=
(
ih
+
2
*
HPAD
-
dilated_kernel_h
)
//
HSTR
+
1
out_width
=
(
iw
+
2
*
WPAD
-
dilated_kernel_w
)
//
WSTR
+
1
oshape
=
(
n
,
oc_chunk
,
out_height
,
out_width
,
oc_bn
)
# DOPAD
DOPAD
=
(
HPAD
!=
0
or
WPAD
!=
0
)
if
DOPAD
:
data_pad
=
pad
(
data
,
(
0
,
0
,
HPAD
,
WPAD
,
0
),
name
=
"data_pad"
)
else
:
data_pad
=
data
ic
=
tvm
.
reduce_axis
((
0
,
in_channel
),
name
=
'ic'
)
kh
=
tvm
.
reduce_axis
((
0
,
kernel_height
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'kw'
)
if
groups
==
1
:
n_elems
=
4
ic_outer
=
tvm
.
reduce_axis
((
0
,
in_channel
//
ic_bn
),
name
=
'ic_outer'
)
ic_f_inner
=
tvm
.
reduce_axis
((
0
,
ic_bn
//
n_elems
),
name
=
'ic_f_inner'
)
ic_s_inner
=
tvm
.
reduce_axis
((
0
,
n_elems
),
name
=
'ic_s_inner'
)
return
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
data_pad
[
n
,
ic_outer
,
oh
*
HSTR
+
kh
*
dilation_h
,
ow
*
WSTR
+
kw
*
dilation_w
,
ic_f_inner
*
n_elems
+
ic_s_inner
]
.
astype
(
out_dtype
)
*
kernel
[
oc_chunk
,
ic_outer
,
kh
,
kw
,
ic_f_inner
,
oc_block
,
ic_s_inner
]
.
astype
(
out_dtype
),
axis
=
[
kh
,
kw
,
ic_outer
,
ic_f_inner
,
ic_s_inner
]),
name
=
'conv2d_NCHWc_int8'
,
tag
=
"conv2d_NCHWc_int8"
)
# for int8 group conv support
n_elems
=
4
ic_chunk
=
in_channel
//
ic_bn
ic_outer
=
tvm
.
reduce_axis
((
0
,
ic_chunk
//
groups
),
name
=
'ic_outer'
)
ic_f_inner
=
tvm
.
reduce_axis
((
0
,
ic_bn
//
n_elems
),
name
=
'ic_f_inner'
)
ic_s_inner
=
tvm
.
reduce_axis
((
0
,
n_elems
),
name
=
'ic_s_inner'
)
oshape
=
(
n
,
oc_chunk
,
out_height
,
out_width
,
oc_bn
)
return
tvm
.
compute
(
oshape
,
lambda
n
,
occ
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
data_pad
[
n
,
(
occ
*
oc_bn
//
(
oc_chunk
*
oc_bn
//
groups
))
*
(
ic_chunk
//
groups
)
+
ic_outer
,
oh
*
HSTR
+
kh
,
ow
*
WSTR
+
kw
,
ic_f_inner
*
n_elems
+
ic_s_inner
]
.
astype
(
out_dtype
)
*
kernel
[
occ
,
ic_outer
,
kh
,
kw
,
ic_f_inner
,
oc_block
,
ic_s_inner
]
.
astype
(
out_dtype
),
axis
=
[
kh
,
kw
,
ic_outer
,
ic_f_inner
,
ic_s_inner
]),
name
=
'conv2d_NCHWc_int8'
,
tag
=
"conv2d_NCHWc_int8"
)
def
conv2d_winograd_weight_transform
(
kernel
,
tile_size
):
def
conv2d_winograd_weight_transform
(
kernel
,
tile_size
):
"""Weight transformation for winograd
"""Weight transformation for winograd
...
...
topi/python/topi/x86/__init__.py
View file @
eb220d92
...
@@ -6,6 +6,7 @@ from .conv2d import schedule_conv2d, schedule_conv2d_nhwc
...
@@ -6,6 +6,7 @@ from .conv2d import schedule_conv2d, schedule_conv2d_nhwc
from
.binarize_pack
import
schedule_binarize_pack
from
.binarize_pack
import
schedule_binarize_pack
from
.binary_dense
import
schedule_binary_dense
from
.binary_dense
import
schedule_binary_dense
from
.nn
import
*
from
.nn
import
*
from
.conv2d_int8
import
*
from
.injective
import
*
from
.injective
import
*
from
.pooling
import
schedule_pool
,
schedule_adaptive_pool
from
.pooling
import
schedule_pool
,
schedule_adaptive_pool
from
.bitserial_conv2d
import
schedule_bitserial_conv2d
from
.bitserial_conv2d
import
schedule_bitserial_conv2d
...
...
topi/python/topi/x86/conv2d.py
View file @
eb220d92
...
@@ -263,58 +263,6 @@ def schedule_conv2d(cfg, outs):
...
@@ -263,58 +263,6 @@ def schedule_conv2d(cfg, outs):
traverse
(
outs
[
0
]
.
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
return
s
@autotvm.register_topi_schedule
(
generic
.
schedule_conv2d_nhwc_pack
,
'cpu'
,
[
'direct'
])
def
schedule_conv2d_nhwc_pack
(
cfg
,
outs
):
"""Create schedule for tensors"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
output_op
=
outs
[
0
]
.
op
scheduled_ops
=
[]
def
traverse
(
op
):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if
tag
.
is_broadcast
(
op
.
tag
):
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
else
:
# inject custom schedule
if
len
(
op
.
axis
)
==
4
:
# schedule bias + bn + relu
n
,
h
,
w
,
c
=
op
.
axis
fused
=
s
[
op
]
.
fuse
(
n
,
h
,
w
)
s
[
op
]
.
parallel
(
fused
)
s
[
op
]
.
vectorize
(
c
)
for
tensor
in
op
.
input_tensors
:
if
isinstance
(
tensor
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
if
'conv2d_nhwc_pack_int8'
in
op
.
tag
:
conv_out
=
op
.
output
(
0
)
kernel
=
conv_out
.
op
.
input_tensors
[
1
]
data_vec
=
conv_out
.
op
.
input_tensors
[
0
]
data
=
data_vec
.
op
.
input_tensors
[
0
]
\
if
isinstance
(
data_vec
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"pad"
not
in
data_vec
.
op
.
tag
\
else
data_vec
if
isinstance
(
data
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"pad"
in
data
.
op
.
tag
:
data_pad
=
data
data
=
data_pad
.
op
.
input_tensors
[
0
]
args
=
[
s
,
cfg
,
data_vec
,
conv_out
,
outs
[
0
]]
if
data
.
dtype
==
'uint8'
:
kh
,
kw
,
_
,
_
,
_
=
get_const_tuple
(
kernel
.
shape
)
if
kh
==
1
and
kw
==
1
:
conv2d_avx_1x1
.
_schedule_conv_nhwc_pack_int8
(
*
args
)
else
:
raise
ValueError
(
"Only support 1x1 kernel with "
"schedule_conv2d_nhwc_pack."
)
else
:
raise
ValueError
(
"Not support this data type {} with "
"schedule_conv2d_nhwc_pack. Only support int8"
.
format
(
data
.
dtype
))
scheduled_ops
.
append
(
op
)
traverse
(
output_op
)
return
s
@generic.schedule_conv2d_nhwc.register
(
"cpu"
)
@generic.schedule_conv2d_nhwc.register
(
"cpu"
)
def
schedule_conv2d_nhwc
(
outs
):
def
schedule_conv2d_nhwc
(
outs
):
"""Create schedule for tensors"""
"""Create schedule for tensors"""
...
@@ -477,51 +425,53 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
...
@@ -477,51 +425,53 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
new_attrs
[
'out_layout'
],
out_dtype
],
depthwise_conv2d_NCHWc
)
new_attrs
[
'out_layout'
],
out_dtype
],
depthwise_conv2d_NCHWc
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
else
:
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
):
# Convert kernel data layout from 4D to 7D
n_elems
=
4
out_channel
,
_
,
kh
,
kw
=
get_const_tuple
(
kernel
.
shape
)
data_expr
,
kernel_expr
=
inputs
kernel_IHWO
=
F
.
transpose
(
kernel_expr
,
axes
=
(
1
,
2
,
3
,
0
))
kernel_IHWOo
=
F
.
reshape
(
kernel_IHWO
,
(
in_channel
,
kh
,
kw
,
out_channel
//
oc_bn
,
oc_bn
))
kernel_OHWoI
=
F
.
transpose
(
kernel_IHWOo
,
axes
=
(
3
,
1
,
2
,
4
,
0
))
kernel_OHWoIi
=
F
.
reshape
(
kernel_OHWoI
,
(
out_channel
//
oc_bn
,
kh
,
kw
,
oc_bn
,
in_channel
//
ic_bn
,
ic_bn
))
kernel_OHWoIie
=
F
.
reshape
(
kernel_OHWoIi
,
(
out_channel
//
oc_bn
,
kh
,
kw
,
oc_bn
,
in_channel
//
ic_bn
,
ic_bn
//
n_elems
,
n_elems
))
kernel_OIHWioe
=
F
.
transpose
(
kernel_OHWoIie
,
axes
=
(
0
,
4
,
1
,
2
,
5
,
3
,
6
))
copy_inputs
=
[
data_expr
,
kernel_OIHWioe
]
# Store altered operator's config
new_kernel
=
tvm
.
placeholder
((
out_channel
//
oc_bn
,
kh
,
kw
,
oc_bn
,
in_channel
//
ic_bn
,
ic_bn
//
n_elems
,
n_elems
))
new_workload
=
autotvm
.
task
.
args_to_workload
(
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
new_attrs
[
'out_layout'
],
out_dtype
],
conv2d_NCHWc
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
else
:
out_channel
,
_
,
kh
,
kw
=
get_const_tuple
(
kernel
.
shape
)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs
[
'kernel_layout'
]
=
'OIHW
%
di
%
do'
%
(
ic_bn
,
oc_bn
)
# Store altered operator's config
new_kernel
=
tvm
.
placeholder
((
out_channel
//
oc_bn
,
in_channel
//
ic_bn
,
kh
,
kw
,
ic_bn
,
oc_bn
),
dtype
=
kernel
.
dtype
)
new_workload
=
autotvm
.
task
.
args_to_workload
(
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
new_attrs
[
'out_layout'
],
out_dtype
],
conv2d_NCHWc
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
if
is_depthwise
:
if
F
.
__name__
==
'nnvm.symbol'
:
if
F
.
__name__
==
'nnvm.symbol'
:
logging
.
warning
(
"Use native layout for depthwise convolution on NNVM."
)
logging
.
warning
(
"Use native layout for depthwise convolution on NNVM."
)
return
None
return
None
return
F
.
nn
.
contrib_depthwise_conv2d_nchwc
(
*
copy_inputs
,
**
new_attrs
)
return
F
.
nn
.
contrib_depthwise_conv2d_nchwc
(
*
copy_inputs
,
**
new_attrs
)
else
:
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
):
# Convert kernel data layout from 4D to 7D
n_elems
=
4
out_channel
,
_
,
kh
,
kw
=
get_const_tuple
(
kernel
.
shape
)
data_expr
,
kernel_expr
=
inputs
kernel_IHWO
=
F
.
transpose
(
kernel_expr
,
axes
=
(
1
,
2
,
3
,
0
))
kernel_IHWOo
=
F
.
reshape
(
kernel_IHWO
,
(
in_channel
,
kh
,
kw
,
out_channel
//
oc_bn
,
oc_bn
))
kernel_OHWoI
=
F
.
transpose
(
kernel_IHWOo
,
axes
=
(
3
,
1
,
2
,
4
,
0
))
kernel_OHWoIi
=
F
.
reshape
(
kernel_OHWoI
,
(
out_channel
//
oc_bn
,
kh
,
kw
,
oc_bn
,
in_channel
//
ic_bn
,
ic_bn
))
kernel_OHWoIie
=
F
.
reshape
(
kernel_OHWoIi
,
(
out_channel
//
oc_bn
,
kh
,
kw
,
oc_bn
,
in_channel
//
ic_bn
,
ic_bn
//
n_elems
,
n_elems
))
kernel_OIHWioe
=
F
.
transpose
(
kernel_OHWoIie
,
axes
=
(
0
,
4
,
1
,
2
,
5
,
3
,
6
))
copy_inputs
=
[
data_expr
,
kernel_OIHWioe
]
# Store altered operator's config
new_kernel
=
tvm
.
placeholder
((
out_channel
//
oc_bn
,
kh
,
kw
,
oc_bn
,
in_channel
//
ic_bn
,
ic_bn
//
n_elems
,
n_elems
))
new_workload
=
autotvm
.
task
.
args_to_workload
(
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
new_attrs
[
'out_layout'
],
out_dtype
],
conv2d_NCHWc
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
if
F
.
__name__
==
'nnvm.symbol'
:
if
F
.
__name__
==
'nnvm.symbol'
:
return
F
.
contrib
.
conv2d_NCHWc
(
*
copy_inputs
,
**
new_attrs
)
logging
.
warning
(
"Use native layout for int8 convolution on NNVM."
)
return
F
.
nn
.
contrib_conv2d_nchwc
(
*
copy_inputs
,
**
new_attrs
)
return
None
return
F
.
nn
.
contrib_conv2d_nchwc_int8
(
*
copy_inputs
,
**
new_attrs
)
out_channel
,
_
,
kh
,
kw
=
get_const_tuple
(
kernel
.
shape
)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs
[
'kernel_layout'
]
=
'OIHW
%
di
%
do'
%
(
ic_bn
,
oc_bn
)
# Store altered operator's config
new_kernel
=
tvm
.
placeholder
((
out_channel
//
oc_bn
,
in_channel
//
ic_bn
,
kh
,
kw
,
ic_bn
,
oc_bn
),
dtype
=
kernel
.
dtype
)
new_workload
=
autotvm
.
task
.
args_to_workload
(
[
new_data
,
new_kernel
,
strides
,
padding
,
dilation
,
new_attrs
[
layout_name
],
new_attrs
[
'out_layout'
],
out_dtype
],
conv2d_NCHWc
)
dispatch_ctx
.
update
(
target
,
new_workload
,
cfg
)
if
F
.
__name__
==
'nnvm.symbol'
:
return
F
.
contrib
.
conv2d_NCHWc
(
*
copy_inputs
,
**
new_attrs
)
return
F
.
nn
.
contrib_conv2d_nchwc
(
*
copy_inputs
,
**
new_attrs
)
@conv2d_infer_layout.register
(
"cpu"
)
@conv2d_infer_layout.register
(
"cpu"
)
...
@@ -544,95 +494,27 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
...
@@ -544,95 +494,27 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
padding
,
dilation
,
layout
,
out_layout
,
out_dtype
):
padding
,
dilation
,
layout
,
out_layout
,
out_dtype
):
# layout and out_layout are not used here,
# layout and out_layout are not used here,
# we keep them for debug convenience when dumping autotvm workload
# we keep them for debug convenience when dumping autotvm workload
HPAD
,
WPAD
=
padding
if
isinstance
(
padding
,
(
tuple
,
list
))
else
(
padding
,
padding
)
HSTR
,
WSTR
=
strides
if
isinstance
(
strides
,
(
tuple
,
list
))
else
(
strides
,
strides
)
dilation_h
,
dilation_w
=
dilation
if
isinstance
(
dilation
,
(
tuple
,
list
))
\
else
(
dilation
,
dilation
)
n
,
ic_chunk
,
ih
,
iw
,
ic_bn
=
get_const_tuple
(
data
.
shape
)
n
,
ic_chunk
,
ih
,
iw
,
ic_bn
=
get_const_tuple
(
data
.
shape
)
in_channel
=
ic_chunk
*
ic_bn
in_channel
=
ic_chunk
*
ic_bn
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
oc_chunk
,
ic_chunk_group
,
kernel_height
,
kernel_width
,
_
,
oc_bn
=
\
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
):
oc_chunk
,
ic_chunk_group
,
kernel_height
,
kernel_width
,
_
,
oc_bn
,
_
=
\
get_const_tuple
(
kernel
.
shape
)
else
:
oc_chunk
,
ic_chunk_group
,
kernel_height
,
kernel_width
,
_
,
oc_bn
=
\
get_const_tuple
(
kernel
.
shape
)
get_const_tuple
(
kernel
.
shape
)
num_filter
=
oc_chunk
*
oc_bn
num_filter
=
oc_chunk
*
oc_bn
groups
=
ic_chunk
//
ic_chunk_group
dilated_kernel_h
=
(
kernel_height
-
1
)
*
dilation_h
+
1
dilated_kernel_w
=
(
kernel_width
-
1
)
*
dilation_w
+
1
# If no config was set, we can fallback to NCHW config.
if
cfg
.
is_fallback
:
if
cfg
.
is_fallback
:
_get_default_config
(
cfg
,
tvm
.
placeholder
((
n
,
in_channel
,
ih
,
iw
),
dtype
=
data
.
dtype
),
_get_default_config
(
cfg
,
tvm
.
placeholder
((
n
,
in_channel
,
ih
,
iw
),
dtype
=
data
.
dtype
),
tvm
.
placeholder
((
num_filter
,
in_channel
,
kernel_height
,
kernel_width
),
tvm
.
placeholder
((
num_filter
,
in_channel
,
kernel_height
,
kernel_width
),
dtype
=
kernel
.
dtype
),
dtype
=
kernel
.
dtype
),
strides
,
padding
,
out_dtype
)
strides
,
padding
,
out_dtype
)
# output shape
return
nn
.
conv2d_NCHWc_compute
(
data
,
out_height
=
(
ih
+
2
*
HPAD
-
dilated_kernel_h
)
//
HSTR
+
1
kernel
,
out_width
=
(
iw
+
2
*
WPAD
-
dilated_kernel_w
)
//
WSTR
+
1
strides
,
oshape
=
(
n
,
oc_chunk
,
out_height
,
out_width
,
oc_bn
)
padding
,
dilation
,
# DOPAD
layout
,
DOPAD
=
(
HPAD
!=
0
or
WPAD
!=
0
)
out_layout
,
if
DOPAD
:
out_dtype
)
data_pad
=
pad
(
data
,
(
0
,
0
,
HPAD
,
WPAD
,
0
),
name
=
"data_pad"
)
else
:
data_pad
=
data
ic
=
tvm
.
reduce_axis
((
0
,
in_channel
),
name
=
'ic'
)
kh
=
tvm
.
reduce_axis
((
0
,
kernel_height
),
name
=
'kh'
)
kw
=
tvm
.
reduce_axis
((
0
,
kernel_width
),
name
=
'kw'
)
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
)
and
groups
==
1
:
assert
out_dtype
==
"int32"
,
\
"INT8 convolution requires input dtype = uint8 and output dtype=int32"
# Intel performs dot product of 2 "4" Int8 values
# Current implementation requires ic_bn to be a multiple of 4
n_elems
=
4
assert
ic_bn
%
n_elems
==
0
ic_outer
=
tvm
.
reduce_axis
((
0
,
in_channel
//
ic_bn
),
name
=
'ic_outer'
)
ic_f_inner
=
tvm
.
reduce_axis
((
0
,
ic_bn
//
n_elems
),
name
=
'ic_f_inner'
)
ic_s_inner
=
tvm
.
reduce_axis
((
0
,
n_elems
),
name
=
'ic_s_inner'
)
return
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
data_pad
[
n
,
ic_outer
,
oh
*
HSTR
+
kh
*
dilation_h
,
ow
*
WSTR
+
kw
*
dilation_w
,
ic_f_inner
*
n_elems
+
ic_s_inner
]
.
astype
(
out_dtype
)
*
kernel
[
oc_chunk
,
ic_outer
,
kh
,
kw
,
ic_f_inner
,
oc_block
,
ic_s_inner
]
.
astype
(
out_dtype
),
axis
=
[
kh
,
kw
,
ic_outer
,
ic_f_inner
,
ic_s_inner
]),
name
=
'conv2d_NCHWc_int8'
,
tag
=
"conv2d_NCHWc_int8"
)
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
):
# for int8 group conv support
n_elems
=
4
ic_chunk
=
in_channel
//
ic_bn
ic_outer
=
tvm
.
reduce_axis
((
0
,
ic_chunk
//
groups
),
name
=
'ic_outer'
)
ic_f_inner
=
tvm
.
reduce_axis
((
0
,
ic_bn
//
n_elems
),
name
=
'ic_f_inner'
)
ic_s_inner
=
tvm
.
reduce_axis
((
0
,
n_elems
),
name
=
'ic_s_inner'
)
oshape
=
(
n
,
oc_chunk
,
out_height
,
out_width
,
oc_bn
)
return
tvm
.
compute
(
oshape
,
lambda
n
,
occ
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
data_pad
[
n
,
(
occ
*
oc_bn
//
(
oc_chunk
*
oc_bn
//
groups
))
*
\
(
ic_chunk
//
groups
)
+
ic_outer
,
oh
*
HSTR
+
kh
,
ow
*
WSTR
+
kw
,
ic_f_inner
*
n_elems
+
ic_s_inner
]
.
astype
(
out_dtype
)
*
kernel
[
occ
,
ic_outer
,
kh
,
kw
,
ic_f_inner
,
oc_block
,
ic_s_inner
]
.
astype
(
out_dtype
),
axis
=
[
kh
,
kw
,
ic_outer
,
ic_f_inner
,
ic_s_inner
]),
name
=
'conv2d_NCHWc_int8'
,
tag
=
"conv2d_NCHWc_int8"
)
# else: fp implementation
return
tvm
.
compute
(
oshape
,
lambda
n
,
oc_chunk
,
oh
,
ow
,
oc_block
:
tvm
.
sum
(
data_pad
[
n
,
ic
//
ic_bn
,
oh
*
HSTR
+
kh
*
dilation_h
,
ow
*
WSTR
+
kw
*
dilation_w
,
ic
%
ic_bn
]
.
astype
(
out_dtype
)
*
kernel
[
oc_chunk
,
ic
//
ic_bn
,
kh
,
kw
,
ic
%
ic_bn
,
oc_block
],
axis
=
[
ic
,
kh
,
kw
]),
name
=
'conv2d_NCHWc'
,
tag
=
"conv2d_NCHWc"
)
@autotvm.register_topi_schedule
(
generic
.
schedule_conv2d_NCHWc
,
'cpu'
,
[
'direct'
])
@autotvm.register_topi_schedule
(
generic
.
schedule_conv2d_NCHWc
,
'cpu'
,
[
'direct'
])
...
@@ -664,19 +546,11 @@ def _schedule_conv2d_NCHWc(cfg, outs):
...
@@ -664,19 +546,11 @@ def _schedule_conv2d_NCHWc(cfg, outs):
args
=
[
s
,
cfg
,
data_vec
,
conv_out
,
outs
[
0
]]
args
=
[
s
,
cfg
,
data_vec
,
conv_out
,
outs
[
0
]]
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
if
_is_int8_hw_support
(
data
.
dtype
,
kernel
.
dtype
,
target
):
_
,
_
,
kh
,
kw
,
_
,
_
,
=
get_const_tuple
(
kernel
.
shape
)
# int8 conv kernel is 7-dim
if
kh
==
1
and
kw
==
1
:
_
,
_
,
kh
,
kw
,
_
,
_
,
_
=
get_const_tuple
(
kernel
.
shape
)
conv2d_avx_1x1
.
_schedule_conv_NCHWc
(
*
args
)
if
kh
==
1
and
kw
==
1
:
conv2d_avx_1x1
.
_schedule_conv_NCHWc_int8
(
*
args
)
else
:
conv2d_avx_common
.
_schedule_conv_NCHWc_int8
(
*
args
)
else
:
else
:
_
,
_
,
kh
,
kw
,
_
,
_
,
=
get_const_tuple
(
kernel
.
shape
)
conv2d_avx_common
.
_schedule_conv_NCHWc
(
*
args
)
if
kh
==
1
and
kw
==
1
:
conv2d_avx_1x1
.
_schedule_conv_NCHWc
(
*
args
)
else
:
conv2d_avx_common
.
_schedule_conv_NCHWc
(
*
args
)
scheduled_ops
.
append
(
op
)
scheduled_ops
.
append
(
op
)
...
...
topi/python/topi/x86/conv2d_int8.py
0 → 100644
View file @
eb220d92
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D int8 schedule on x86"""
import
tvm
from
tvm
import
autotvm
from
..
import
generic
,
tag
from
..util
import
get_const_tuple
from
..nn.conv2d
import
conv2d_NCHWc_int8
from
..
import
nn
from
.conv2d
import
_get_default_config
from
.
import
conv2d_avx_1x1
,
conv2d_avx_common
@autotvm.register_topi_compute
(
conv2d_NCHWc_int8
,
'cpu'
,
'direct'
)
def
_declaration_conv_NCHWc_int8
(
cfg
,
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_layout
,
out_dtype
):
n
,
ic_chunk
,
ih
,
iw
,
ic_bn
=
get_const_tuple
(
data
.
shape
)
in_channel
=
ic_chunk
*
ic_bn
oc_chunk
,
_
,
kernel_height
,
kernel_width
,
_
,
oc_bn
,
_
=
\
get_const_tuple
(
kernel
.
shape
)
num_filter
=
oc_chunk
*
oc_bn
# If config is not set, we can reuse the default config for NCHW.
if
cfg
.
is_fallback
:
_get_default_config
(
cfg
,
tvm
.
placeholder
((
n
,
in_channel
,
ih
,
iw
),
dtype
=
data
.
dtype
),
tvm
.
placeholder
((
num_filter
,
in_channel
,
kernel_height
,
kernel_width
),
dtype
=
kernel
.
dtype
),
strides
,
padding
,
out_dtype
)
return
nn
.
conv2d_NCHWc_int8_compute
(
data
,
kernel
,
strides
,
padding
,
dilation
,
layout
,
out_layout
,
out_dtype
)
@autotvm.register_topi_schedule
(
generic
.
schedule_conv2d_NCHWc_int8
,
'cpu'
,
[
'direct'
])
def
_schedule_conv2d_NCHWc_int8
(
cfg
,
outs
):
"""Create schedule for tensors"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
traverse
(
op
):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if
tag
.
is_broadcast
(
op
.
tag
):
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
for
tensor
in
op
.
input_tensors
:
if
isinstance
(
tensor
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
if
'conv2d_NCHWc_int8'
in
op
.
tag
:
conv_out
=
op
.
output
(
0
)
kernel
=
conv_out
.
op
.
input_tensors
[
1
]
data_vec
=
conv_out
.
op
.
input_tensors
[
0
]
data
=
data_vec
.
op
.
input_tensors
[
0
]
\
if
isinstance
(
data_vec
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"pad"
not
in
data_vec
.
op
.
tag
\
else
data_vec
if
isinstance
(
data
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"pad"
in
data
.
op
.
tag
:
data_pad
=
data
data
=
data_pad
.
op
.
input_tensors
[
0
]
args
=
[
s
,
cfg
,
data_vec
,
conv_out
,
outs
[
0
]]
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
# int8 conv kernel is 7-dim
_
,
_
,
kh
,
kw
,
_
,
_
,
_
=
get_const_tuple
(
kernel
.
shape
)
if
kh
==
1
and
kw
==
1
:
conv2d_avx_1x1
.
_schedule_conv_NCHWc_int8
(
*
args
)
else
:
conv2d_avx_common
.
_schedule_conv_NCHWc_int8
(
*
args
)
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
@autotvm.register_topi_schedule
(
generic
.
schedule_conv2d_nhwc_pack
,
'cpu'
,
[
'direct'
])
def
schedule_conv2d_nhwc_pack
(
cfg
,
outs
):
"""Create schedule for tensors"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
output_op
=
outs
[
0
]
.
op
scheduled_ops
=
[]
def
traverse
(
op
):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if
tag
.
is_broadcast
(
op
.
tag
):
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
else
:
# inject custom schedule
if
len
(
op
.
axis
)
==
4
:
# schedule bias + bn + relu
n
,
h
,
w
,
c
=
op
.
axis
fused
=
s
[
op
]
.
fuse
(
n
,
h
,
w
)
s
[
op
]
.
parallel
(
fused
)
s
[
op
]
.
vectorize
(
c
)
for
tensor
in
op
.
input_tensors
:
if
isinstance
(
tensor
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
if
'conv2d_nhwc_pack_int8'
in
op
.
tag
:
conv_out
=
op
.
output
(
0
)
kernel
=
conv_out
.
op
.
input_tensors
[
1
]
data_vec
=
conv_out
.
op
.
input_tensors
[
0
]
data
=
data_vec
.
op
.
input_tensors
[
0
]
\
if
isinstance
(
data_vec
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"pad"
not
in
data_vec
.
op
.
tag
\
else
data_vec
if
isinstance
(
data
.
op
,
tvm
.
tensor
.
ComputeOp
)
and
"pad"
in
data
.
op
.
tag
:
data_pad
=
data
data
=
data_pad
.
op
.
input_tensors
[
0
]
args
=
[
s
,
cfg
,
data_vec
,
conv_out
,
outs
[
0
]]
if
data
.
dtype
==
'uint8'
:
kh
,
kw
,
_
,
_
,
_
=
get_const_tuple
(
kernel
.
shape
)
if
kh
==
1
and
kw
==
1
:
conv2d_avx_1x1
.
_schedule_conv_nhwc_pack_int8
(
*
args
)
else
:
raise
ValueError
(
"Only support 1x1 kernel with "
"schedule_conv2d_nhwc_pack."
)
else
:
raise
ValueError
(
"Not support this data type {} with "
"schedule_conv2d_nhwc_pack. Only support int8"
.
format
(
data
.
dtype
))
scheduled_ops
.
append
(
op
)
traverse
(
output_op
)
return
s
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment