Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
bb7df695
Commit
bb7df695
authored
Jun 24, 2018
by
Siva
Committed by
Tianqi Chen
Jun 24, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM][CONVOLUTION] Group convolution generalization for NHWC (#1232)
parent
5b33e7b8
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
172 additions
and
13 deletions
+172
-13
nnvm/python/nnvm/frontend/tensorflow.py
+127
-8
nnvm/python/nnvm/top/nn.py
+18
-2
nnvm/src/top/nn/convolution.cc
+2
-1
nnvm/tests/python/compiler/test_top_level2.py
+25
-2
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
bb7df695
...
@@ -33,6 +33,8 @@ class AttrCvt(object):
...
@@ -33,6 +33,8 @@ class AttrCvt(object):
self
.
_ignores
.
append
(
'_input_shapes'
)
self
.
_ignores
.
append
(
'_input_shapes'
)
self
.
_ignores
.
append
(
'T'
)
self
.
_ignores
.
append
(
'T'
)
self
.
_ignores
.
append
(
'use_cudnn_on_gpu'
)
self
.
_ignores
.
append
(
'use_cudnn_on_gpu'
)
self
.
_ignores
.
append
(
'_node_name'
)
self
.
_ignores
.
append
(
'is_training'
)
return
AttrConvert
(
self
.
_op_name
,
self
.
_transforms
,
self
.
_excludes
,
return
AttrConvert
(
self
.
_op_name
,
self
.
_transforms
,
self
.
_excludes
,
self
.
_disables
,
self
.
_ignores
,
self
.
_extras
,
self
.
_disables
,
self
.
_ignores
,
self
.
_extras
,
self
.
_custom_check
)(
inputs
,
attrs
,
*
args
)
self
.
_custom_check
)(
inputs
,
attrs
,
*
args
)
...
@@ -230,6 +232,85 @@ def _conv():
...
@@ -230,6 +232,85 @@ def _conv():
custom_check
=
_dimension_constraint
())(
inputs
,
attr
)
custom_check
=
_dimension_constraint
())(
inputs
,
attr
)
return
_impl
return
_impl
def
_depthwise_conv
():
def
_impl
(
inputs
,
attr
,
params
):
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
input_shapes
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
# Extract kernel shape from params
conv_param_weights
=
params
[
inputs
[
1
]
.
list_output_names
()[
0
]]
if
attr
[
'data_format'
]
==
'NHWC'
:
kernel_h
,
kernel_w
,
_
,
depth_mult
=
conv_param_weights
.
shape
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
0
],
conv_param_weights
.
shape
[
1
])
attr
[
'channels'
]
=
input_shapes
[
0
][
3
]
*
depth_mult
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
0
],
attr
[
'dilations'
][
1
])
elif
attr
[
'data_format'
]
==
'NCHW'
:
depth_mult
,
_
,
kernel_h
,
kernel_w
=
conv_param_weights
.
shape
attr
[
'kernel_shape'
]
=
(
conv_param_weights
.
shape
[
2
],
conv_param_weights
.
shape
[
3
])
attr
[
'channels'
]
=
input_shapes
[
0
][
1
]
*
depth_mult
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
2
],
attr
[
'dilations'
][
3
])
else
:
raise
TypeError
(
"Unsupported data format type : {}"
.
format
(
attr
[
'data_format'
]))
# Fix strides
attr
[
'strides'
]
=
(
attr
[
'strides'
][
1
],
attr
[
'strides'
][
2
])
# Fix groups
attr
[
'groups'
]
=
attr
[
'channels'
]
# Fix padding
attr
[
'padding'
]
=
attr
[
'padding'
]
.
decode
(
"utf-8"
)
if
attr
[
'padding'
]
==
'VALID'
:
attr
[
'padding'
]
=
[
0
,
0
]
elif
attr
[
'padding'
]
==
'SAME'
:
stride_h
,
stride_w
=
attr
[
'strides'
]
kernel_h
,
kernel_w
=
attr
[
'kernel_shape'
]
if
attr
[
'data_format'
]
==
'NHWC'
:
in_h
=
input_shapes
[
0
][
1
]
in_w
=
input_shapes
[
0
][
2
]
else
:
in_h
=
input_shapes
[
0
][
2
]
in_w
=
input_shapes
[
0
][
3
]
pad_v
=
_get_pad_pair
(
in_h
,
kernel_h
,
stride_h
)
pad_h
=
_get_pad_pair
(
in_w
,
kernel_w
,
stride_w
)
if
attr
[
'data_format'
]
==
'NHWC'
:
inputs
[
0
]
=
_sym
.
pad
(
data
=
inputs
[
0
],
pad_width
=
((
0
,
0
),
(
pad_v
[
0
],
pad_v
[
1
]),
(
pad_h
[
0
],
pad_h
[
1
]),
(
0
,
0
)))
else
:
inputs
[
0
]
=
_sym
.
pad
(
data
=
inputs
[
0
],
pad_width
=
((
0
,
0
),
(
0
,
0
),
(
pad_v
[
0
],
pad_v
[
1
]),
(
pad_h
[
0
],
pad_h
[
1
])))
attr
[
'padding'
]
=
[
0
,
0
]
else
:
raise
TypeError
(
"Unsupported padding type : {}"
.
format
(
attr
[
'padding'
]))
if
'kernel_layout'
not
in
attr
:
attr
[
'kernel_layout'
]
=
'HWOI'
if
attr
[
'data_format'
]
==
'NHWC'
else
'OIHW'
return
AttrCvt
(
op_name
=
_dimension_picker
(
'conv'
),
transforms
=
{
'kernel_shape'
:
'kernel_size'
,
'data_format'
:
'layout'
,
'dilations'
:
(
'dilation'
,
(
0
,
0
)),
'group'
:
(
'groups'
,
1
)},
extras
=
{
'use_bias'
:
len
(
inputs
)
==
3
},
custom_check
=
_dimension_constraint
())(
inputs
,
attr
)
return
_impl
def
_decode_image
():
def
_decode_image
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
...
@@ -358,9 +439,27 @@ def _batch_norm():
...
@@ -358,9 +439,27 @@ def _batch_norm():
op_name
=
'batch_norm'
,
op_name
=
'batch_norm'
,
transforms
=
{
'scale_after_normalization'
:
'scale'
,
'variance_epsilon'
:
'epsilon'
},
transforms
=
{
'scale_after_normalization'
:
'scale'
,
'variance_epsilon'
:
'epsilon'
},
extras
=
{
'axis'
:
3
},
# Fix axis
extras
=
{
'axis'
:
3
},
# Fix axis
ignores
=
[
'data_format'
],
disables
=
[
'momentum'
])(
new_inputs
,
attr
)
disables
=
[
'momentum'
])(
new_inputs
,
attr
)
return
_impl
return
_impl
def
_relu6
():
def
_impl
(
inputs
,
attr
,
params
):
return
_sym
.
clip
(
inputs
[
0
],
a_min
=
0
,
a_max
=
6
)
return
_impl
def
_shape
():
def
_impl
(
inputs
,
attr
,
params
):
input_shapes
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
# Fix the -1 dimensions to 1
input_shapes
[
0
]
=
[
1
if
x
==
-
1
else
x
for
x
in
input_shapes
[
0
]]
params
[
attr
[
'_node_name'
]]
=
tvm
.
nd
.
array
(
input_shapes
[
0
])
return
_sym
.
Variable
(
name
=
attr
[
'_node_name'
],
shape
=
params
[
attr
[
'_node_name'
]]
.
shape
)
return
_impl
# compatible operators that do NOT require any conversion.
# compatible operators that do NOT require any conversion.
_identity_list
=
[]
_identity_list
=
[]
...
@@ -392,6 +491,10 @@ _convert_map = {
...
@@ -392,6 +491,10 @@ _convert_map = {
'Add'
:
_elemwise
(
'add'
),
'Add'
:
_elemwise
(
'add'
),
'Rsqrt'
:
_rsqrt
(),
'Rsqrt'
:
_rsqrt
(),
'Squeeze'
:
_squeeze
(),
'Squeeze'
:
_squeeze
(),
'FusedBatchNorm'
:
_batch_norm
(),
'Relu6'
:
_relu6
(),
'DepthwiseConv2dNative'
:
_depthwise_conv
(),
'Shape'
:
_shape
(),
}
}
...
@@ -458,9 +561,13 @@ class GraphProto(object):
...
@@ -458,9 +561,13 @@ class GraphProto(object):
self
.
_num_input
+=
1
self
.
_num_input
+=
1
self
.
_nodes
[
node
.
name
]
=
_sym
.
Variable
(
name
=
node
.
name
)
self
.
_nodes
[
node
.
name
]
=
_sym
.
Variable
(
name
=
node
.
name
)
self
.
_output_shapes
[
node
.
name
]
=
\
try
:
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
self
.
_output_shapes
[
node
.
name
]
=
\
for
shape
in
self
.
_parse_attr
(
node
.
attr
)[
'_output_shapes'
]]
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
for
shape
in
self
.
_parse_attr
(
node
.
attr
)[
'_output_shapes'
]]
except
KeyError
:
raise
NotImplementedError
(
\
"Please freeze the graph with add_shapes=True"
)
elif
node
.
op
==
"Const"
:
elif
node
.
op
==
"Const"
:
# Assuming first Const node as Graph Input node
# Assuming first Const node as Graph Input node
if
self
.
_input_node
==
''
:
if
self
.
_input_node
==
''
:
...
@@ -476,17 +583,29 @@ class GraphProto(object):
...
@@ -476,17 +583,29 @@ class GraphProto(object):
raise
NotImplementedError
(
\
raise
NotImplementedError
(
\
"Const {} couldn't be converted to Param."
.
format
(
node
.
name
))
"Const {} couldn't be converted to Param."
.
format
(
node
.
name
))
self
.
_output_shapes
[
node
.
name
]
=
\
try
:
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
self
.
_output_shapes
[
node
.
name
]
=
\
for
shape
in
self
.
_parse_attr
(
node
.
attr
)[
'_output_shapes'
]]
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
for
shape
in
self
.
_parse_attr
(
node
.
attr
)[
'_output_shapes'
]]
except
KeyError
:
raise
NotImplementedError
(
\
"Please freeze the graph with add_shapes=True"
)
else
:
else
:
attr
=
self
.
_parse_attr
(
node
.
attr
)
attr
=
self
.
_parse_attr
(
node
.
attr
)
self
.
_output_shapes
[
node
.
name
]
=
\
try
:
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
for
shape
in
attr
[
'_output_shapes'
]]
self
.
_output_shapes
[
node
.
name
]
=
\
[
tensor_util
.
TensorShapeProtoToList
(
shape
)
\
for
shape
in
attr
[
'_output_shapes'
]]
except
KeyError
:
raise
NotImplementedError
(
\
"Please freeze the graph with add_shapes=True"
)
# Pass the parsed shapes instead
# Pass the parsed shapes instead
attr
[
"_output_shapes"
]
=
self
.
_output_shapes
[
node
.
name
]
attr
[
"_output_shapes"
]
=
self
.
_output_shapes
[
node
.
name
]
# Pass the node name too in attr
attr
[
"_node_name"
]
=
node
.
name
try
:
try
:
inputs
=
[
self
.
_nodes
[
i
]
for
i
in
node
.
input
]
inputs
=
[
self
.
_nodes
[
i
]
for
i
in
node
.
input
]
input_shapes
=
{}
input_shapes
=
{}
...
...
nnvm/python/nnvm/top/nn.py
View file @
bb7df695
...
@@ -84,6 +84,7 @@ def compute_conv2d(attrs, inputs, _):
...
@@ -84,6 +84,7 @@ def compute_conv2d(attrs, inputs, _):
groups
=
attrs
.
get_int
(
"groups"
)
groups
=
attrs
.
get_int
(
"groups"
)
channels
=
attrs
.
get_int
(
"channels"
)
channels
=
attrs
.
get_int
(
"channels"
)
layout
=
attrs
[
"layout"
]
layout
=
attrs
[
"layout"
]
kernel_layout
=
attrs
[
"kernel_layout"
]
assert
layout
==
"NCHW"
or
layout
==
"NHWC"
assert
layout
==
"NCHW"
or
layout
==
"NHWC"
(
dilation_h
,
dilation_w
)
=
dilation
(
dilation_h
,
dilation_w
)
=
dilation
if
dilation_h
<
1
or
dilation_w
<
1
:
if
dilation_h
<
1
or
dilation_w
<
1
:
...
@@ -97,10 +98,18 @@ def compute_conv2d(attrs, inputs, _):
...
@@ -97,10 +98,18 @@ def compute_conv2d(attrs, inputs, _):
if
groups
==
1
:
if
groups
==
1
:
out
=
topi
.
nn
.
conv2d
(
inputs
[
0
],
kernel
,
strides
,
padding
,
layout
)
out
=
topi
.
nn
.
conv2d
(
inputs
[
0
],
kernel
,
strides
,
padding
,
layout
)
elif
groups
==
get_const_int
(
inputs
[
0
]
.
shape
[
1
])
and
groups
==
channels
:
elif
layout
==
"NCHW"
and
\
groups
==
get_const_int
(
inputs
[
0
]
.
shape
[
1
])
and
\
groups
==
channels
:
out
=
topi
.
nn
.
depthwise_conv2d_nchw
(
inputs
[
0
],
kernel
,
strides
,
padding
)
out
=
topi
.
nn
.
depthwise_conv2d_nchw
(
inputs
[
0
],
kernel
,
strides
,
padding
)
elif
layout
==
"NHWC"
and
\
kernel_layout
==
"HWOI"
and
\
groups
==
get_const_int
(
inputs
[
0
]
.
shape
[
3
])
and
\
groups
==
channels
:
out
=
topi
.
nn
.
depthwise_conv2d_nhwc
(
inputs
[
0
],
kernel
,
strides
,
padding
)
else
:
else
:
raise
ValueError
(
"not support arbitrary group number for now"
)
raise
ValueError
(
"not support arbitrary group number for now"
)
if
attrs
.
get_bool
(
"use_bias"
):
if
attrs
.
get_bool
(
"use_bias"
):
bias
=
inputs
[
2
]
bias
=
inputs
[
2
]
expand_axis
=
1
if
layout
==
"NCHW"
else
0
expand_axis
=
1
if
layout
==
"NCHW"
else
0
...
@@ -112,13 +121,20 @@ def compute_conv2d(attrs, inputs, _):
...
@@ -112,13 +121,20 @@ def compute_conv2d(attrs, inputs, _):
def
schedule_conv2d
(
attrs
,
outs
,
target
):
def
schedule_conv2d
(
attrs
,
outs
,
target
):
"""Schedule definition of conv2d"""
"""Schedule definition of conv2d"""
groups
=
attrs
.
get_int
(
"groups"
)
groups
=
attrs
.
get_int
(
"groups"
)
channels
=
attrs
.
get_int
(
"channels"
)
layout
=
attrs
[
"layout"
]
layout
=
attrs
[
"layout"
]
kernel_layout
=
attrs
[
"kernel_layout"
]
with
tvm
.
target
.
create
(
target
):
with
tvm
.
target
.
create
(
target
):
if
groups
==
1
and
layout
==
"NCHW"
:
if
groups
==
1
and
layout
==
"NCHW"
:
return
topi
.
generic
.
schedule_conv2d_nchw
(
outs
)
return
topi
.
generic
.
schedule_conv2d_nchw
(
outs
)
elif
groups
==
1
and
layout
==
"NHWC"
:
elif
groups
==
1
and
layout
==
"NHWC"
:
return
topi
.
generic
.
schedule_conv2d_nhwc
(
outs
)
return
topi
.
generic
.
schedule_conv2d_nhwc
(
outs
)
return
topi
.
generic
.
schedule_depthwise_conv2d_nchw
(
outs
)
elif
groups
==
channels
and
layout
==
"NCHW"
:
return
topi
.
generic
.
schedule_depthwise_conv2d_nchw
(
outs
)
elif
groups
==
channels
and
layout
==
"NHWC"
and
kernel_layout
==
"HWOI"
:
return
topi
.
generic
.
schedule_depthwise_conv2d_nhwc
(
outs
)
else
:
raise
ValueError
(
"No compatible schedule"
)
@reg.register_alter_op_layout
(
"conv2d"
)
@reg.register_alter_op_layout
(
"conv2d"
)
def
alter_conv2d_layout
(
attrs
,
inputs
,
tinfos
):
def
alter_conv2d_layout
(
attrs
,
inputs
,
tinfos
):
...
...
nnvm/src/top/nn/convolution.cc
View file @
bb7df695
...
@@ -79,7 +79,8 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
...
@@ -79,7 +79,8 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
param
.
kernel_size
[
1
]});
param
.
kernel_size
[
1
]});
wshape
=
ConvertLayout
(
wshape
,
kOIHW
,
kernel_layout
);
wshape
=
ConvertLayout
(
wshape
,
kOIHW
,
kernel_layout
);
wshape
[
0
]
*=
param
.
groups
;
wshape
[
kernel_layout
.
indexof
(
'O'
)]
*=
param
.
groups
;
NNVM_ASSIGN_INPUT_SHAPE
(
attrs
,
*
in_shape
,
Conv2DParam
::
kWeight
,
wshape
);
NNVM_ASSIGN_INPUT_SHAPE
(
attrs
,
*
in_shape
,
Conv2DParam
::
kWeight
,
wshape
);
if
(
param
.
use_bias
)
{
if
(
param
.
use_bias
)
{
...
...
nnvm/tests/python/compiler/test_top_level2.py
View file @
bb7df695
...
@@ -58,7 +58,7 @@ def test_dilated_conv2d():
...
@@ -58,7 +58,7 @@ def test_dilated_conv2d():
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
c_np
,
rtol
=
1e-5
)
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
c_np
,
rtol
=
1e-5
)
def
test_grouped_conv2d
():
def
test_grouped_conv2d
_nchw
():
x
=
sym
.
Variable
(
"x"
)
x
=
sym
.
Variable
(
"x"
)
y
=
sym
.
conv2d
(
x
,
channels
=
32
,
kernel_size
=
(
3
,
3
),
groups
=
32
,
y
=
sym
.
conv2d
(
x
,
channels
=
32
,
kernel_size
=
(
3
,
3
),
groups
=
32
,
name
=
"y"
,
padding
=
(
1
,
1
))
name
=
"y"
,
padding
=
(
1
,
1
))
...
@@ -80,6 +80,28 @@ def test_grouped_conv2d():
...
@@ -80,6 +80,28 @@ def test_grouped_conv2d():
c_np
=
c_np
+
bias
.
asnumpy
()
.
reshape
(
kshape
[
0
],
1
,
1
)
c_np
=
c_np
+
bias
.
asnumpy
()
.
reshape
(
kshape
[
0
],
1
,
1
)
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
c_np
,
rtol
=
1e-5
)
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
c_np
,
rtol
=
1e-5
)
def
test_grouped_conv2d_nhwc
():
x
=
sym
.
Variable
(
"x"
)
y
=
sym
.
conv2d
(
x
,
channels
=
32
,
kernel_size
=
(
3
,
3
),
groups
=
32
,
name
=
"y"
,
padding
=
(
1
,
1
),
layout
=
"NHWC"
,
kernel_layout
=
'HWOI'
)
dtype
=
"float32"
dshape
=
(
1
,
18
,
18
,
32
)
kshape
=
(
3
,
3
,
32
,
1
)
oshape
=
(
1
,
18
,
18
,
32
)
shape_dict
=
{
"x"
:
dshape
}
for
target
,
ctx
in
ctx_list
():
graph
,
lib
,
_
=
nnvm
.
compiler
.
build
(
y
,
target
,
shape_dict
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
data
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
dshape
)
.
astype
(
dtype
))
kernel
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
kshape
)
.
astype
(
dtype
))
bias
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
kshape
[
2
])
.
astype
(
dtype
))
m
.
run
(
x
=
data
,
y_weight
=
kernel
,
y_bias
=
bias
)
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
oshape
,
dtype
))
c_np
=
topi
.
testing
.
depthwise_conv2d_python_nhwc
(
data
.
asnumpy
(),
kernel
.
asnumpy
(),
(
1
,
1
),
'SAME'
)
c_np
=
c_np
+
bias
.
asnumpy
()
.
reshape
(
1
,
1
,
kshape
[
2
])
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
c_np
,
rtol
=
1e-5
)
def
test_conv2d_transpose
():
def
test_conv2d_transpose
():
x
=
sym
.
Variable
(
"x"
)
x
=
sym
.
Variable
(
"x"
)
...
@@ -269,7 +291,8 @@ def test_resize_bilinear():
...
@@ -269,7 +291,8 @@ def test_resize_bilinear():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_conv2d
()
test_conv2d
()
test_dilated_conv2d
()
test_dilated_conv2d
()
test_grouped_conv2d
()
test_grouped_conv2d_nchw
()
test_grouped_conv2d_nhwc
()
test_conv2d_transpose
()
test_conv2d_transpose
()
test_max_pool2d
()
test_max_pool2d
()
test_avg_pool2d
()
test_avg_pool2d
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment