Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
770ac84e
Commit
770ac84e
authored
Jun 06, 2019
by
Alexey Romanov
Committed by
Tianqi Chen
Jun 06, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Frontend] Simplify parameter handling in Tensorflow frontend (#2993)
parent
5999f7a6
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
141 additions
and
140 deletions
+141
-140
python/tvm/relay/frontend/tensorflow.py
+90
-100
tests/python/frontend/tensorflow/test_forward.py
+49
-30
topi/python/topi/util.py
+2
-10
No files found.
python/tvm/relay/frontend/tensorflow.py
View file @
770ac84e
...
@@ -63,7 +63,7 @@ def _get_relay_op(op_name):
...
@@ -63,7 +63,7 @@ def _get_relay_op(op_name):
return
op
return
op
class
AttrCvt
(
object
):
class
AttrCvt
(
object
):
"""Common attribute conveter. An AttrConverter instance is a callable:
"""Common attribute conve
r
ter. An AttrConverter instance is a callable:
```
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
new_op_name, new_attr = attr_converter(attrs)
...
@@ -222,17 +222,37 @@ def _dimension_constraint():
...
@@ -222,17 +222,37 @@ def _dimension_constraint():
return
False
return
False
return
_dim_check
,
"Only 2d kernel supported."
return
_dim_check
,
"Only 2d kernel supported."
def
_infer_channels
(
inputs
,
params
,
transpose
=
False
):
def
_infer_channels
(
node
,
params
,
transpose
=
False
):
"""A hack for getting 'chann
le
s' or 'units' since tensorflow don't provide
"""A hack for getting 'chann
el
s' or 'units' since tensorflow don't provide
these attributes. We check the shape of weights provided to get the number.
these attributes. We check the shape of weights provided to get the number.
"""
"""
out_type
=
ir_pass
.
infer_type
(
inputs
)
out_shape
=
_infer_shape
(
node
,
params
)
out_shapes
=
[
get_const_tuple
(
out_type
.
checked_type
.
shape
)]
channels
=
out_shape
[
0
]
if
not
transpose
else
out_shape
[
1
]
channels
=
out_shapes
[
0
][
0
]
if
not
transpose
else
out_shapes
[
0
][
1
]
return
channels
return
channels
def
_infer_out_shapes
(
inputs
,
params
):
"""A method to get the output shape of intermediate nodes in the relay graph."""
return
[
_infer_shape
(
inputs
,
params
)]
def
_infer_shape
(
node
,
params
=
None
):
"""A method to get the output shape of an intermediate node in the relay graph."""
out_type
=
ir_pass
.
infer_type
(
node
)
return
get_const_tuple
(
out_type
.
checked_type
.
shape
)
def
_get_param
(
params
,
input_node
):
return
params
.
pop
(
input_node
.
name_hint
)
.
asnumpy
()
def
_get_num_param
(
params
,
input_node
):
return
_get_param
(
params
,
input_node
)[
0
]
def
_get_list_param
(
params
,
input_node
):
return
_get_param
(
params
,
input_node
)
.
tolist
()
def
_get_tuple_param
(
params
,
input_node
):
return
tuple
(
_get_param
(
params
,
input_node
))
def
_rsqrt
():
def
_rsqrt
():
def
_impl
(
inputs
,
attr
,
*
arg
s
):
def
_impl
(
inputs
,
attr
,
param
s
):
inputs
.
append
(
tvm
.
relay
.
const
(
-
0.5
,
attr
[
'T'
]
.
name
))
inputs
.
append
(
tvm
.
relay
.
const
(
-
0.5
,
attr
[
'T'
]
.
name
))
return
AttrCvt
(
op_name
=
"power"
)(
inputs
,
attr
)
return
AttrCvt
(
op_name
=
"power"
)(
inputs
,
attr
)
return
_impl
return
_impl
...
@@ -243,16 +263,15 @@ def _argx(func, func_name):
...
@@ -243,16 +263,15 @@ def _argx(func, func_name):
try
:
try
:
# In Tensorflow, `axis` argument is a Tensor, not attribute. We
# In Tensorflow, `axis` argument is a Tensor, not attribute. We
# support the case where it inputs from a scalar constant.
# support the case where it inputs from a scalar constant.
axis_input_name
=
inputs
[
1
]
.
name_hint
axis_input_value
=
[
_get_num_param
(
params
,
inputs
[
1
])]
axis_input_vlaue
=
[
params
[
axis_input_name
]
.
asnumpy
()[
0
]]
except
(
IndexError
,
KeyError
):
except
(
IndexError
,
KeyError
):
raise
TypeError
(
\
raise
TypeError
(
\
"Unsupported argument for `{}` : `axis` should be a constant"
.
format
(
func_name
))
"Unsupported argument for `{}` : `axis` should be a constant"
.
format
(
func_name
))
return
func
(
inputs
[
0
],
axis
=
axis_input_v
la
ue
,
keepdims
=
False
)
return
func
(
inputs
[
0
],
axis
=
axis_input_v
al
ue
,
keepdims
=
False
)
return
_impl
return
_impl
def
_elemwise
(
name
):
def
_elemwise
(
name
):
def
_impl
(
inputs
,
attr
,
*
arg
s
):
def
_impl
(
inputs
,
attr
,
param
s
):
assert
len
(
inputs
)
==
2
,
"{} take 2 inputs, {} given"
.
format
(
name
,
len
(
inputs
))
assert
len
(
inputs
)
==
2
,
"{} take 2 inputs, {} given"
.
format
(
name
,
len
(
inputs
))
return
_get_relay_op
(
name
)(
*
inputs
)
return
_get_relay_op
(
name
)(
*
inputs
)
return
_impl
return
_impl
...
@@ -472,7 +491,7 @@ def _cast():
...
@@ -472,7 +491,7 @@ def _cast():
def
_expand_dims
():
def
_expand_dims
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
dim_input
=
inputs
.
pop
(
1
)
dim_input
=
inputs
.
pop
(
1
)
axis
=
params
.
pop
(
_get_name_hint
(
dim_input
))
.
asnumpy
()[
0
]
axis
=
_get_num_param
(
params
,
dim_input
)
return
AttrCvt
(
op_name
=
"expand_dims"
,
ignores
=
[
'Tdim'
,
'N'
],
return
AttrCvt
(
op_name
=
"expand_dims"
,
ignores
=
[
'Tdim'
,
'N'
],
extras
=
{
'axis'
:
int
(
axis
),
'num_newaxis'
:
1
})(
inputs
,
attr
)
extras
=
{
'axis'
:
int
(
axis
),
'num_newaxis'
:
1
})(
inputs
,
attr
)
return
_impl
return
_impl
...
@@ -527,21 +546,19 @@ def _identity():
...
@@ -527,21 +546,19 @@ def _identity():
def
_concatV2
():
def
_concatV2
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
pop_node
=
inputs
.
pop
(
len
(
inputs
)
-
1
)
pop_node
=
inputs
.
pop
(
len
(
inputs
)
-
1
)
axis
=
params
[
pop_node
.
name_hint
]
axis
=
int
(
_get_num_param
(
params
,
pop_node
))
params
.
pop
(
pop_node
.
name_hint
)
return
AttrCvt
(
return
AttrCvt
(
op_name
=
"concatenate"
,
ignores
=
[
'T'
,
'N'
,
'Tidx'
],
op_name
=
"concatenate"
,
ignores
=
[
'T'
,
'N'
,
'Tidx'
],
extras
=
{
'axis'
:
int
(
axis
.
asnumpy
()[
0
])
})([
inputs
],
attr
)
extras
=
{
'axis'
:
axis
})([
inputs
],
attr
)
return
_impl
return
_impl
def
_concat
():
def
_concat
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
pop_node
=
inputs
.
pop
(
0
)
pop_node
=
inputs
.
pop
(
0
)
axis
=
params
[
pop_node
.
name_hint
]
axis
=
int
(
_get_num_param
(
params
,
pop_node
))
params
.
pop
(
pop_node
.
name_hint
)
return
AttrCvt
(
return
AttrCvt
(
op_name
=
"concatenate"
,
ignores
=
[
'N'
],
op_name
=
"concatenate"
,
ignores
=
[
'N'
],
extras
=
{
'axis'
:
int
(
axis
.
asnumpy
()[
0
])
})([
inputs
],
attr
)
extras
=
{
'axis'
:
axis
})([
inputs
],
attr
)
return
_impl
return
_impl
def
_pack
():
def
_pack
():
...
@@ -565,8 +582,8 @@ def _tile():
...
@@ -565,8 +582,8 @@ def _tile():
def
_slice
():
def
_slice
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
begin
=
params
.
pop
(
_get_name_hint
(
inputs
[
1
]))
.
asnumpy
()
.
tolist
(
)
begin
=
_get_list_param
(
params
,
inputs
[
1
]
)
size
=
params
.
pop
(
_get_name_hint
(
inputs
[
2
]))
.
asnumpy
()
.
tolist
(
)
size
=
_get_list_param
(
params
,
inputs
[
2
]
)
data_shape
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
data_shape
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
data_dim
=
len
(
data_shape
)
data_dim
=
len
(
data_shape
)
end
=
size
end
=
size
...
@@ -581,24 +598,18 @@ def _slice():
...
@@ -581,24 +598,18 @@ def _slice():
def
_reshape
():
def
_reshape
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
pop_node
=
inputs
.
pop
(
1
)
try
:
try
:
pop_node
=
inputs
[
1
]
shape_arg
=
_get_tuple_param
(
params
,
pop_node
)
shape_arg
=
params
.
pop
(
pop_node
.
name_hint
)
inputs
.
pop
(
1
)
return
AttrCvt
(
op_name
=
"reshape"
,
extras
=
{
'newshape'
:
tuple
(
shape_arg
.
asnumpy
())},
ignores
=
[
'Tshape'
])(
inputs
,
attr
)
except
AttributeError
:
except
AttributeError
:
# Shape operator is already pruned, hence
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
# try to infer shape by precompute prune if possible.
params_new
=
_infer_value
(
inputs
[
1
]
,
params
)
params_new
=
_infer_value
(
pop_node
,
params
)
inputs
.
pop
(
1
)
shape_arg
=
tuple
(
params_new
.
asnumpy
()
.
astype
(
'int64'
)
.
flatten
()
)
return
AttrCvt
(
return
AttrCvt
(
op_name
=
"reshape"
,
op_name
=
"reshape"
,
extras
=
{
'newshape'
:
tuple
(
params_new
.
asnumpy
()
.
astype
(
'int64'
)
.
flatten
())
},
extras
=
{
'newshape'
:
shape_arg
},
ignores
=
[
'Tshape'
])(
inputs
,
attr
)
ignores
=
[
'Tshape'
])(
inputs
,
attr
)
return
_impl
return
_impl
...
@@ -737,9 +748,10 @@ def _fill():
...
@@ -737,9 +748,10 @@ def _fill():
if
-
1
in
output_shape
:
if
-
1
in
output_shape
:
output_shape
=
_infer_value
(
inputs
[
0
],
params
)
.
asnumpy
()
.
reshape
([
-
1
])
.
tolist
()
output_shape
=
_infer_value
(
inputs
[
0
],
params
)
.
asnumpy
()
.
reshape
([
-
1
])
.
tolist
()
fill_arg
=
params
.
pop
(
inputs
.
pop
(
1
)
.
name_hint
)
fill_arg
=
_get_num_param
(
params
,
inputs
.
pop
(
1
))
return
_op
.
full
(
tvm
.
relay
.
const
(
fill_arg
.
asnumpy
()[
0
],
attr
[
'T'
]
.
name
),
dtype
=
attr
[
'T'
]
.
name
output_shape
,
attr
[
'T'
]
.
name
)
return
_op
.
full
(
tvm
.
relay
.
const
(
fill_arg
,
dtype
),
output_shape
,
dtype
)
return
_impl
return
_impl
def
_lrn
():
def
_lrn
():
...
@@ -757,9 +769,7 @@ def _lrn():
...
@@ -757,9 +769,7 @@ def _lrn():
def
_sum
():
def
_sum
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
axis
=
params
.
pop
(
inputs
[
1
]
.
name_hint
)
.
asnumpy
()
axis
=
_get_tuple_param
(
params
,
inputs
[
1
])
# convert to tuple for preventing invalid parameter format error
axis
=
tuple
(
axis
)
return
AttrCvt
(
return
AttrCvt
(
op_name
=
'sum'
,
op_name
=
'sum'
,
extras
=
{
'axis'
:
axis
},
extras
=
{
'axis'
:
axis
},
...
@@ -786,25 +796,17 @@ def _square():
...
@@ -786,25 +796,17 @@ def _square():
def
_gather
():
def
_gather
():
"GatherV2, Gather"
"GatherV2, Gather"
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
axis
=
0
if
len
(
inputs
)
>
2
:
if
len
(
inputs
)
>
2
:
axis
=
params
[
inputs
.
pop
(
2
)
.
name_hint
]
.
asnumpy
()[
0
]
axis
=
_get_num_param
(
params
,
inputs
.
pop
(
2
))
new_input
=
[]
else
:
new_input
.
append
(
inputs
.
pop
(
0
))
axis
=
0
new_input
.
append
(
inputs
.
pop
(
0
))
new_input
=
inputs
[
0
:
2
]
return
AttrCvt
(
op_name
=
"take"
,
return
AttrCvt
(
op_name
=
"take"
,
extras
=
{
'axis'
:
tvm
.
const
(
axis
,
'int32'
)},
extras
=
{
'axis'
:
tvm
.
const
(
axis
,
'int32'
)},
ignores
=
[
'Tindices'
,
'Tparams'
,
'validate_indices'
,
\
ignores
=
[
'Tindices'
,
'Tparams'
,
'validate_indices'
,
'Taxis'
,
'_class'
])(
new_input
,
attr
)
'Taxis'
,
'_class'
])(
new_input
,
attr
)
return
_impl
return
_impl
def
_infer_out_shapes
(
inputs
,
params
):
"""A method to get the output shape of an intermediate node in the relay graph."""
out_type
=
ir_pass
.
infer_type
(
inputs
)
out_shapes
=
[
get_const_tuple
(
out_type
.
checked_type
.
shape
)]
return
out_shapes
def
_stridedSlice
():
def
_stridedSlice
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
"""Strided Slice.
"""Strided Slice.
...
@@ -812,9 +814,9 @@ def _stridedSlice():
...
@@ -812,9 +814,9 @@ def _stridedSlice():
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/core/util/strided_slice_op.cc#L147-L368
tensorflow/core/util/strided_slice_op.cc#L147-L368
"""
"""
begin
=
params
.
pop
(
inputs
[
1
]
.
name_hint
)
.
asnumpy
()
.
tolist
(
)
begin
=
_get_list_param
(
params
,
inputs
[
1
]
)
end
=
params
.
pop
(
inputs
[
2
]
.
name_hint
)
.
asnumpy
()
.
tolist
(
)
end
=
_get_list_param
(
params
,
inputs
[
2
]
)
stride
=
params
.
pop
(
inputs
[
3
]
.
name_hint
)
.
asnumpy
()
.
tolist
(
)
stride
=
_get_list_param
(
params
,
inputs
[
3
]
)
begin_mask
=
int
(
attr
.
get
(
'begin_mask'
,
0
))
begin_mask
=
int
(
attr
.
get
(
'begin_mask'
,
0
))
end_mask
=
int
(
attr
.
get
(
'end_mask'
,
0
))
end_mask
=
int
(
attr
.
get
(
'end_mask'
,
0
))
ellipsis_mask
=
int
(
attr
.
get
(
'ellipsis_mask'
,
0
))
ellipsis_mask
=
int
(
attr
.
get
(
'ellipsis_mask'
,
0
))
...
@@ -889,7 +891,7 @@ def _stridedSlice():
...
@@ -889,7 +891,7 @@ def _stridedSlice():
if
begin_mask
or
end_mask
or
ellipsis_mask
or
new_axis_mask
or
shrink_axis_mask
:
if
begin_mask
or
end_mask
or
ellipsis_mask
or
new_axis_mask
or
shrink_axis_mask
:
begin
,
end
,
stride
,
fshape_indices
=
_transform_mask
(
stride_dim
,
ellipsis_mask
)
begin
,
end
,
stride
,
fshape_indices
=
_transform_mask
(
stride_dim
,
ellipsis_mask
)
out
=
_op
.
strided_slice
(
inputs
[
0
],
begin
=
begin
,
end
=
end
,
strides
=
stride
)
out
=
_op
.
strided_slice
(
inputs
[
0
],
begin
=
begin
,
end
=
end
,
strides
=
stride
)
out_shape
=
_infer_
out_shapes
(
out
,
params
)[
0
]
out_shape
=
_infer_
shape
(
out
,
params
)
if
not
fshape_indices
:
if
not
fshape_indices
:
fshape_indices
=
range
(
len
(
out_shape
))
fshape_indices
=
range
(
len
(
out_shape
))
...
@@ -910,19 +912,14 @@ def _stridedSlice():
...
@@ -910,19 +912,14 @@ def _stridedSlice():
def
_pad
(
name
):
def
_pad
(
name
):
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
padlist_key
=
inputs
[
1
]
.
name_hint
padlist
=
_get_param
(
params
,
inputs
[
1
])
if
padlist_key
in
params
:
paddings
=
tuple
(
tuple
(
l
)
for
l
in
padlist
)
padlist
=
params
.
pop
(
padlist_key
)
.
asnumpy
()
else
:
raise
tvm
.
error
.
OpAttributeRequired
(
'Attribute {} not found in operator Pad.'
.
format
(
padlist_key
))
paddings
=
tuple
([
tuple
(
l
)
for
l
in
padlist
])
attr
[
'pad_width'
]
=
paddings
attr
[
'pad_width'
]
=
paddings
attr
[
'pad_value'
]
=
0
attr
[
'pad_value'
]
=
0
new_inputs
=
[
inputs
[
0
]]
new_inputs
=
[
inputs
[
0
]]
if
name
==
'PadV2'
:
if
name
==
'PadV2'
:
constant_values
=
params
.
pop
(
inputs
[
2
]
.
name_hint
)
.
asnumpy
(
)
constant_values
=
_get_num_param
(
params
,
inputs
[
2
]
)
attr
[
'pad_value'
]
=
constant_values
[
0
]
attr
[
'pad_value'
]
=
constant_values
return
AttrCvt
(
return
AttrCvt
(
op_name
=
'pad'
,
op_name
=
'pad'
,
ignores
=
[
'Tpaddings'
],)(
new_inputs
,
attr
)
ignores
=
[
'Tpaddings'
],)(
new_inputs
,
attr
)
...
@@ -932,10 +929,9 @@ def _transpose():
...
@@ -932,10 +929,9 @@ def _transpose():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
# If perm is not specified, axes is left empty,
# If perm is not specified, axes is left empty,
# otherwise its value is get from params
# otherwise its value is get from params
param_name
=
_get_name_hint
(
inputs
[
1
])
try
:
if
param_name
in
params
:
axes
=
_get_list_param
(
params
,
inputs
[
1
])
axes
=
tuple
(
params
.
get
(
param_name
)
.
asnumpy
())
except
(
IndexError
,
KeyError
):
else
:
axes
=
None
axes
=
None
return
_op
.
transpose
(
inputs
[
0
],
axes
=
axes
)
return
_op
.
transpose
(
inputs
[
0
],
axes
=
axes
)
return
_impl
return
_impl
...
@@ -947,7 +943,7 @@ def _where():
...
@@ -947,7 +943,7 @@ def _where():
def
_reverse_v2
():
def
_reverse_v2
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
axis
=
params
.
pop
(
inputs
[
1
]
.
name_hint
)
.
asnumpy
()[
0
]
axis
=
_get_num_param
(
params
,
inputs
[
1
])
return
AttrCvt
(
return
AttrCvt
(
op_name
=
"reverse"
,
op_name
=
"reverse"
,
ignores
=
[
'Tidx'
],
ignores
=
[
'Tidx'
],
...
@@ -968,9 +964,9 @@ def _rank():
...
@@ -968,9 +964,9 @@ def _rank():
def
_range
():
def
_range
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
start
=
params
.
pop
(
inputs
[
0
]
.
name_hint
)
.
asnumpy
()[
0
]
start
=
_get_num_param
(
params
,
inputs
[
0
])
limit
=
params
.
pop
(
inputs
[
1
]
.
name_hint
)
.
asnumpy
()[
0
]
limit
=
_get_num_param
(
params
,
inputs
[
1
])
delta
=
params
.
pop
(
inputs
[
2
]
.
name_hint
)
.
asnumpy
()[
0
]
delta
=
_get_num_param
(
params
,
inputs
[
2
])
name
=
attr
[
"_node_name"
]
name
=
attr
[
"_node_name"
]
params
[
name
]
=
tvm
.
nd
.
array
([
start
,
limit
,
delta
])
params
[
name
]
=
tvm
.
nd
.
array
([
start
,
limit
,
delta
])
...
@@ -981,25 +977,27 @@ def _range():
...
@@ -981,25 +977,27 @@ def _range():
def
_elu
():
def
_elu
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
alpha
=
tvm
.
relay
.
const
(
-
1.0
,
attr
[
'T'
]
.
name
)
dtype
=
attr
[
'T'
]
.
name
return
alpha
*
_op
.
nn
.
relu
(
tvm
.
relay
.
const
(
1
,
attr
[
'T'
]
.
name
)
\
alpha
=
tvm
.
relay
.
const
(
-
1.0
,
dtype
)
return
alpha
*
_op
.
nn
.
relu
(
tvm
.
relay
.
const
(
1
,
dtype
)
\
-
_op
.
exp
(
inputs
[
0
]))
+
_op
.
nn
.
relu
(
inputs
[
0
])
-
_op
.
exp
(
inputs
[
0
]))
+
_op
.
nn
.
relu
(
inputs
[
0
])
return
_impl
return
_impl
def
_selu
():
def
_selu
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
alpha
=
tvm
.
relay
.
const
(
-
1.6732632423543772848170429916717
,
attr
[
'T'
]
.
name
)
dtype
=
attr
[
'T'
]
.
name
gamma
=
tvm
.
relay
.
const
(
1.0507009873554804934193349852946
,
attr
[
'T'
]
.
name
)
alpha
=
tvm
.
relay
.
const
(
-
1.6732632423543772848170429916717
,
dtype
)
return
gamma
*
(
alpha
*
_op
.
nn
.
relu
(
tvm
.
relay
.
const
(
1
,
attr
[
'T'
]
.
name
)
\
gamma
=
tvm
.
relay
.
const
(
1.0507009873554804934193349852946
,
dtype
)
return
gamma
*
(
alpha
*
_op
.
nn
.
relu
(
tvm
.
relay
.
const
(
1
,
dtype
)
\
-
_op
.
exp
(
inputs
[
0
]))
+
_op
.
nn
.
relu
(
inputs
[
0
]))
-
_op
.
exp
(
inputs
[
0
]))
+
_op
.
nn
.
relu
(
inputs
[
0
]))
return
_impl
return
_impl
def
_mean
():
def
_mean
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
axis
=
params
.
pop
(
inputs
[
1
]
.
name_hint
)
axis
=
_get_tuple_param
(
params
,
inputs
[
1
]
)
return
AttrCvt
(
op_name
=
"mean"
,
ignores
=
[
'Tdim'
,
'Tidx'
],
return
AttrCvt
(
op_name
=
"mean"
,
ignores
=
[
'Tdim'
,
'Tidx'
],
transforms
=
{
'keep_dims'
:
'keepdims'
},
transforms
=
{
'keep_dims'
:
'keepdims'
},
extras
=
{
'axis'
:
tuple
(
axis
.
asnumpy
())
})([
inputs
[
0
]],
attr
)
extras
=
{
'axis'
:
axis
})([
inputs
[
0
]],
attr
)
return
_impl
return
_impl
def
_broadcast
(
name
):
def
_broadcast
(
name
):
...
@@ -1025,8 +1023,7 @@ def _split(has_size_vector):
...
@@ -1025,8 +1023,7 @@ def _split(has_size_vector):
if
has_size_vector
:
if
has_size_vector
:
input_node_index
=
0
input_node_index
=
0
input_axis_index
=
2
input_axis_index
=
2
size_splits_input_name
=
_get_name_hint
(
inputs
[
1
])
size_splits
=
_get_param
(
params
,
inputs
[
1
])
size_splits
=
params
[
size_splits_input_name
]
.
asnumpy
()
section_beginnings
=
np
.
cumsum
(
size_splits
)[:
-
1
]
section_beginnings
=
np
.
cumsum
(
size_splits
)[:
-
1
]
indices_or_sections
=
tuple
(
section_beginnings
)
indices_or_sections
=
tuple
(
section_beginnings
)
else
:
else
:
...
@@ -1034,8 +1031,7 @@ def _split(has_size_vector):
...
@@ -1034,8 +1031,7 @@ def _split(has_size_vector):
input_axis_index
=
0
input_axis_index
=
0
indices_or_sections
=
attr
[
'num_split'
]
indices_or_sections
=
attr
[
'num_split'
]
input_node
=
inputs
[
input_node_index
]
input_node
=
inputs
[
input_node_index
]
axis_input_name
=
_get_name_hint
(
inputs
[
input_axis_index
])
axis_input_value
=
_get_num_param
(
params
,
inputs
[
input_axis_index
])
axis_input_value
=
params
[
axis_input_name
]
.
asnumpy
()[
0
]
except
(
IndexError
,
KeyError
):
except
(
IndexError
,
KeyError
):
raise
TypeError
(
\
raise
TypeError
(
\
"Unsupported argument for split: `axis` and `num_or_size_splits` "
\
"Unsupported argument for split: `axis` and `num_or_size_splits` "
\
...
@@ -1105,8 +1101,8 @@ def _space_to_batch_nd():
...
@@ -1105,8 +1101,8 @@ def _space_to_batch_nd():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
input_node
=
inputs
[
0
]
input_node
=
inputs
[
0
]
input_shape
=
attr
[
'_input_shapes'
][
input_node
]
input_shape
=
attr
[
'_input_shapes'
][
input_node
]
block_shape
=
params
.
pop
(
inputs
[
1
]
.
name_hint
)
.
asnumpy
()
.
tolist
(
)
block_shape
=
_get_list_param
(
params
,
inputs
[
1
]
)
paddings
=
params
.
pop
(
inputs
[
2
]
.
name_hint
)
.
asnumpy
()
.
tolist
(
)
paddings
=
_get_list_param
(
params
,
inputs
[
2
]
)
N
=
len
(
input_shape
)
N
=
len
(
input_shape
)
M
=
len
(
block_shape
)
M
=
len
(
block_shape
)
batch
=
input_shape
[
0
]
batch
=
input_shape
[
0
]
...
@@ -1127,7 +1123,7 @@ def _space_to_batch_nd():
...
@@ -1127,7 +1123,7 @@ def _space_to_batch_nd():
axes
=
[
2
*
i
+
2
for
i
in
range
(
M
)]
+
[
0
]
+
[
2
*
i
+
1
for
i
in
range
(
M
)]
+
\
axes
=
[
2
*
i
+
2
for
i
in
range
(
M
)]
+
[
0
]
+
[
2
*
i
+
1
for
i
in
range
(
M
)]
+
\
list
(
range
(
1
+
2
*
M
,
1
+
2
*
M
+
remaining_shape_length
))
list
(
range
(
1
+
2
*
M
,
1
+
2
*
M
+
remaining_shape_length
))
permuted_reshaped_padded
=
tvm
.
relay
.
transpose
(
reshaped_padded
,
axes
=
axes
)
permuted_reshaped_padded
=
tvm
.
relay
.
transpose
(
reshaped_padded
,
axes
=
axes
)
permuted_reshaped_padded_shape
=
_infer_
out_shapes
(
permuted_reshaped_padded
,
params
)[
0
]
permuted_reshaped_padded_shape
=
_infer_
shape
(
permuted_reshaped_padded
,
params
)
# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# producing an output tensor of shape:
# producing an output tensor of shape:
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
...
@@ -1144,8 +1140,8 @@ def _batch_to_space_nd():
...
@@ -1144,8 +1140,8 @@ def _batch_to_space_nd():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
input_node
=
inputs
[
0
]
input_node
=
inputs
[
0
]
input_shape
=
attr
[
'_input_shapes'
][
input_node
]
input_shape
=
attr
[
'_input_shapes'
][
input_node
]
block_shape
=
params
.
pop
(
inputs
[
1
]
.
name_hint
)
.
asnumpy
()
.
tolist
(
)
block_shape
=
_get_list_param
(
params
,
inputs
[
1
]
)
crops
=
params
.
pop
(
inputs
[
2
]
.
name_hint
)
.
asnumpy
()
.
tolist
(
)
crops
=
_get_list_param
(
params
,
inputs
[
2
]
)
M
=
len
(
block_shape
)
M
=
len
(
block_shape
)
batch
=
input_shape
[
0
]
batch
=
input_shape
[
0
]
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
...
@@ -1170,7 +1166,7 @@ def _batch_to_space_nd():
...
@@ -1170,7 +1166,7 @@ def _batch_to_space_nd():
# [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
# [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
# ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
# ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
# input_shape[M+1], ..., input_shape[N-1]]
# input_shape[M+1], ..., input_shape[N-1]]
reshaped_permuted_shape
=
_infer_
out_shapes
(
reshaped_permuted
,
params
)[
0
]
reshaped_permuted_shape
=
_infer_
shape
(
reshaped_permuted
,
params
)
cropped
=
reshaped_permuted
cropped
=
reshaped_permuted
for
axis
in
range
(
1
,
M
+
1
):
for
axis
in
range
(
1
,
M
+
1
):
crop
=
crops
[
axis
-
1
]
crop
=
crops
[
axis
-
1
]
...
@@ -1971,23 +1967,17 @@ class GraphProto(object):
...
@@ -1971,23 +1967,17 @@ class GraphProto(object):
# Infer shapes even without specifying "add_shapes=True"
# Infer shapes even without specifying "add_shapes=True"
if
output_shapes
==
[
None
]:
if
output_shapes
==
[
None
]:
out_shapes
=
[]
out_shapes
=
[
_infer_shape
(
node_item
)
for
node_item
in
self
.
_nodes
[
node
.
name
]]
for
node_item
in
self
.
_nodes
[
node
.
name
]:
out_type
=
ir_pass
.
infer_type
(
node_item
)
out_shapes
.
append
(
get_const_tuple
(
out_type
.
checked_type
.
shape
))
self
.
_output_shapes
[
node
.
name
]
=
out_shapes
self
.
_output_shapes
[
node
.
name
]
=
out_shapes
if
self
.
_output_shapes
[
node
.
name
]
and
shape
and
node
.
name
in
shape
:
if
self
.
_output_shapes
[
node
.
name
]
and
shape
and
node
.
name
in
shape
:
assert
self
.
_output_shapes
[
node
.
name
]
==
list
(
shape
[
node
.
name
])
assert
self
.
_output_shapes
[
node
.
name
]
==
list
(
shape
[
node
.
name
])
# Infer shapes if passed explicit
e
ly
# Infer shapes if passed explicitly
node_output
=
self
.
_nodes
[
node
.
name
]
node_output
=
self
.
_nodes
[
node
.
name
]
if
shape
and
(
not
self
.
_output_shapes
[
node
.
name
][
0
]
if
shape
and
(
not
self
.
_output_shapes
[
node
.
name
][
0
]
or
-
1
in
self
.
_output_shapes
[
node
.
name
][
0
]):
or
-
1
in
self
.
_output_shapes
[
node
.
name
][
0
]):
out_shapes
=
[]
out_shapes
=
[
_infer_shape
(
node_item
)
for
node_item
in
node_output
]
for
node_item
in
node_output
:
out_type
=
ir_pass
.
infer_type
(
node_item
)
out_shapes
.
append
(
get_const_tuple
(
out_type
.
checked_type
.
shape
))
self
.
_output_shapes
[
node
.
name
]
=
out_shapes
self
.
_output_shapes
[
node
.
name
]
=
out_shapes
out
=
[]
out
=
[]
...
...
tests/python/frontend/tensorflow/test_forward.py
View file @
770ac84e
...
@@ -56,31 +56,23 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
...
@@ -56,31 +56,23 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
layout
=
None
layout
=
None
if
target
==
"cuda"
:
if
target
==
"cuda"
:
layout
=
"NCHW"
layout
=
"NCHW"
target_host
=
'llvm'
target_host
=
None
if
isinstance
(
input_data
,
list
):
shape_dict
=
{
e
:
i
.
shape
for
e
,
i
in
zip
(
input_node
,
input_data
)}
shape_dict
=
{}
dtype_dict
=
{}
for
i
,
e
in
enumerate
(
input_node
):
shape_dict
[
e
]
=
input_data
[
i
]
.
shape
dtype_dict
[
e
]
=
input_data
[
i
]
.
dtype
else
:
shape_dict
=
{
input_node
:
input_data
.
shape
}
dtype_dict
=
{
input_node
:
input_data
.
dtype
}
sym
,
params
=
relay
.
frontend
.
from_tensorflow
(
graph_def
,
sym
,
params
=
relay
.
frontend
.
from_tensorflow
(
graph_def
,
layout
=
layout
,
layout
=
layout
,
shape
=
shape_dict
,
shape
=
shape_dict
,
outputs
=
out_names
)
outputs
=
out_names
)
with
relay
.
build_config
(
opt_level
=
opt_level
):
with
relay
.
build_config
(
opt_level
=
opt_level
):
graph
,
lib
,
params
=
relay
.
build
(
sym
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
sym
,
target
,
target_host
,
params
)
ctx
=
tvm
.
context
(
target
,
0
)
ctx
=
tvm
.
context
(
target
,
0
)
from
tvm.contrib
import
graph_runtime
from
tvm.contrib
import
graph_runtime
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
# set inputs
# set inputs
for
i
,
e
in
enumerate
(
input_node
):
for
e
,
i
in
zip
(
input_node
,
input_data
):
m
.
set_input
(
e
,
tvm
.
nd
.
array
(
i
nput_data
[
i
]
.
astype
(
input_data
[
i
]
.
dtype
)
))
m
.
set_input
(
e
,
tvm
.
nd
.
array
(
i
))
m
.
set_input
(
**
params
)
m
.
set_input
(
**
params
)
# execute
# execute
...
@@ -88,10 +80,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
...
@@ -88,10 +80,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
# get outputs
# get outputs
assert
out_names
is
None
or
num_output
==
len
(
out_names
),
(
assert
out_names
is
None
or
num_output
==
len
(
out_names
),
(
"out_names: {} num_output: {}"
.
format
(
out_names
,
num_output
))
"out_names: {} num_output: {}"
.
format
(
out_names
,
num_output
))
tvm_output_list
=
[]
tvm_output_list
=
[
m
.
get_output
(
i
)
.
asnumpy
()
for
i
in
range
(
num_output
)]
for
i
in
range
(
0
,
num_output
):
tvm_output
=
m
.
get_output
(
i
)
tvm_output_list
.
append
(
tvm_output
.
asnumpy
())
return
tvm_output_list
return
tvm_output_list
def
run_tf_graph
(
sess
,
input_data
,
input_node
,
output_node
):
def
run_tf_graph
(
sess
,
input_data
,
input_node
,
output_node
):
...
@@ -100,13 +89,9 @@ def run_tf_graph(sess, input_data, input_node, output_node):
...
@@ -100,13 +89,9 @@ def run_tf_graph(sess, input_data, input_node, output_node):
input_node
=
convert_to_list
(
input_node
)
input_node
=
convert_to_list
(
input_node
)
output_node
=
convert_to_list
(
output_node
)
output_node
=
convert_to_list
(
output_node
)
tensor
=
[
0
]
*
len
(
output_node
)
tensor
=
[
sess
.
graph
.
get_tensor_by_name
(
output_name
)
for
output_name
in
output_node
]
for
i
in
range
(
len
(
output_node
)):
tensor
[
i
]
=
sess
.
graph
.
get_tensor_by_name
(
output_node
[
i
])
input_dict
=
{}
input_dict
=
{
e
:
input_data
[
i
]
for
i
,
e
in
enumerate
(
input_node
)}
for
i
,
e
in
enumerate
(
input_node
):
input_dict
[
e
]
=
input_data
[
i
]
output_data
=
sess
.
run
(
tensor
,
input_dict
)
output_data
=
sess
.
run
(
tensor
,
input_dict
)
return
output_data
return
output_data
...
@@ -115,17 +100,15 @@ def run_tf_graph(sess, input_data, input_node, output_node):
...
@@ -115,17 +100,15 @@ def run_tf_graph(sess, input_data, input_node, output_node):
def
compare_tf_with_tvm
(
in_data
,
in_name
,
out_name
,
init_global_variables
=
False
,
def
compare_tf_with_tvm
(
in_data
,
in_name
,
out_name
,
init_global_variables
=
False
,
no_gpu
=
False
,
opt_level
=
3
):
no_gpu
=
False
,
opt_level
=
3
):
"""Generic function to generate and compare tensorflow and TVM output"""
"""Generic function to generate and compare tensorflow and TVM output"""
def
name_without_num
(
name
):
return
name
.
split
(
':'
)[
0
]
if
":"
in
name
else
name
out_name
=
convert_to_list
(
out_name
)
out_name
=
convert_to_list
(
out_name
)
out_node
=
[
0
]
*
len
(
out_name
)
out_node
=
[
name_without_num
(
name
)
for
name
in
out_name
]
for
i
in
range
(
len
(
out_name
)):
out_node
[
i
]
=
out_name
[
i
]
.
split
(
':'
)[
0
]
if
":"
in
out_name
[
i
]
else
out_name
[
i
]
in_data
=
convert_to_list
(
in_data
)
in_data
=
convert_to_list
(
in_data
)
in_name
=
convert_to_list
(
in_name
)
in_name
=
convert_to_list
(
in_name
)
in_node
=
[
0
]
*
len
(
in_name
)
in_node
=
[
name_without_num
(
name
)
for
name
in
in_name
]
for
i
in
range
(
len
(
in_name
)):
in_node
[
i
]
=
in_name
[
i
]
.
split
(
':'
)[
0
]
if
":"
in
in_name
[
i
]
else
in_name
[
i
]
with
tf
.
Session
()
as
sess
:
with
tf
.
Session
()
as
sess
:
if
init_global_variables
:
if
init_global_variables
:
sess
.
run
(
variables
.
global_variables_initializer
())
sess
.
run
(
variables
.
global_variables_initializer
())
...
@@ -578,6 +561,38 @@ def test_forward_variable():
...
@@ -578,6 +561,38 @@ def test_forward_variable():
#######################################################################
#######################################################################
# MatMul
# ------
def
_test_matmul
(
i
,
j
,
k
,
dtype
,
outer
=
None
):
""" One iteration of matmul """
A_shape_init
=
[
i
,
j
]
B_shape_init
=
[
j
,
k
]
for
transpose_a
in
[
False
,
True
]:
for
transpose_b
in
[
False
,
True
]:
outer
=
outer
or
[]
A_shape
=
outer
+
(
A_shape_init
[::
-
1
]
if
transpose_a
else
A_shape_init
)
B_shape
=
outer
+
(
B_shape_init
[::
-
1
]
if
transpose_b
else
B_shape_init
)
with
tf
.
Graph
()
.
as_default
():
A
=
tf
.
placeholder
(
shape
=
A_shape
,
dtype
=
dtype
,
name
=
'A'
)
B
=
tf
.
placeholder
(
shape
=
B_shape
,
dtype
=
dtype
,
name
=
'B'
)
result
=
tf
.
matmul
(
A
,
B
,
transpose_a
=
transpose_a
,
transpose_b
=
transpose_b
)
A_np
=
np
.
random
.
uniform
(
high
=
5.0
,
size
=
A_shape
)
.
astype
(
dtype
)
B_np
=
np
.
random
.
uniform
(
high
=
5.0
,
size
=
B_shape
)
.
astype
(
dtype
)
compare_tf_with_tvm
([
A_np
,
B_np
],
[
A
.
name
,
B
.
name
],
result
.
name
)
def
test_forward_matmul
():
""" Matmul op test"""
_test_matmul
(
1
,
3
,
6
,
'int32'
)
_test_matmul
(
5
,
3
,
1
,
'float64'
)
# TODO non-empty outer requires BatchMatMul (BatchMatMulV2 for some cases?) support
#######################################################################
# StridedSlice
# StridedSlice
# ------------
# ------------
...
@@ -1785,3 +1800,6 @@ if __name__ == '__main__':
...
@@ -1785,3 +1800,6 @@ if __name__ == '__main__':
test_forward_rel_ops
()
test_forward_rel_ops
()
test_forward_logical
()
test_forward_logical
()
test_where
()
test_where
()
test_forward_matmul
()
# TODO missing tests: rank, range
\ No newline at end of file
topi/python/topi/util.py
View file @
770ac84e
...
@@ -151,11 +151,7 @@ def get_const_tuple(in_tuple):
...
@@ -151,11 +151,7 @@ def get_const_tuple(in_tuple):
out_tuple : tuple of int
out_tuple : tuple of int
The output.
The output.
"""
"""
out_tuple
=
()
return
tuple
(
get_const_int
(
elem
)
for
elem
in
in_tuple
)
for
elem
in
in_tuple
:
value
=
get_const_int
(
elem
)
out_tuple
=
out_tuple
+
(
value
,
)
return
out_tuple
def
get_float_tuple
(
in_tuple
):
def
get_float_tuple
(
in_tuple
):
...
@@ -171,11 +167,7 @@ def get_float_tuple(in_tuple):
...
@@ -171,11 +167,7 @@ def get_float_tuple(in_tuple):
out_tuple : tuple of float
out_tuple : tuple of float
The output.
The output.
"""
"""
out_tuple
=
()
return
tuple
(
get_const_float
(
elem
)
for
elem
in
in_tuple
)
for
elem
in
in_tuple
:
value
=
get_const_float
(
elem
)
out_tuple
=
out_tuple
+
(
value
,
)
return
out_tuple
def
simplify
(
expr
):
def
simplify
(
expr
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment