Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9bb16872
Commit
9bb16872
authored
Jun 13, 2019
by
Yong Wu
Committed by
Tianqi Chen
Jun 13, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Frontend] Add a bunch of ops in tf converter (#3270)
parent
c9e96d9f
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
305 additions
and
22 deletions
+305
-22
python/tvm/relay/frontend/tensorflow.py
+41
-12
tests/python/frontend/tensorflow/test_forward.py
+264
-10
No files found.
python/tvm/relay/frontend/tensorflow.py
View file @
9bb16872
...
@@ -777,12 +777,12 @@ def _sum():
...
@@ -777,12 +777,12 @@ def _sum():
ignores
=
[
'name'
,
'Tidx'
])([
inputs
[
0
]],
attr
)
ignores
=
[
'name'
,
'Tidx'
])([
inputs
[
0
]],
attr
)
return
_impl
return
_impl
def
_reduce
_all
(
):
def
_reduce
(
op
):
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
axis
=
params
.
pop
(
inputs
[
1
]
.
name_hint
)
.
asnumpy
()
axis
=
params
.
pop
(
inputs
[
1
]
.
name_hint
)
.
asnumpy
()
axis
=
tuple
(
axis
)
axis
=
tuple
(
axis
)
return
AttrCvt
(
return
AttrCvt
(
op_name
=
'all'
,
op_name
=
op
,
extras
=
{
'axis'
:
axis
},
extras
=
{
'axis'
:
axis
},
transforms
=
{
'keep_dims'
:
'keepdims'
},
transforms
=
{
'keep_dims'
:
'keepdims'
},
ignores
=
[
'name'
,
'Tidx'
])([
inputs
[
0
]],
attr
)
ignores
=
[
'name'
,
'Tidx'
])([
inputs
[
0
]],
attr
)
...
@@ -807,6 +807,14 @@ def _gather():
...
@@ -807,6 +807,14 @@ def _gather():
'Taxis'
,
'_class'
])(
new_input
,
attr
)
'Taxis'
,
'_class'
])(
new_input
,
attr
)
return
_impl
return
_impl
def
_gather_nd
():
"""GatherNd"""
def
_impl
(
inputs
,
attr
,
params
):
return
AttrCvt
(
op_name
=
"gather_nd"
,
ignores
=
[
'Tindices'
,
'Tparams'
,
\
'Taxis'
,
'_class'
])(
inputs
,
attr
)
return
_impl
def
_stridedSlice
():
def
_stridedSlice
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
"""Strided Slice.
"""Strided Slice.
...
@@ -971,15 +979,18 @@ def _rank():
...
@@ -971,15 +979,18 @@ def _rank():
def
_range
():
def
_range
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
start
=
_get_num_param
(
params
,
inputs
[
0
])
start
=
params
.
pop
(
inputs
[
0
]
.
name_hint
)
.
asnumpy
()[
0
]
limit
=
_get_num_param
(
params
,
inputs
[
1
])
limit
=
params
.
pop
(
inputs
[
1
]
.
name_hint
)
.
asnumpy
()[
0
]
\
delta
=
_get_num_param
(
params
,
inputs
[
2
])
if
hasattr
(
inputs
[
1
],
"name_hint"
)
else
params
.
pop
(
'Rank'
)
.
asnumpy
()[
0
]
delta
=
params
.
pop
(
inputs
[
2
]
.
name_hint
)
.
asnumpy
()[
0
]
name
=
attr
[
"_node_name"
]
dtype
=
attr
[
'dtype'
]
.
name
if
'dtype'
in
attr
else
"int32"
params
[
name
]
=
tvm
.
nd
.
array
([
start
,
limit
,
delta
])
return
AttrCvt
(
return
[
_expr
.
var
(
name
,
op_name
=
"arange"
,
shape
=
params
[
name
]
.
shape
,
ignores
=
[
'Tidx'
],
dtype
=
'int32'
)]
extras
=
{
'start'
:
start
,
"stop"
:
limit
,
'step'
:
delta
,
'dtype'
:
dtype
})([],
attr
)
return
_impl
return
_impl
def
_elu
():
def
_elu
():
...
@@ -1099,6 +1110,13 @@ def _topk():
...
@@ -1099,6 +1110,13 @@ def _topk():
extras
=
{
'k'
:
k
,
'is_ascend'
:
False
,
'dtype'
:
'int32'
})(
inputs
,
attr
)
extras
=
{
'k'
:
k
,
'is_ascend'
:
False
,
'dtype'
:
'int32'
})(
inputs
,
attr
)
return
_impl
return
_impl
def
_floordiv
():
def
_impl
(
inputs
,
attr
,
params
):
assert
len
(
inputs
)
==
2
div
=
AttrCvt
(
'divide'
)(
inputs
,
attr
)
return
_get_relay_op
(
'floor'
)(
div
)
return
_impl
def
_logical
(
name
):
def
_logical
(
name
):
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
return
AttrCvt
(
op_name
=
name
)(
inputs
,
attr
)
return
AttrCvt
(
op_name
=
name
)(
inputs
,
attr
)
...
@@ -1207,8 +1225,9 @@ _identity_list = []
...
@@ -1207,8 +1225,9 @@ _identity_list = []
# for 1 to N mapping(composed), use custom callable functions
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
# for N to 1 mapping, currently not supported(?)
_convert_map
=
{
_convert_map
=
{
'Abs'
:
AttrCvt
(
'abs'
),
'Add'
:
_elemwise
(
'add'
),
'Add'
:
_elemwise
(
'add'
),
'All'
:
_reduce
_all
(
),
'All'
:
_reduce
(
'all'
),
'ArgMax'
:
_argx
(
_op
.
argmax
,
'argmax'
),
'ArgMax'
:
_argx
(
_op
.
argmax
,
'argmax'
),
'ArgMin'
:
_argx
(
_op
.
argmin
,
'argmin'
),
'ArgMin'
:
_argx
(
_op
.
argmin
,
'argmin'
),
'AvgPool'
:
_pooling
(
'avg_pool'
),
'AvgPool'
:
_pooling
(
'avg_pool'
),
...
@@ -1232,26 +1251,33 @@ _convert_map = {
...
@@ -1232,26 +1251,33 @@ _convert_map = {
'ExpandDims'
:
_expand_dims
(),
'ExpandDims'
:
_expand_dims
(),
'Fill'
:
_fill
(),
'Fill'
:
_fill
(),
'Floor'
:
AttrCvt
(
'floor'
),
'Floor'
:
AttrCvt
(
'floor'
),
'FloorDiv'
:
_floordiv
(),
'FusedBatchNorm'
:
_fused_batch_norm
(),
'FusedBatchNorm'
:
_fused_batch_norm
(),
'FusedBatchNormV2'
:
_fused_batch_norm
(),
'FusedBatchNormV2'
:
_fused_batch_norm
(),
'Gather'
:
_gather
(),
'Gather'
:
_gather
(),
'GatherNd'
:
_gather_nd
(),
'GatherV2'
:
_gather
(),
'GatherV2'
:
_gather
(),
'Greater'
:
_broadcast
(
'greater'
),
'Greater'
:
_broadcast
(
'greater'
),
'GreaterEqual'
:
_broadcast
(
'greater_equal'
),
'GreaterEqual'
:
_broadcast
(
'greater_equal'
),
'Identity'
:
_identity
(),
'Identity'
:
_identity
(),
'LeakyRelu'
:
AttrCvt
(
'leaky_relu'
),
'LeakyRelu'
:
AttrCvt
(
'leaky_relu'
),
'LeftShift'
:
AttrCvt
(
'left_shift'
),
'Less'
:
_broadcast
(
'less'
),
'Less'
:
_broadcast
(
'less'
),
'LessEqual'
:
_broadcast
(
'less_equal'
),
'LessEqual'
:
_broadcast
(
'less_equal'
),
'Log'
:
AttrCvt
(
'log'
),
'Log'
:
AttrCvt
(
'log'
),
'LogicalAnd'
:
_logical
(
'logical_and'
),
'LogicalAnd'
:
_logical
(
'logical_and'
),
'LogicalOr'
:
_logical
(
'logical_or'
),
'LogicalOr'
:
_logical
(
'logical_or'
),
'LogicalNot'
:
_logical
(
'logical_not'
),
'LogicalNot'
:
_logical
(
'logical_not'
),
'LogSoftmax'
:
AttrCvt
(
'log_softmax'
),
'LRN'
:
_lrn
(),
'LRN'
:
_lrn
(),
'MatMul'
:
_matmul
(),
'MatMul'
:
_matmul
(),
'Max'
:
_reduce
(
'max'
),
'MaxPool'
:
_pooling
(
'max_pool'
),
'MaxPool'
:
_pooling
(
'max_pool'
),
'Maximum'
:
_elemwise
(
'maximum'
),
'Maximum'
:
_elemwise
(
'maximum'
),
'Mean'
:
_mean
(),
'Mean'
:
_mean
(),
'Min'
:
_reduce
(
'min'
),
'Minimum'
:
_elemwise
(
'minimum'
),
'Minimum'
:
_elemwise
(
'minimum'
),
'Mod'
:
_elemwise
(
'mod'
),
'Mul'
:
_elemwise
(
'multiply'
),
'Mul'
:
_elemwise
(
'multiply'
),
'Neg'
:
AttrCvt
(
'negative'
),
'Neg'
:
AttrCvt
(
'negative'
),
'NotEqual'
:
_broadcast
(
'not_equal'
),
'NotEqual'
:
_broadcast
(
'not_equal'
),
...
@@ -1269,6 +1295,7 @@ _convert_map = {
...
@@ -1269,6 +1295,7 @@ _convert_map = {
'ResizeBilinear'
:
_resize_bilinear
(),
'ResizeBilinear'
:
_resize_bilinear
(),
'ResizeBicubic'
:
_resize_bilinear
(),
'ResizeBicubic'
:
_resize_bilinear
(),
'ReverseV2'
:
_reverse_v2
(),
'ReverseV2'
:
_reverse_v2
(),
'RightShift'
:
AttrCvt
(
'right_shift'
),
'Round'
:
AttrCvt
(
'round'
),
'Round'
:
AttrCvt
(
'round'
),
'Rsqrt'
:
_rsqrt
(),
'Rsqrt'
:
_rsqrt
(),
'Select'
:
_where
(),
'Select'
:
_where
(),
...
@@ -1292,7 +1319,9 @@ _convert_map = {
...
@@ -1292,7 +1319,9 @@ _convert_map = {
'Tile'
:
_tile
(),
'Tile'
:
_tile
(),
'TopKV2'
:
_topk
(),
'TopKV2'
:
_topk
(),
'Transpose'
:
_transpose
(),
'Transpose'
:
_transpose
(),
'TruncateMod'
:
_elemwise
(
'mod'
),
'Unpack'
:
_unpack
(),
'Unpack'
:
_unpack
(),
'ZerosLike'
:
AttrCvt
(
'zeros_like'
),
}
}
...
...
tests/python/frontend/tensorflow/test_forward.py
View file @
9bb16872
...
@@ -64,6 +64,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
...
@@ -64,6 +64,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
layout
=
layout
,
layout
=
layout
,
shape
=
shape_dict
,
shape
=
shape_dict
,
outputs
=
out_names
)
outputs
=
out_names
)
with
relay
.
build_config
(
opt_level
=
opt_level
):
with
relay
.
build_config
(
opt_level
=
opt_level
):
graph
,
lib
,
params
=
relay
.
build
(
sym
,
target
,
target_host
,
params
)
graph
,
lib
,
params
=
relay
.
build
(
sym
,
target
,
target_host
,
params
)
...
@@ -642,10 +643,53 @@ def test_forward_stridedslice():
...
@@ -642,10 +643,53 @@ def test_forward_stridedslice():
'float32'
,
shrink_axis_mask
=
8
,
new_axis_mask
=
1
,
ellipsis_mask
=
2
,
begin_mask
=
5
,
'float32'
,
shrink_axis_mask
=
8
,
new_axis_mask
=
1
,
ellipsis_mask
=
2
,
begin_mask
=
5
,
end_mask
=
8
)
end_mask
=
8
)
#######################################################################
# FloorDiv, RealDiv
# -----------------
def
_test_forward_divide
(
ip_shape
,
dtype
):
np_numer
=
np
.
random
.
uniform
(
-
100
,
100
,
size
=
ip_shape
)
.
astype
(
dtype
)
np_denomin
=
np
.
random
.
uniform
(
1
,
100
,
size
=
ip_shape
)
.
astype
(
dtype
)
tf
.
reset_default_graph
()
numerator
=
tf
.
placeholder
(
dtype
,
ip_shape
,
name
=
"numer"
)
denominator
=
tf
.
placeholder
(
dtype
,
ip_shape
,
name
=
"denomin"
)
tf
.
math
.
divide
(
numerator
,
denominator
,
name
=
'RealDiv'
)
compare_tf_with_tvm
([
np_numer
,
np_denomin
],
[
'numer:0'
,
'denomin:0'
],
'RealDiv:0'
)
def
_test_forward_floordiv
(
ip_shape
,
dtype
):
np_numer
=
np
.
random
.
uniform
(
-
100
,
100
,
size
=
ip_shape
)
.
astype
(
dtype
)
tf
.
reset_default_graph
()
numerator
=
tf
.
placeholder
(
dtype
,
ip_shape
,
name
=
"numer"
)
tf
.
math
.
floordiv
(
numerator
,
tf
.
constant
(
5
,
dtype
=
dtype
),
name
=
'FloorDiv'
)
compare_tf_with_tvm
([
np_numer
],
[
'numer:0'
],
'FloorDiv:0'
)
def
test_forward_divide
():
'''test FloorDiv, RealDiv'''
_test_forward_divide
((
4
,),
'int32'
)
_test_forward_divide
((
4
,
3
,
7
),
'float32'
)
_test_forward_floordiv
((
4
,
3
,
7
),
'float32'
)
#######################################################################
#######################################################################
# Gather, GatherV2
# TruncateMod
# ----------------
# -----------
def
_test_forward_truncatemod
(
ip_shape
,
dtype
):
np_data_1
=
np
.
random
.
uniform
(
-
100
,
100
,
size
=
ip_shape
)
.
astype
(
dtype
)
np_data_2
=
np
.
random
.
uniform
(
1
,
10
,
size
=
ip_shape
)
.
astype
(
dtype
)
tf
.
reset_default_graph
()
in_data_1
=
tf
.
placeholder
(
dtype
,
ip_shape
,
name
=
"in_data_1"
)
in_data_2
=
tf
.
placeholder
(
dtype
,
ip_shape
,
name
=
"in_data_2"
)
tf
.
truncatemod
(
in_data_1
,
in_data_2
,
name
=
'truncatemod'
)
compare_tf_with_tvm
([
np_data_1
,
np_data_2
],
[
'in_data_1:0'
,
'in_data_2:0'
],
'truncatemod:0'
)
def
test_forward_truncatemod
():
'''test TruncateMod'''
_test_forward_truncatemod
((
4
,
3
,
7
),
'int32'
)
#######################################################################
# Gather, GatherV2, GatherNd
# --------------------------
def
_test_gather
(
ip_shape
,
indice_shape
,
indice_value
,
axis
,
dtype
):
def
_test_gather
(
ip_shape
,
indice_shape
,
indice_value
,
axis
,
dtype
):
""" One iteration of a GatherV2 """
""" One iteration of a GatherV2 """
...
@@ -718,6 +762,33 @@ def test_forward_gather_v1():
...
@@ -718,6 +762,33 @@ def test_forward_gather_v1():
_test_gather_v1
((
4
,
3
,
5
,
6
),
(
1
,
4
),
[[
2
,
1
,
0
,
0
]],
'float32'
)
_test_gather_v1
((
4
,
3
,
5
,
6
),
(
1
,
4
),
[[
2
,
1
,
0
,
0
]],
'float32'
)
def
test_forward_gather_nd
():
"""test operator GatherNd"""
np_data
=
np
.
random
.
uniform
(
1
,
100
,
size
=
(
2
,
2
))
.
astype
(
np
.
float32
)
tf
.
reset_default_graph
()
in_data
=
tf
.
placeholder
(
tf
.
float32
,
(
2
,
2
),
name
=
"in_data"
)
tf
.
gather_nd
(
in_data
,
indices
=
[[
1
,
0
],
[
0
,
1
]],
name
=
"gather_nd"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'gather_nd:0'
)
#######################################################################
# BiasAdd
# -------
def
test_forward_bias_add
():
"""test Op BiasAdd"""
def
check_bias_add
(
lh_shpae
,
rh_shape
,
dtype
):
tf
.
reset_default_graph
()
lh_data
=
np
.
random
.
uniform
(
size
=
lh_shpae
)
.
astype
(
dtype
)
rh_data
=
np
.
random
.
uniform
(
size
=
rh_shape
)
.
astype
(
dtype
)
lft_data
=
tf
.
placeholder
(
dtype
,
name
=
"lft_data"
)
rgt_data
=
tf
.
placeholder
(
dtype
,
name
=
"rgt_data"
)
tf
.
nn
.
bias_add
(
lft_data
,
rgt_data
,
name
=
"BiasAdd"
)
compare_tf_with_tvm
([
lh_data
,
rh_data
],
[
'lft_data:0'
,
'rgt_data:0'
],
'BiasAdd:0'
)
check_bias_add
((
10
,
8
,
16
,
32
),
(
32
,),
dtype
=
"int32"
)
check_bias_add
((
10
,
20
),
(
20
,),
dtype
=
"float32"
)
#######################################################################
#######################################################################
# Split
# Split
# -----
# -----
...
@@ -1109,6 +1180,32 @@ def test_forward_pack():
...
@@ -1109,6 +1180,32 @@ def test_forward_pack():
_test_pack
(
axis
,
[
3
])
_test_pack
(
axis
,
[
3
])
_test_pack
(
0
,
[])
_test_pack
(
0
,
[])
#######################################################################
# Unpack
# ------
def
_test_forward_unpack
(
in_shape
,
axis
,
dtype
):
"""test operator Unpack"""
np_data
=
np
.
random
.
uniform
(
-
100
,
100
,
size
=
in_shape
)
.
astype
(
dtype
)
tf
.
reset_default_graph
()
in_data
=
tf
.
placeholder
(
dtype
,
in_shape
,
name
=
"in_data"
)
tf
.
unstack
(
in_data
,
axis
=
axis
,
name
=
"Unpack"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'Unpack:0'
)
def
test_forward_unpack
():
_test_forward_unpack
((
3
,),
0
,
'int32'
)
_test_forward_unpack
((
3
,),
-
1
,
'int16'
)
_test_forward_unpack
((
21
,
23
,
3
),
2
,
'float32'
)
#######################################################################
# Range
# -----
def
test_forward_range
():
"""test operator Range"""
tf
.
reset_default_graph
()
tf
.
range
(
1
,
18
,
3
,
name
=
"range"
)
compare_tf_with_tvm
([],
[],
'range:0'
)
#######################################################################
#######################################################################
# Pad
# Pad
# ---
# ---
...
@@ -1182,7 +1279,7 @@ def test_forward_logical():
...
@@ -1182,7 +1279,7 @@ def test_forward_logical():
#######################################################################
#######################################################################
# Where, Select
# Where, Select
# -------------
# -------------
def
test_where
():
def
test_
forward_
where
():
''' Where: return elements depending on conditions'''
''' Where: return elements depending on conditions'''
with
tf
.
Graph
()
.
as_default
():
with
tf
.
Graph
()
.
as_default
():
with
tf
.
Session
()
as
sess
:
with
tf
.
Session
()
as
sess
:
...
@@ -1553,6 +1650,22 @@ def test_forward_tanh():
...
@@ -1553,6 +1650,22 @@ def test_forward_tanh():
tf
.
nn
.
tanh
(
in1
)
tf
.
nn
.
tanh
(
in1
)
compare_tf_with_tvm
(
inp_array
,
'Placeholder:0'
,
'Tanh:0'
)
compare_tf_with_tvm
(
inp_array
,
'Placeholder:0'
,
'Tanh:0'
)
#######################################################################
# Softmax
# -------
def
test_forward_softmax
():
"""test operator Softmax """
def
check_softmax
(
in_shape
,
axis
,
dtype
):
np_data
=
np
.
random
.
uniform
(
-
100
,
100
,
size
=
in_shape
)
.
astype
(
dtype
)
tf
.
reset_default_graph
()
in_data
=
tf
.
placeholder
(
dtype
,
in_shape
,
name
=
"in_data"
)
tf
.
nn
.
softmax
(
in_data
,
axis
=
axis
,
name
=
"Softmax"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'Softmax:0'
)
check_softmax
((
2
,
3
,
5
),
2
,
"float32"
)
check_softmax
((
2
,
3
,
5
),
-
1
,
"float32"
)
#######################################################################
#######################################################################
# Tensor
# Tensor
# ------
# ------
...
@@ -1565,6 +1678,29 @@ def test_forward_round():
...
@@ -1565,6 +1678,29 @@ def test_forward_round():
tf
.
round
(
in_data
,
name
=
"round"
)
tf
.
round
(
in_data
,
name
=
"round"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'round:0'
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'round:0'
)
def
test_forward_abs
():
"""test operator Abs"""
np_data
=
np
.
random
.
uniform
(
1
,
100
,
size
=
(
9
,
11
))
.
astype
(
np
.
float32
)
tf
.
reset_default_graph
()
in_data
=
tf
.
placeholder
(
tf
.
float32
,
(
9
,
11
),
name
=
"in_data"
)
tf
.
math
.
abs
(
in_data
,
name
=
"abs"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'abs:0'
)
def
_test_forward_zeros_like
(
in_shape
,
dtype
):
np_data
=
np
.
random
.
uniform
(
-
10
,
10
,
size
=
in_shape
)
.
astype
(
dtype
)
tf
.
reset_default_graph
()
in_data
=
tf
.
placeholder
(
dtype
,
in_shape
,
name
=
"in_data"
)
tf
.
zeros_like
(
in_data
,
name
=
"zeros_like"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'zeros_like:0'
)
def
test_forward_zeros_like
():
if
tf
.
__version__
<
LooseVersion
(
'1.2'
):
_test_forward_zeros_like
((
2
,
3
),
"int32"
)
_test_forward_zeros_like
((
2
,
3
,
5
),
"int8"
)
_test_forward_zeros_like
((
2
,
3
,
5
,
7
),
"uint16"
)
_test_forward_zeros_like
((
2
,
3
,
11
),
"float32"
)
_test_forward_zeros_like
((
2
,
3
,
11
),
"float64"
)
def
_test_forward_reverse_v2
(
in_shape
,
axis
,
dtype
):
def
_test_forward_reverse_v2
(
in_shape
,
axis
,
dtype
):
np_data
=
np
.
random
.
uniform
(
-
10
,
10
,
size
=
in_shape
)
.
astype
(
dtype
)
np_data
=
np
.
random
.
uniform
(
-
10
,
10
,
size
=
in_shape
)
.
astype
(
dtype
)
tf
.
reset_default_graph
()
tf
.
reset_default_graph
()
...
@@ -1588,6 +1724,14 @@ def test_forward_sign():
...
@@ -1588,6 +1724,14 @@ def test_forward_sign():
tf
.
sign
(
in_data
,
name
=
"sign"
)
tf
.
sign
(
in_data
,
name
=
"sign"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'sign:0'
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'sign:0'
)
def
test_forward_square
():
"""test operator Square """
np_data
=
np
.
random
.
uniform
(
1
,
100
,
size
=
(
2
,
3
,
5
))
.
astype
(
np
.
float32
)
tf
.
reset_default_graph
()
in_data
=
tf
.
placeholder
(
tf
.
float32
,
(
2
,
3
,
5
),
name
=
"in_data"
)
tf
.
square
(
in_data
,
name
=
"square"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'square:0'
)
def
test_forward_pow_exp
():
def
test_forward_pow_exp
():
"""test Pow and Exp """
"""test Pow and Exp """
np_in1
=
np
.
random
.
uniform
(
-
2
,
2
,
size
=
(
5
,
7
,
11
))
.
astype
(
np
.
float32
)
np_in1
=
np
.
random
.
uniform
(
-
2
,
2
,
size
=
(
5
,
7
,
11
))
.
astype
(
np
.
float32
)
...
@@ -1616,6 +1760,14 @@ def test_forward_negative():
...
@@ -1616,6 +1760,14 @@ def test_forward_negative():
tf
.
negative
(
in_data
,
name
=
"negative"
)
tf
.
negative
(
in_data
,
name
=
"negative"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'negative:0'
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'negative:0'
)
def
test_forward_log_softmax
():
"""test operator LogSoftmax"""
np_data
=
np
.
random
.
uniform
(
1
,
100
,
size
=
(
9
,
11
))
.
astype
(
np
.
float32
)
tf
.
reset_default_graph
()
in_data
=
tf
.
placeholder
(
tf
.
float32
,
(
9
,
11
),
name
=
"in_data"
)
tf
.
math
.
log_softmax
(
in_data
,
name
=
"LogSoftmax"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'LogSoftmax:0'
)
def
test_forward_softplus
():
def
test_forward_softplus
():
"""test operator Softplus"""
"""test operator Softplus"""
np_data
=
np
.
random
.
uniform
(
1
,
10
,
size
=
(
2
,
3
,
5
))
.
astype
(
np
.
float32
)
np_data
=
np
.
random
.
uniform
(
1
,
10
,
size
=
(
2
,
3
,
5
))
.
astype
(
np
.
float32
)
...
@@ -1640,6 +1792,34 @@ def test_forward_sqrt():
...
@@ -1640,6 +1792,34 @@ def test_forward_sqrt():
tf
.
sqrt
(
in_data
,
name
=
"sqrt"
)
tf
.
sqrt
(
in_data
,
name
=
"sqrt"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'sqrt:0'
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'sqrt:0'
)
def
_test_forward_right_shift
(
in_shape
,
dtype
):
"""test operator RightShift"""
lh_data
=
np
.
random
.
randint
(
1
,
3
,
size
=
in_shape
)
.
astype
(
dtype
)
rh_data
=
np
.
random
.
randint
(
1
,
8
,
size
=
in_shape
)
.
astype
(
dtype
)
tf
.
reset_default_graph
()
lft_data
=
tf
.
placeholder
(
dtype
,
in_shape
,
name
=
"lft_data"
)
rgt_data
=
tf
.
placeholder
(
dtype
,
in_shape
,
name
=
"rgt_data"
)
tf
.
bitwise
.
right_shift
(
lft_data
,
rgt_data
,
name
=
"RightShift"
)
compare_tf_with_tvm
([
lh_data
,
rh_data
],
[
'lft_data:0'
,
'rgt_data:0'
],
'RightShift:0'
)
def
test_forward_right_shift
():
_test_forward_right_shift
((
7
,),
'int32'
)
_test_forward_right_shift
((
3
,
11
),
'int16'
)
def
_test_forward_left_shift
(
in_shape
,
dtype
):
"""test operator LeftShift"""
lh_data
=
np
.
random
.
randint
(
100
,
1000000
,
size
=
in_shape
)
.
astype
(
dtype
)
rh_data
=
np
.
random
.
randint
(
1
,
3
,
size
=
in_shape
)
.
astype
(
dtype
)
tf
.
reset_default_graph
()
lft_data
=
tf
.
placeholder
(
dtype
,
in_shape
,
name
=
"lft_data"
)
rgt_data
=
tf
.
placeholder
(
dtype
,
in_shape
,
name
=
"rgt_data"
)
tf
.
bitwise
.
left_shift
(
lft_data
,
rgt_data
,
name
=
"LeftShift"
)
compare_tf_with_tvm
([
lh_data
,
rh_data
],
[
'lft_data:0'
,
'rgt_data:0'
],
'LeftShift:0'
)
def
test_forward_left_shift
():
_test_forward_left_shift
((
10
,),
'int32'
)
_test_forward_left_shift
((
224
,
224
,
3
),
'int16'
)
#######################################################################
#######################################################################
# Mean
# Mean
# ----
# ----
...
@@ -1652,13 +1832,13 @@ def test_forward_mean():
...
@@ -1652,13 +1832,13 @@ def test_forward_mean():
compare_tf_with_tvm
(
inp_array
,
'Placeholder:0'
,
'Mean:0'
,
no_gpu
=
True
)
compare_tf_with_tvm
(
inp_array
,
'Placeholder:0'
,
'Mean:0'
,
no_gpu
=
True
)
check_mean
((
10
,
8
,
16
,
32
))
check_mean
((
10
,
8
,
16
,
32
))
check_mean
((
10
,
8
,
16
,
32
),
axis
=
(
2
,
3
))
check_mean
((
10
,
8
,
16
,
32
),
axis
=
(
2
,
3
))
check_mean
((
10
,
8
,
16
,
32
),
axis
=
(
1
,
2
),
keepdims
=
True
)
check_mean
((
10
,
8
,
16
,
32
),
axis
=
(
1
,
2
),
keepdims
=
True
)
#######################################################################
#######################################################################
# All
# All
, Max, Min
# ---
# ---
----------
def
test_forward_all
():
def
test_forward_
reduce_
all
():
"""Test the All operator."""
"""Test the All operator."""
np_data
=
np
.
random
.
choice
([
True
,
False
],
size
=
(
5
,
7
,
11
))
np_data
=
np
.
random
.
choice
([
True
,
False
],
size
=
(
5
,
7
,
11
))
tf
.
reset_default_graph
()
tf
.
reset_default_graph
()
...
@@ -1666,6 +1846,30 @@ def test_forward_all():
...
@@ -1666,6 +1846,30 @@ def test_forward_all():
tf
.
reduce_all
(
in_data
,
name
=
"all"
)
tf
.
reduce_all
(
in_data
,
name
=
"all"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'all:0'
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'all:0'
)
def
test_forward_reduce_max
():
def
check_max
(
ishape
,
axis
,
keepdims
,
dtype
):
tf
.
reset_default_graph
()
np_data
=
np
.
random
.
uniform
(
size
=
ishape
)
.
astype
(
dtype
)
in_data
=
tf
.
placeholder
(
dtype
,
name
=
"in_data"
)
tf
.
math
.
reduce_max
(
in_data
,
axis
=
axis
,
keepdims
=
keepdims
,
name
=
"reduce_max"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'reduce_max:0'
)
check_max
((
10
,
8
,
16
,
32
),
axis
=
(
-
1
),
keepdims
=
True
,
dtype
=
"int32"
)
check_max
((
10
,
8
,
16
,
32
),
axis
=
(
2
,
3
),
keepdims
=
True
,
dtype
=
"float32"
)
check_max
((
10
,
8
,
16
,
32
),
axis
=
(
1
,
2
),
keepdims
=
True
,
dtype
=
'float32'
)
def
test_forward_reduce_min
():
def
check_min
(
ishape
,
axis
,
keepdims
,
dtype
):
tf
.
reset_default_graph
()
np_data
=
np
.
random
.
uniform
(
size
=
ishape
)
.
astype
(
dtype
)
in_data
=
tf
.
placeholder
(
dtype
,
name
=
"in_data"
)
tf
.
math
.
reduce_min
(
in_data
,
axis
=
axis
,
keepdims
=
keepdims
,
name
=
"reduce_max"
)
compare_tf_with_tvm
([
np_data
],
[
'in_data:0'
],
'reduce_max:0'
)
check_min
((
10
,
8
,
16
,
32
),
axis
=
(
-
1
),
keepdims
=
True
,
dtype
=
"int32"
)
check_min
((
10
,
8
,
16
,
32
),
axis
=
(
2
,
3
),
keepdims
=
True
,
dtype
=
"float32"
)
check_min
((
10
,
8
,
16
,
32
),
axis
=
(
1
,
2
),
keepdims
=
True
,
dtype
=
'float32'
)
#######################################################################
#######################################################################
# Relational operators
# Relational operators
# --------------------
# --------------------
...
@@ -1724,6 +1928,38 @@ def test_forward_reduce_prod():
...
@@ -1724,6 +1928,38 @@ def test_forward_reduce_prod():
#######################################################################
#######################################################################
# Maximum, Minimum
# ----------------
def
test_forward_maximum
():
"""test Op Maximum"""
def
check_maximum
(
lh_shape
,
rh_shape
,
dtype
):
tf
.
reset_default_graph
()
lh_data
=
np
.
random
.
uniform
(
size
=
lh_shape
)
.
astype
(
dtype
)
rh_data
=
np
.
random
.
uniform
(
size
=
rh_shape
)
.
astype
(
dtype
)
lft_data
=
tf
.
placeholder
(
dtype
,
name
=
"lft_data"
)
rgt_data
=
tf
.
placeholder
(
dtype
,
name
=
"rgt_data"
)
tf
.
math
.
maximum
(
lft_data
,
rgt_data
,
name
=
"maximum"
)
compare_tf_with_tvm
([
lh_data
,
rh_data
],
[
'lft_data:0'
,
'rgt_data:0'
],
'maximum:0'
)
check_maximum
((
10
,
8
,
16
,
32
),
(
1
,),
dtype
=
"int32"
)
check_maximum
((
10
,
8
,
16
,
32
),
(
10
,
8
,
16
,
32
),
dtype
=
"float32"
)
def
test_forward_minimum
():
"""test Op Minimum"""
def
check_minimum
(
lh_shape
,
rh_shape
,
dtype
):
tf
.
reset_default_graph
()
lh_data
=
np
.
random
.
uniform
(
size
=
lh_shape
)
.
astype
(
dtype
)
rh_data
=
np
.
random
.
uniform
(
size
=
rh_shape
)
.
astype
(
dtype
)
lft_data
=
tf
.
placeholder
(
dtype
,
name
=
"lft_data"
)
rgt_data
=
tf
.
placeholder
(
dtype
,
name
=
"rgt_data"
)
tf
.
math
.
minimum
(
lft_data
,
rgt_data
,
name
=
"minimum"
)
compare_tf_with_tvm
([
lh_data
,
rh_data
],
[
'lft_data:0'
,
'rgt_data:0'
],
'minimum:0'
)
check_minimum
((
10
,
8
,
16
,
32
),
(
1
,),
dtype
=
"int32"
)
check_minimum
((
10
,
8
,
16
,
32
),
(
10
,
8
,
16
,
32
),
dtype
=
"float32"
)
#######################################################################
# PlaceholderWithDefault
# PlaceholderWithDefault
# ----------------------
# ----------------------
def
test_placeholder
():
def
test_placeholder
():
...
@@ -1740,6 +1976,7 @@ def test_placeholder():
...
@@ -1740,6 +1976,7 @@ def test_placeholder():
compare_tf_with_tvm
([
in_data1
,
in_data2
],
[
'place1:0'
,
'in2:0'
],
'out2:0'
,
init_global_variables
=
True
)
compare_tf_with_tvm
([
in_data1
,
in_data2
],
[
'place1:0'
,
'in2:0'
],
'out2:0'
,
init_global_variables
=
True
)
#######################################################################
#######################################################################
# Main
# Main
# ----
# ----
...
@@ -1756,14 +1993,22 @@ if __name__ == '__main__':
...
@@ -1756,14 +1993,22 @@ if __name__ == '__main__':
test_forward_fill
()
test_forward_fill
()
test_forward_crop
()
test_forward_crop
()
test_forward_pad
()
test_forward_pad
()
test_forward_unpack
()
test_forward_gather
()
test_forward_gather
()
test_forward_gather_v1
()
test_forward_gather_v1
()
test_forward_gather_nd
()
test_forward_stridedslice
()
test_forward_stridedslice
()
test_forward_split
()
test_forward_split
()
test_forward_unstack
()
test_forward_unstack
()
test_forward_tile
()
test_forward_tile
()
test_forward_top_k_v2
()
test_forward_top_k_v2
()
test_forward_clip_by_value
()
test_forward_clip_by_value
()
test_forward_maximum
()
test_forward_minimum
()
test_forward_range
()
test_forward_right_shift
()
test_forward_left_shift
()
test_forward_truncatemod
()
# Activations
# Activations
test_forward_sigmoid
()
test_forward_sigmoid
()
...
@@ -1780,17 +2025,26 @@ if __name__ == '__main__':
...
@@ -1780,17 +2025,26 @@ if __name__ == '__main__':
test_forward_sign
()
test_forward_sign
()
test_forward_log
()
test_forward_log
()
test_forward_negative
()
test_forward_negative
()
test_forward_divide
()
test_forward_abs
()
test_forward_softplus
()
test_forward_softplus
()
test_forward_sqrt
()
test_forward_sqrt
()
test_forward_rsqrt
()
test_forward_rsqrt
()
test_forward_expand_dims
()
test_forward_expand_dims
()
test_forward_square
()
test_forward_softmax
()
test_forward_log_softmax
()
test_forward_bias_add
()
test_forward_zeros_like
()
# Reductions
# Reductions
test_forward_argminmax
()
test_forward_argminmax
()
test_forward_reduce
()
test_forward_reduce
()
test_forward_mean
()
test_forward_mean
()
test_forward_reduce_prod
()
test_forward_reduce_prod
()
test_forward_all
()
test_forward_reduce_all
()
test_forward_reduce_max
()
test_forward_reduce_min
()
# General
# General
test_forward_multi_input
()
test_forward_multi_input
()
...
@@ -1826,7 +2080,7 @@ if __name__ == '__main__':
...
@@ -1826,7 +2080,7 @@ if __name__ == '__main__':
# Relational ops
# Relational ops
test_forward_rel_ops
()
test_forward_rel_ops
()
test_forward_logical
()
test_forward_logical
()
test_where
()
test_
forward_
where
()
test_forward_matmul
()
test_forward_matmul
()
# TODO missing tests: rank, range
# TODO missing tests: rank, range
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment