Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
046a3ed9
Commit
046a3ed9
authored
Sep 24, 2018
by
Siva
Committed by
Tianqi Chen
Sep 23, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FRONTEND][TENSORFLOW] NCHW layout support (Resnet V1/V2). (#1743)
parent
160e4107
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
150 additions
and
65 deletions
+150
-65
nnvm/python/nnvm/frontend/tensorflow.py
+50
-18
nnvm/tests/python/frontend/tensorflow/test_forward.py
+100
-47
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
046a3ed9
...
@@ -110,11 +110,6 @@ def _elemwise(name):
...
@@ -110,11 +110,6 @@ def _elemwise(name):
def
_impl
(
inputs
,
attr
,
*
args
):
def
_impl
(
inputs
,
attr
,
*
args
):
assert
len
(
inputs
)
==
2
,
"Math op take 2 inputs, {} given"
.
format
(
len
(
inputs
))
assert
len
(
inputs
)
==
2
,
"Math op take 2 inputs, {} given"
.
format
(
len
(
inputs
))
op_name
=
_math_name_picker
(
name
)(
attr
)
op_name
=
_math_name_picker
(
name
)(
attr
)
axis
=
int
(
attr
.
get
(
'axis'
,
0
))
conv_ops
=
[
"conv2d"
,
"conv2d_transpose"
]
if
op_name
==
'broadcast_add'
and
inputs
[
0
]
.
attr
(
'op_name'
)
in
conv_ops
:
# TODO: remove hard coded infershape
inputs
[
1
]
=
_sym
.
expand_dims
(
inputs
[
1
],
axis
=
axis
,
num_newaxis
=
2
)
return
get_nnvm_op
(
op_name
)(
*
inputs
)
return
get_nnvm_op
(
op_name
)(
*
inputs
)
return
_impl
return
_impl
...
@@ -128,8 +123,10 @@ def _pooling(name):
...
@@ -128,8 +123,10 @@ def _pooling(name):
if
attr
[
'data_format'
]
==
'NHWC'
:
if
attr
[
'data_format'
]
==
'NHWC'
:
attr
[
'kernel_shape'
]
=
(
attr
[
'ksize'
][
1
],
attr
[
'ksize'
][
2
])
attr
[
'kernel_shape'
]
=
(
attr
[
'ksize'
][
1
],
attr
[
'ksize'
][
2
])
attr
[
'strides'
]
=
(
attr
[
'strides'
][
1
],
attr
[
'strides'
][
2
])
elif
attr
[
'data_format'
]
==
'NCHW'
:
elif
attr
[
'data_format'
]
==
'NCHW'
:
attr
[
'kernel_shape'
]
=
(
attr
[
'ksize'
][
2
],
attr
[
'ksize'
][
3
])
attr
[
'kernel_shape'
]
=
(
attr
[
'ksize'
][
2
],
attr
[
'ksize'
][
3
])
attr
[
'strides'
]
=
(
attr
[
'strides'
][
2
],
attr
[
'strides'
][
3
])
else
:
else
:
raise
TypeError
(
"Unsupported data_format type : {}"
.
format
(
attr
[
'data_format'
]))
raise
TypeError
(
"Unsupported data_format type : {}"
.
format
(
attr
[
'data_format'
]))
...
@@ -140,9 +137,6 @@ def _pooling(name):
...
@@ -140,9 +137,6 @@ def _pooling(name):
attr
[
'data_format'
]
=
"NCHW"
attr
[
'data_format'
]
=
"NCHW"
flip_layout
=
True
flip_layout
=
True
# Fix strides
attr
[
'strides'
]
=
(
attr
[
'strides'
][
1
],
attr
[
'strides'
][
2
])
# Fix padding
# Fix padding
attr
[
'padding'
]
=
attr
[
'padding'
]
.
decode
(
"utf-8"
)
attr
[
'padding'
]
=
attr
[
'padding'
]
.
decode
(
"utf-8"
)
...
@@ -188,8 +182,15 @@ def _conv(opname):
...
@@ -188,8 +182,15 @@ def _conv(opname):
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
flip_layout
=
False
flip_layout
=
False
# NCHW Layout require weights transpose
if
attr
[
'data_format'
]
==
'NCHW'
:
tmp_shape
=
attr
[
'_input_shapes'
][
inputs
[
1
]][
0
]
tmp_shape
=
[
tmp_shape
[
ii
]
for
ii
in
(
3
,
2
,
0
,
1
)]
inputs
[
1
]
=
_sym
.
transpose
(
inputs
[
1
],
axes
=
(
3
,
2
,
0
,
1
))
attr
[
'_input_shapes'
][
inputs
[
1
]]
=
[
tmp_shape
]
input_shape
=
attr
[
'_input_shapes'
][
inputs
[
0
]][
0
]
input_shape
=
attr
[
'_input_shapes'
][
inputs
[
0
]][
0
]
weights_shape
=
params
[
inputs
[
1
]
.
list_output_names
()[
0
]]
.
shape
weights_shape
=
attr
[
'_input_shapes'
][
inputs
[
1
]][
0
]
if
attr
[
'_target_layout'
]
==
"NCHW"
and
attr
[
'data_format'
]
==
"NHWC"
:
if
attr
[
'_target_layout'
]
==
"NCHW"
and
attr
[
'data_format'
]
==
"NHWC"
:
input_shape
=
[
input_shape
[
ii
]
for
ii
in
(
0
,
3
,
1
,
2
)]
input_shape
=
[
input_shape
[
ii
]
for
ii
in
(
0
,
3
,
1
,
2
)]
...
@@ -202,6 +203,7 @@ def _conv(opname):
...
@@ -202,6 +203,7 @@ def _conv(opname):
inputs
[
1
]
=
_sym
.
transpose
(
inputs
[
1
],
axes
=
(
2
,
3
,
0
,
1
))
inputs
[
1
]
=
_sym
.
transpose
(
inputs
[
1
],
axes
=
(
2
,
3
,
0
,
1
))
attr
[
'data_format'
]
=
"NCHW"
attr
[
'data_format'
]
=
"NCHW"
attr
[
'strides'
]
=
[
attr
[
'strides'
][
ii
]
for
ii
in
(
0
,
3
,
1
,
2
)]
flip_layout
=
True
flip_layout
=
True
if
attr
[
'data_format'
]
==
'NHWC'
:
if
attr
[
'data_format'
]
==
'NHWC'
:
...
@@ -214,6 +216,7 @@ def _conv(opname):
...
@@ -214,6 +216,7 @@ def _conv(opname):
if
'dilations'
in
attr
:
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
0
],
attr
[
'dilations'
][
1
])
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
0
],
attr
[
'dilations'
][
1
])
attr
[
'strides'
]
=
(
attr
[
'strides'
][
1
],
attr
[
'strides'
][
2
])
elif
attr
[
'data_format'
]
==
'NCHW'
:
elif
attr
[
'data_format'
]
==
'NCHW'
:
depth_mult
,
_
,
kernel_h
,
kernel_w
=
weights_shape
depth_mult
,
_
,
kernel_h
,
kernel_w
=
weights_shape
attr
[
'kernel_shape'
]
=
(
weights_shape
[
2
],
weights_shape
[
3
])
attr
[
'kernel_shape'
]
=
(
weights_shape
[
2
],
weights_shape
[
3
])
...
@@ -226,6 +229,7 @@ def _conv(opname):
...
@@ -226,6 +229,7 @@ def _conv(opname):
if
'dilations'
in
attr
:
if
'dilations'
in
attr
:
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
2
],
attr
[
'dilations'
][
3
])
attr
[
'dilations'
]
=
(
attr
[
'dilations'
][
2
],
attr
[
'dilations'
][
3
])
attr
[
'strides'
]
=
(
attr
[
'strides'
][
2
],
attr
[
'strides'
][
3
])
else
:
else
:
raise
TypeError
(
"Unsupported data format type : {}"
.
format
(
attr
[
'data_format'
]))
raise
TypeError
(
"Unsupported data format type : {}"
.
format
(
attr
[
'data_format'
]))
...
@@ -233,9 +237,6 @@ def _conv(opname):
...
@@ -233,9 +237,6 @@ def _conv(opname):
if
opname
==
'depthwise'
:
if
opname
==
'depthwise'
:
attr
[
'groups'
]
=
attr
[
'channels'
]
attr
[
'groups'
]
=
attr
[
'channels'
]
# Fix strides
attr
[
'strides'
]
=
(
attr
[
'strides'
][
1
],
attr
[
'strides'
][
2
])
# Fix padding
# Fix padding
attr
[
'padding'
]
=
attr
[
'padding'
]
.
decode
(
"utf-8"
)
attr
[
'padding'
]
=
attr
[
'padding'
]
.
decode
(
"utf-8"
)
...
@@ -416,12 +417,27 @@ def _fused_batch_norm():
...
@@ -416,12 +417,27 @@ def _fused_batch_norm():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
# Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
# Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
# NNVM: (data, gamma, beta, moving_mean, moving_varience)
# NNVM: (data, gamma, beta, moving_mean, moving_varience)
return
AttrCvt
(
axis
=
3
op_name
=
'batch_norm'
,
need_cast
=
False
transforms
=
{
'scale_after_normalization'
:
'scale'
,
'variance_epsilon'
:
'epsilon'
},
extras
=
{
'axis'
:
3
},
# Fix axis
if
'data_format'
in
attr
:
ignores
=
[
'data_format'
],
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
if
attr
[
'data_format'
]
==
'NCHW'
:
axis
=
1
if
'U'
in
attr
:
need_cast
=
True
inputs
[
0
]
=
_sym
.
cast
(
inputs
[
0
],
dtype
=
attr
[
'U'
]
.
name
)
out
=
AttrCvt
(
op_name
=
'batch_norm'
,
transforms
=
{
'scale_after_normalization'
:
'scale'
,
'variance_epsilon'
:
'epsilon'
},
extras
=
{
'axis'
:
axis
},
ignores
=
[
'data_format'
,
'U'
],
disables
=
[
'momentum'
])(
inputs
,
attr
)
disables
=
[
'momentum'
])(
inputs
,
attr
)
if
need_cast
:
out
=
_sym
.
cast
(
out
,
dtype
=
attr
[
'T'
]
.
name
)
return
out
return
_impl
return
_impl
def
_batch_norm
():
def
_batch_norm
():
...
@@ -432,10 +448,16 @@ def _batch_norm():
...
@@ -432,10 +448,16 @@ def _batch_norm():
# (data, gamma, beta, moving_mean, moving_var)
# (data, gamma, beta, moving_mean, moving_var)
new_inputs
=
[
inputs
[
0
],
inputs
[
4
],
inputs
[
3
],
inputs
[
1
],
inputs
[
2
]]
new_inputs
=
[
inputs
[
0
],
inputs
[
4
],
inputs
[
3
],
inputs
[
1
],
inputs
[
2
]]
axis
=
3
if
'data_format'
in
attr
:
attr
[
'data_format'
]
=
attr
[
'data_format'
]
.
decode
(
"utf-8"
)
if
attr
[
'data_format'
]
==
'NCHW'
:
axis
=
1
return
AttrCvt
(
return
AttrCvt
(
op_name
=
'batch_norm'
,
op_name
=
'batch_norm'
,
transforms
=
{
'scale_after_normalization'
:
'scale'
,
'variance_epsilon'
:
'epsilon'
},
transforms
=
{
'scale_after_normalization'
:
'scale'
,
'variance_epsilon'
:
'epsilon'
},
extras
=
{
'axis'
:
3
},
# Fix axis
extras
=
{
'axis'
:
axis
},
ignores
=
[
'data_format'
],
ignores
=
[
'data_format'
],
disables
=
[
'momentum'
])(
new_inputs
,
attr
)
disables
=
[
'momentum'
])(
new_inputs
,
attr
)
return
_impl
return
_impl
...
@@ -729,6 +751,14 @@ def _selu():
...
@@ -729,6 +751,14 @@ def _selu():
return
gamma
*
(
-
alpha
*
_sym
.
relu
(
1
-
_sym
.
exp
(
inputs
[
0
]))
+
_sym
.
relu
(
inputs
[
0
]))
return
gamma
*
(
-
alpha
*
_sym
.
relu
(
1
-
_sym
.
exp
(
inputs
[
0
]))
+
_sym
.
relu
(
inputs
[
0
]))
return
_impl
return
_impl
def
_mean
():
def
_impl
(
inputs
,
attr
,
params
):
axis
=
params
.
pop
(
inputs
[
1
]
.
list_output_names
()[
0
])
return
AttrCvt
(
op_name
=
"mean"
,
ignores
=
[
'Tdim'
,
'Tidx'
],
transforms
=
{
'keep_dims'
:
'keepdims'
},
extras
=
{
'axis'
:
tuple
(
axis
.
asnumpy
())})(
inputs
[
0
],
attr
)
return
_impl
# compatible operators that do NOT require any conversion.
# compatible operators that do NOT require any conversion.
_identity_list
=
[]
_identity_list
=
[]
...
@@ -773,6 +803,7 @@ _convert_map = {
...
@@ -773,6 +803,7 @@ _convert_map = {
'Rsqrt'
:
_rsqrt
(),
'Rsqrt'
:
_rsqrt
(),
'Squeeze'
:
_squeeze
(),
'Squeeze'
:
_squeeze
(),
'FusedBatchNorm'
:
_fused_batch_norm
(),
'FusedBatchNorm'
:
_fused_batch_norm
(),
'FusedBatchNormV2'
:
_fused_batch_norm
(),
'Relu6'
:
_relu6
(),
'Relu6'
:
_relu6
(),
'DepthwiseConv2dNative'
:
_conv
(
'depthwise'
),
'DepthwiseConv2dNative'
:
_conv
(
'depthwise'
),
'Shape'
:
_shape
(),
'Shape'
:
_shape
(),
...
@@ -787,6 +818,7 @@ _convert_map = {
...
@@ -787,6 +818,7 @@ _convert_map = {
'Rank'
:
_rank
(),
'Rank'
:
_rank
(),
'Transpose'
:
_transpose
(),
'Transpose'
:
_transpose
(),
'Tanh'
:
AttrCvt
(
'tanh'
),
'Tanh'
:
AttrCvt
(
'tanh'
),
'Mean'
:
_mean
(),
}
}
# _convert_map_rnn defines maps of rnn operator name to
# _convert_map_rnn defines maps of rnn operator name to
...
...
nnvm/tests/python/frontend/tensorflow/test_forward.py
View file @
046a3ed9
...
@@ -88,7 +88,7 @@ def run_tf_graph(sess, input_data, input_node, output_node):
...
@@ -88,7 +88,7 @@ def run_tf_graph(sess, input_data, input_node, output_node):
return
output_data
return
output_data
def
compare_tf_with_tvm
(
in_data
,
in_name
,
out_name
,
init_global_variables
=
False
):
def
compare_tf_with_tvm
(
in_data
,
in_name
,
out_name
,
init_global_variables
=
False
,
no_gpu
=
False
):
"""Generic function to generate and compare tensorflow and TVM output"""
"""Generic function to generate and compare tensorflow and TVM output"""
out_node
=
out_name
.
split
(
':'
)[
0
]
if
":"
in
out_name
else
out_name
out_node
=
out_name
.
split
(
':'
)[
0
]
if
":"
in
out_name
else
out_name
...
@@ -116,6 +116,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
...
@@ -116,6 +116,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
if
not
ctx
.
exist
:
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
print
(
"Skip because
%
s is not enabled"
%
device
)
continue
continue
if
no_gpu
and
device
==
'cuda'
:
continue
tvm_output
=
run_tvm_graph
(
final_graph_def
,
in_data
,
tvm_output
=
run_tvm_graph
(
final_graph_def
,
in_data
,
in_node
,
tf_output
.
shape
,
tf_output
.
dtype
,
target
=
device
)
in_node
,
tf_output
.
shape
,
tf_output
.
dtype
,
target
=
device
)
...
@@ -123,10 +125,20 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
...
@@ -123,10 +125,20 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False)
sess
.
close
()
sess
.
close
()
def
is_gpu_available
():
from
tensorflow.python.client
import
device_lib
local_device_protos
=
device_lib
.
list_local_devices
()
gpu_list
=
[
x
.
name
for
x
in
local_device_protos
if
x
.
device_type
==
'GPU'
]
if
len
(
gpu_list
)
<
0
:
print
(
"Tensorflow GPU:"
,
gpu_list
)
return
True
else
:
return
False
#######################################################################
#######################################################################
# Pooling
# Pooling
# -------
# -------
def
_test_pooling
(
input_shape
,
**
kwargs
):
def
_test_pooling
_iteration
(
input_shape
,
**
kwargs
):
""" One iteration of pool operation with given shapes and attributes """
""" One iteration of pool operation with given shapes and attributes """
x
=
-
np
.
arange
(
x
=
-
np
.
arange
(
...
@@ -143,61 +155,45 @@ def _test_pooling(input_shape, **kwargs):
...
@@ -143,61 +155,45 @@ def _test_pooling(input_shape, **kwargs):
compare_tf_with_tvm
(
x
,
'Placeholder:0'
,
out_name
)
compare_tf_with_tvm
(
x
,
'Placeholder:0'
,
out_name
)
def
_test_pooling
(
input_shape
,
**
kwargs
):
_test_pooling_iteration
(
input_shape
,
**
kwargs
)
if
is_gpu_available
():
input_shape
=
[
input_shape
[
ii
]
for
ii
in
(
0
,
3
,
1
,
2
)]
kwargs
[
'data_layout'
]
=
'NCHW'
_test_pooling_iteration
(
input_shape
,
**
kwargs
)
def
test_forward_pooling
():
def
test_forward_pooling
():
""" Pooling """
""" Pooling """
for
pool_type
in
[
'AVG'
,
'MAX'
]:
_test_pooling
(
input_shape
=
[
2
,
9
,
10
,
2
],
_test_pooling
(
input_shape
=
[
2
,
9
,
10
,
2
],
window_shape
=
[
1
,
1
],
window_shape
=
[
1
,
1
],
padding
=
'SAME'
,
padding
=
'SAME'
,
pooling_type
=
'MAX'
,
pooling_type
=
pool_type
,
dilation_rate
=
[
1
,
1
],
strides
=
[
1
,
1
])
_test_pooling
(
input_shape
=
[
2
,
9
,
10
,
2
],
window_shape
=
[
1
,
1
],
padding
=
'SAME'
,
pooling_type
=
'AVG'
,
dilation_rate
=
[
1
,
1
],
dilation_rate
=
[
1
,
1
],
strides
=
[
1
,
1
])
strides
=
[
1
,
1
])
_test_pooling
(
input_shape
=
[
2
,
10
,
9
,
2
],
_test_pooling
(
input_shape
=
[
2
,
10
,
9
,
2
],
window_shape
=
[
1
,
1
],
window_shape
=
[
1
,
1
],
padding
=
'SAME'
,
padding
=
'SAME'
,
pooling_type
=
'MAX'
,
pooling_type
=
pool_type
,
dilation_rate
=
[
1
,
1
],
strides
=
[
1
,
1
])
_test_pooling
(
input_shape
=
[
2
,
10
,
9
,
2
],
window_shape
=
[
1
,
1
],
padding
=
'SAME'
,
pooling_type
=
'AVG'
,
dilation_rate
=
[
1
,
1
],
dilation_rate
=
[
1
,
1
],
strides
=
[
1
,
1
])
strides
=
[
1
,
1
])
_test_pooling
(
input_shape
=
[
2
,
9
,
10
,
2
],
_test_pooling
(
input_shape
=
[
2
,
9
,
10
,
2
],
window_shape
=
[
2
,
1
],
window_shape
=
[
2
,
1
],
padding
=
'SAME'
,
padding
=
'SAME'
,
pooling_type
=
'MAX'
,
pooling_type
=
pool_type
,
dilation_rate
=
[
1
,
1
],
dilation_rate
=
[
1
,
1
],
strides
=
[
1
,
1
])
strides
=
[
1
,
1
])
_test_pooling
(
input_shape
=
[
2
,
9
,
10
,
2
],
window_shape
=
[
2
,
1
],
padding
=
'SAME'
,
pooling_type
=
'AVG'
,
dilation_rate
=
[
1
,
1
],
strides
=
[
2
,
1
])
_test_pooling
(
input_shape
=
[
2
,
10
,
9
,
2
],
_test_pooling
(
input_shape
=
[
2
,
10
,
9
,
2
],
window_shape
=
[
2
,
3
],
window_shape
=
[
2
,
3
],
padding
=
'SAME'
,
padding
=
'SAME'
,
pooling_type
=
'MAX'
,
pooling_type
=
pool_type
,
dilation_rate
=
[
1
,
1
],
dilation_rate
=
[
1
,
1
],
strides
=
[
2
,
1
])
strides
=
[
2
,
1
])
_test_pooling
(
input_shape
=
[
2
,
10
,
9
,
2
],
window_shape
=
[
2
,
3
],
padding
=
'SAME'
,
pooling_type
=
'AVG'
,
dilation_rate
=
[
1
,
1
],
strides
=
[
1
,
2
])
#######################################################################
#######################################################################
# Convolution
# Convolution
...
@@ -234,6 +230,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
...
@@ -234,6 +230,12 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
'Placeholder:0'
,
'Conv2D:0'
)
'Placeholder:0'
,
'Conv2D:0'
)
def
test_forward_convolution
():
def
test_forward_convolution
():
if
is_gpu_available
():
_test_convolution
([
4
,
176
,
8
,
8
],
[
1
,
1
,
176
,
32
],
[
1
,
1
],
[
1
,
1
],
'SAME'
,
'NCHW'
)
_test_convolution
([
4
,
19
,
17
,
17
],
[
3
,
3
,
19
,
19
],
[
1
,
1
],
[
2
,
2
],
'VALID'
,
'NCHW'
)
_test_convolution
([
4
,
124
,
17
,
17
],
[
1
,
1
,
124
,
19
],
[
1
,
1
],
[
1
,
1
],
'SAME'
,
'NCHW'
)
_test_convolution
([
4
,
12
,
17
,
17
],
[
3
,
3
,
12
,
32
],
[
1
,
1
],
[
2
,
2
],
'VALID'
,
'NCHW'
)
_test_convolution
([
4
,
8
,
8
,
176
],
[
1
,
1
,
176
,
32
],
[
1
,
1
],
[
1
,
1
],
'SAME'
,
'NHWC'
)
_test_convolution
([
4
,
8
,
8
,
176
],
[
1
,
1
,
176
,
32
],
[
1
,
1
],
[
1
,
1
],
'SAME'
,
'NHWC'
)
_test_convolution
([
4
,
17
,
17
,
19
],
[
3
,
3
,
19
,
19
],
[
1
,
1
],
[
2
,
2
],
'VALID'
,
'NHWC'
)
_test_convolution
([
4
,
17
,
17
,
19
],
[
3
,
3
,
19
,
19
],
[
1
,
1
],
[
2
,
2
],
'VALID'
,
'NHWC'
)
_test_convolution
([
4
,
17
,
17
,
124
],
[
1
,
1
,
124
,
19
],
[
1
,
1
],
[
1
,
1
],
'SAME'
,
'NHWC'
)
_test_convolution
([
4
,
17
,
17
,
124
],
[
1
,
1
,
124
,
19
],
[
1
,
1
],
[
1
,
1
],
'SAME'
,
'NHWC'
)
...
@@ -712,6 +714,25 @@ def test_forward_mobilenet():
...
@@ -712,6 +714,25 @@ def test_forward_mobilenet():
np
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
),
np
.
squeeze
(
tf_output
),
rtol
=
1e-5
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
),
np
.
squeeze
(
tf_output
),
rtol
=
1e-5
,
atol
=
1e-5
)
#######################################################################
#######################################################################
# ResnetV2
# ---------
def
test_forward_resnetv2
():
'''test resnet model'''
if
is_gpu_available
():
with
tf
.
Graph
()
.
as_default
():
graph_def
=
nnvm
.
testing
.
tf
.
get_workload
(
"ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb"
)
# Call the utility to import the graph definition into default graph.
graph_def
=
nnvm
.
testing
.
tf
.
ProcessGraphDefParam
(
graph_def
)
data
=
np
.
random
.
uniform
(
size
=
(
128
,
224
,
224
,
3
))
.
astype
(
'float32'
)
out_node
=
'ArgMax'
with
tf
.
Session
()
as
sess
:
tf_output
=
run_tf_graph
(
sess
,
data
,
'input_tensor:0'
,
out_node
+
':0'
)
tvm_output
=
run_tvm_graph
(
graph_def
,
data
,
'input_tensor'
,
tf_output
.
shape
,
'float32'
)
np
.
testing
.
assert_allclose
(
np
.
squeeze
(
tvm_output
),
np
.
squeeze
(
tf_output
),
rtol
=
1e-5
,
atol
=
1e-5
)
#######################################################################
# PTB
# PTB
# ---
# ---
dir
(
tf
.
contrib
)
dir
(
tf
.
contrib
)
...
@@ -947,37 +968,69 @@ def test_forward_tanh():
...
@@ -947,37 +968,69 @@ def test_forward_tanh():
compare_tf_with_tvm
(
inp_array
,
'Placeholder:0'
,
'Tanh:0'
)
compare_tf_with_tvm
(
inp_array
,
'Placeholder:0'
,
'Tanh:0'
)
#######################################################################
#######################################################################
# Mean
# ----
def
test_forward_mean
():
def
check_mean
(
ishape
,
**
kwargs
):
inp_array
=
np
.
random
.
uniform
(
size
=
ishape
)
.
astype
(
np
.
float32
)
with
tf
.
Graph
()
.
as_default
():
in1
=
tf
.
placeholder
(
shape
=
inp_array
.
shape
,
dtype
=
inp_array
.
dtype
)
tf
.
keras
.
backend
.
mean
(
in1
,
**
kwargs
)
compare_tf_with_tvm
(
inp_array
,
'Placeholder:0'
,
'Mean:0'
,
no_gpu
=
True
)
check_mean
((
10
,
8
,
16
,
32
))
check_mean
((
10
,
8
,
16
,
32
),
axis
=
(
2
,
3
))
check_mean
((
10
,
8
,
16
,
32
),
axis
=
(
1
,
2
),
keepdims
=
True
)
#######################################################################
# Main
# Main
# ----
# ----
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
# Transforms
test_forward_transpose
()
test_forward_transpose
()
test_forward_convolution
()
test_forward_pooling
()
test_forward_reshape
()
test_forward_reshape
()
test_forward_squeeze
()
test_forward_squeeze
()
test_forward_pack
()
test_forward_resize_bilinear
()
test_forward_pad
()
test_forward_gather
()
#test_forward_stridedslice()
# Activations
test_forward_sigmoid
()
test_forward_sigmoid
()
test_forward_relu
()
test_forward_leaky_relu
()
test_forward_elu
()
test_forward_selu
()
test_forward_tanh
()
# Reductions
test_forward_argminmax
()
test_forward_argminmax
()
test_forward_reduce
()
test_forward_reduce
()
test_forward_mean
()
# NN
test_forward_convolution
()
test_forward_pooling
()
if
tf
.
__version__
==
'1.4.1'
:
if
tf
.
__version__
==
'1.4.1'
:
_test_forward_concat_v2
()
_test_forward_concat_v2
()
test_forward_lrn
()
test_forward_l2_normalize
()
# General
test_forward_multi_input
()
test_forward_multi_input
()
test_forward_pack
()
test_forward_variable
()
# End to End
test_forward_inception_v3
()
test_forward_inception_v3
()
test_forward_inception_v1
()
test_forward_inception_v1
()
test_forward_mobilenet
()
test_forward_mobilenet
()
test_forward_variable
()
test_forward_resnetv2
()
test_forward_resize_bilinear
()
test_forward_pad
()
#test_forward_lstm()
#test_forward_stridedslice()
test_forward_gather
()
test_forward_ptb
()
test_forward_ptb
()
test_forward_lrn
()
test_forward_l2_normalize
()
# RNN
#test_forward_lstm()
# Elementwise
test_forward_ceil
()
test_forward_ceil
()
test_forward_floor
()
test_forward_floor
()
test_forward_relu
()
test_forward_leaky_relu
()
test_forward_elu
()
test_forward_selu
()
test_forward_tanh
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment