Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
be8fa6ac
Commit
be8fa6ac
authored
Aug 06, 2019
by
Zhi
Committed by
Thierry Moreau
Aug 06, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[relay][frontend] clean up tf frontend (#3710)
* clean up tf frontend * fix get_relay_op
parent
8d5de5ed
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
197 deletions
+47
-197
python/tvm/relay/frontend/common.py
+27
-3
python/tvm/relay/frontend/mxnet.py
+2
-8
python/tvm/relay/frontend/tensorflow.py
+18
-186
No files found.
python/tvm/relay/frontend/common.py
View file @
be8fa6ac
...
...
@@ -17,6 +17,8 @@
"""Common utilities"""
from
__future__
import
absolute_import
as
_abs
import
logging
import
tvm
from
topi.util
import
get_const_tuple
from
..
import
expr
as
_expr
from
..
import
module
as
_module
...
...
@@ -224,6 +226,7 @@ class StrAttrsDict(object):
raise
AttributeError
(
"Required attribute {} not found."
.
format
(
key
))
return
default
def
get_relay_op
(
op_name
):
"""Get the callable function from Relay based on operator name.
Parameters
...
...
@@ -246,9 +249,10 @@ def get_relay_op(op_name):
if
op
is
not
None
:
break
if
not
op
:
raise
RuntimeError
(
"Unable to map op_name {} to relay"
.
format
(
op_name
))
raise
tvm
.
error
.
OpNotImplemented
(
"Unable to map op_name {} to relay"
.
format
(
op_name
))
return
op
class
ExprTable
(
object
):
"""Table storing Relay expressions by names."""
def
__init__
(
self
):
...
...
@@ -298,21 +302,27 @@ class AttrCvt(object):
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provided, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occurred.
disables : list
A list of attributes that is disabled in relay. Log warnings.
ignores : list
A list of attributes that is ignored in relay. Debug level logging.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
...
...
@@ -329,6 +339,14 @@ class AttrCvt(object):
self
.
_custom_check
=
custom_check
def
__call__
(
self
,
inputs
,
attrs
,
*
args
):
self
.
_ignores
.
append
(
'_output_shapes'
)
self
.
_ignores
.
append
(
'_input_shapes'
)
self
.
_ignores
.
append
(
'T'
)
self
.
_ignores
.
append
(
'use_cudnn_on_gpu'
)
self
.
_ignores
.
append
(
'_node_name'
)
self
.
_ignores
.
append
(
'is_training'
)
self
.
_ignores
.
append
(
'_target_layout'
)
# apply custom check
if
self
.
_custom_check
:
func
,
msg
=
self
.
_custom_check
...
...
@@ -348,7 +366,8 @@ class AttrCvt(object):
new_attrs
=
{}
for
k
in
attrs
.
keys
():
if
k
in
self
.
_excludes
:
raise
NotImplementedError
(
"Attribute {} not supported yet."
.
format
(
k
))
raise
NotImplementedError
(
'Attribute
%
s in operator
%
s is not'
+
' supported.'
,
k
,
op_name
)
elif
k
in
self
.
_disables
:
logging
.
warning
(
"Attribute
%
s is disabled in relay.sym.
%
s"
,
k
,
op_name
)
elif
k
in
self
.
_ignores
:
...
...
@@ -401,6 +420,7 @@ class AttrCvt(object):
raise
AttributeError
(
"Required attribute {} not found."
.
format
(
key
))
return
attr
[
key
]
def
get_name
(
node
):
name
=
''
if
hasattr
(
node
,
"name_hint"
):
...
...
@@ -410,17 +430,19 @@ def get_name(node):
def
infer_type
(
node
):
"""A method to infer the type of an intermediate node in the relay graph."""
mod
=
_module
.
Module
.
from_expr
(
node
)
mod
=
node
if
isinstance
(
node
,
_module
.
Module
)
else
_module
.
Module
.
from_expr
(
node
)
mod
=
_transform
.
InferType
()(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
node
,
_expr
.
Function
)
else
entry
.
body
def
infer_shape
(
inputs
):
"""A method to get the output shape of an intermediate node in the graph."""
out_type
=
infer_type
(
inputs
)
out_shapes
=
get_const_tuple
(
out_type
.
checked_type
.
shape
)
return
out_shapes
def
infer_channels
(
inputs
,
transpose
=
False
):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
these attributes. We check the shape of weights provided to get the number.
...
...
@@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False):
channels
=
out_shapes
[
0
][
0
]
if
not
transpose
else
out_shapes
[
0
][
1
]
return
channels
def
new_var
(
name_hint
,
type_annotation
=
None
,
shape
=
None
,
dtype
=
"float32"
):
return
_expr
.
var
(
name_hint
,
type_annotation
,
shape
,
dtype
)
class
Renamer
(
object
):
"""A simply renamer for operators.
...
...
python/tvm/relay/frontend/mxnet.py
View file @
be8fa6ac
...
...
@@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs
import
json
import
tvm
from
..
import
analysis
,
transform
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
op
as
_op
from
..
import
module
as
_module
from
...
import
nd
as
_nd
from
.common
import
StrAttrsDict
from
.common
import
infer_type
as
_infer_type
from
.nnvm_common
import
_rename
,
_binop_scalar
,
_rbinop_scalar
,
_reduce
from
.nnvm_common
import
_arg_reduce
,
_init_op
,
_softmax_op
,
_cast
from
.nnvm_common
import
_clip
,
_transpose
,
_upsampling
...
...
@@ -41,13 +42,6 @@ _activation_map = {
"relu"
:
_op
.
nn
.
relu
}
def
_infer_type
(
node
):
"""A method to infer the type of an intermediate node in the relay graph."""
mod
=
_module
.
Module
.
from_expr
(
node
)
mod
=
transform
.
InferType
()(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
node
,
_expr
.
Function
)
else
entry
.
body
def
_mx_fully_connected
(
inputs
,
attrs
):
import
mxnet
as
mx
units
=
attrs
.
get_int
(
"num_hidden"
)
...
...
python/tvm/relay/frontend/tensorflow.py
View file @
be8fa6ac
...
...
@@ -19,20 +19,21 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
print_function
import
logging
import
warnings
from
collections
import
defaultdict
# Numpy support
import
numpy
as
np
import
tvm
from
topi.util
import
get_const_tuple
from
..
import
analysis
from
..
import
transform
as
_transform
from
..
import
expr
as
_expr
from
..
import
op
as
_op
from
..expr_functor
import
ExprMutator
from
..
import
module
as
_module
from
.common
import
AttrCvt
,
get_relay_op
from
.common
import
infer_type
as
_infer_type
from
.common
import
infer_shape
as
_infer_shape
from
.common
import
infer_channels
as
_infer_channels
__all__
=
[
'from_tensorflow'
]
...
...
@@ -50,140 +51,6 @@ def _infer_value(input_val, params):
m
.
run
()
return
m
.
get_output
(
0
)
def
_get_relay_op
(
op_name
):
ops
=
[
_op
,
_op
.
nn
,
_op
.
image
,
_op
.
vision
]
for
operator
in
ops
:
try
:
op
=
getattr
(
operator
,
op_name
)
return
op
except
AttributeError
:
continue
raise
tvm
.
error
.
OpNotImplemented
(
'Operator {} is not supported for frontend TensorFlow.'
.
format
(
op_name
))
class
AttrCvt
(
object
):
"""Common attribute converter. An AttrConverter instance is a callable:
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
```
Parameters
----------
op_name : str or callable
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provided, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occurred.
disables : list
A list of attributes that is disabled in relay. Log warnings.
ignores : list
A list of attributes that is ignored in relay. Debug level logging.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
"""
def
__init__
(
self
,
op_name
,
transforms
=
None
,
excludes
=
None
,
disables
=
None
,
ignores
=
None
,
extras
=
None
,
custom_check
=
None
):
self
.
_op_name
=
op_name
self
.
_transforms
=
transforms
if
transforms
else
{}
self
.
_excludes
=
excludes
if
excludes
else
[]
self
.
_disables
=
disables
if
disables
else
[]
self
.
_ignores
=
ignores
if
ignores
else
[]
self
.
_extras
=
extras
if
extras
else
{}
self
.
_custom_check
=
custom_check
def
__call__
(
self
,
inputs
,
attrs
,
*
args
):
self
.
_ignores
.
append
(
'_output_shapes'
)
self
.
_ignores
.
append
(
'_input_shapes'
)
self
.
_ignores
.
append
(
'T'
)
self
.
_ignores
.
append
(
'use_cudnn_on_gpu'
)
self
.
_ignores
.
append
(
'_node_name'
)
self
.
_ignores
.
append
(
'is_training'
)
self
.
_ignores
.
append
(
'_target_layout'
)
# apply custom check
if
self
.
_custom_check
:
func
,
msg
=
self
.
_custom_check
if
not
func
(
attrs
):
raise
RuntimeError
(
"Check failed: {}"
.
format
(
msg
))
# get new op_name
if
isinstance
(
self
.
_op_name
,
str
):
op_name
=
self
.
_op_name
else
:
assert
callable
(
self
.
_op_name
),
"op_name can either be string or callable"
op_name
=
self
.
_op_name
(
attrs
)
# convert attributes
new_attrs
=
{}
for
k
in
attrs
.
keys
():
if
k
in
self
.
_excludes
:
raise
tvm
.
error
.
OpAttributeUnImplemented
(
'Attribute {} in operator {} is not supported.'
.
format
(
k
,
op_name
))
elif
k
in
self
.
_disables
:
logging
.
warning
(
"Attribute
%
s is disabled in relay.
%
s"
,
k
,
op_name
)
elif
k
in
self
.
_ignores
:
logging
.
debug
(
"Attribute
%
s is ignored in relay.
%
s"
,
k
,
op_name
)
elif
k
in
self
.
_transforms
:
new_name
,
defaults
,
transform
=
self
.
_parse_default
(
self
.
_transforms
[
k
])
if
defaults
is
None
:
new_attr
=
self
.
_required_attr
(
attrs
,
k
)
else
:
new_attr
=
attrs
.
get
(
k
,
None
)
if
new_attr
is
None
:
new_attrs
[
new_name
]
=
defaults
else
:
new_attrs
[
new_name
]
=
transform
(
new_attr
)
else
:
# copy
new_attrs
[
k
]
=
attrs
[
k
]
# add extras
new_attrs
.
update
(
self
.
_extras
)
return
_get_relay_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_parse_default
(
self
,
target
):
"""Helper function to parse default values."""
if
not
isinstance
(
target
,
(
list
,
tuple
)):
k
,
v
,
t
=
target
,
None
,
lambda
x
:
x
elif
len
(
target
)
==
1
:
k
,
v
,
t
=
target
[
0
],
None
,
lambda
x
:
x
elif
len
(
target
)
==
2
:
k
,
v
,
t
=
target
[
0
],
target
[
1
],
lambda
x
:
x
elif
len
(
target
)
>
2
:
k
,
v
,
t
=
target
[
0
],
target
[
1
],
target
[
2
]
else
:
k
=
None
# should raise
if
not
isinstance
(
k
,
str
):
msg
=
"{} is not a valid target, (name, default) expected."
.
format
(
target
)
raise
ValueError
(
msg
)
return
k
,
v
,
t
def
_parse_bool
(
self
,
value
):
"""Helper function to parse default boolean values."""
if
isinstance
(
value
,
str
):
return
value
.
strip
()
.
lower
()
in
[
'true'
,
'1'
,
't'
,
'y'
,
'yes'
]
return
bool
(
value
)
def
_required_attr
(
self
,
attr
,
key
):
"""Wrapper for getting required attributes."""
assert
isinstance
(
attr
,
dict
)
if
key
not
in
attr
:
raise
tvm
.
error
.
OpAttributeRequired
(
'Attribute {} not found in operator {}'
.
format
(
key
,
self
.
_op_name
))
return
attr
[
key
]
def
_get_pad_pair
(
input1d
,
kernel1d
,
stride1d
):
if
input1d
%
stride1d
==
0
:
pad
=
max
(
kernel1d
-
stride1d
,
0
)
...
...
@@ -195,12 +62,6 @@ def _get_pad_pair(input1d, kernel1d, stride1d):
return
[
pad_before
,
pad_after
]
def
_get_name_hint
(
node
):
name
=
''
if
hasattr
(
node
,
"name_hint"
):
name
=
node
.
name_hint
return
name
def
_math_name_picker
(
surfix
):
def
_impl
(
attr
):
return
'broadcast_'
+
surfix
...
...
@@ -222,30 +83,6 @@ def _dimension_constraint():
return
False
return
_dim_check
,
"Only 2d kernel supported."
def
_infer_channels
(
node
,
params
,
transpose
=
False
):
"""A hack for getting 'channels' or 'units' since tensorflow don't provide
these attributes. We check the shape of weights provided to get the number.
"""
out_shape
=
_infer_shape
(
node
,
params
)
channels
=
out_shape
[
0
]
if
not
transpose
else
out_shape
[
1
]
return
channels
def
_infer_out_shapes
(
inputs
,
params
):
"""A method to get the output shape of intermediate nodes in the relay graph."""
return
[
_infer_shape
(
inputs
,
params
)]
def
_infer_type
(
node
):
"""A method to infer the type of an intermediate node in the relay graph."""
mod
=
_module
.
Module
.
from_expr
(
node
)
mod
=
_transform
.
InferType
()(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
node
,
_expr
.
Function
)
else
entry
.
body
def
_infer_shape
(
node
,
params
=
None
):
"""A method to get the output shape of an intermediate node in the relay graph."""
out_type
=
_infer_type
(
node
)
return
get_const_tuple
(
out_type
.
checked_type
.
shape
)
def
_get_param
(
params
,
input_node
):
return
params
.
pop
(
input_node
.
name_hint
)
.
asnumpy
()
...
...
@@ -280,7 +117,7 @@ def _argx(func, func_name):
def
_elemwise
(
name
):
def
_impl
(
inputs
,
attr
,
params
):
assert
len
(
inputs
)
==
2
,
"{} take 2 inputs, {} given"
.
format
(
name
,
len
(
inputs
))
return
_
get_relay_op
(
name
)(
*
inputs
)
return
get_relay_op
(
name
)(
*
inputs
)
return
_impl
def
_pooling
(
name
):
...
...
@@ -300,7 +137,7 @@ def _pooling(name):
else
:
msg
=
'Value {} of attribute "data_format" of operator Pooling '
\
'is not valid.'
raise
tvm
.
error
.
OpAttributeInvalid
(
msg
.
format
(
attr
s
[
'data_format'
]))
raise
tvm
.
error
.
OpAttributeInvalid
(
msg
.
format
(
attr
[
'data_format'
]))
if
attr
[
'_target_layout'
]
==
"NCHW"
and
attr
[
'data_format'
]
==
"NHWC"
:
tmp_shape
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
...
...
@@ -539,7 +376,7 @@ def _crop_and_resize():
res_crop
=
_op
.
strided_slice
(
inputs
[
0
],
begin
=
begin
,
end
=
size
)
# 2) Resize
res_resize
=
_
get_relay_op
(
'resize'
)(
res_crop
,
**
attrs
)
res_resize
=
get_relay_op
(
'resize'
)(
res_crop
,
**
attrs
)
out
=
_op
.
concatenate
([
out
,
res_resize
],
axis
=
0
)
if
out
else
res_resize
return
out
return
_impl
...
...
@@ -598,7 +435,7 @@ def _check_numerics():
def
_matmul
():
def
_impl
(
inputs
,
attr
,
params
):
channels
=
_infer_channels
(
inputs
[
1
],
params
,
not
attr
[
'transpose_b'
])
channels
=
_infer_channels
(
inputs
[
1
],
not
attr
[
'transpose_b'
])
if
attr
[
'transpose_a'
]:
inputs
[
0
]
=
_op
.
transpose
(
inputs
[
0
],
axes
=
(
1
,
0
))
if
not
attr
[
'transpose_b'
]:
...
...
@@ -615,15 +452,10 @@ def _batch_matmul():
adj_y
=
attr
[
'adj_y'
]
input_x
=
_op
.
transpose
(
inputs
[
0
],
axes
=
[
0
,
2
,
1
])
if
adj_x
else
inputs
[
0
]
input_y
=
_op
.
transpose
(
inputs
[
1
],
axes
=
[
0
,
2
,
1
])
if
not
adj_y
else
inputs
[
1
]
ret
=
_
get_relay_op
(
'batch_matmul'
)(
input_x
,
input_y
)
ret
=
get_relay_op
(
'batch_matmul'
)(
input_x
,
input_y
)
return
ret
return
_impl
def
_undef
():
def
_impl
(
inputs
,
attr
,
params
):
return
_sym
.
__undef__
()
return
_impl
def
_identity
():
def
_impl
(
inputs
,
attr
,
params
):
return
inputs
[
0
]
...
...
@@ -985,7 +817,7 @@ def _stridedSlice():
if
begin_mask
or
end_mask
or
ellipsis_mask
or
new_axis_mask
or
shrink_axis_mask
:
begin
,
end
,
stride
,
fshape_indices
=
_transform_mask
(
stride_dim
,
ellipsis_mask
)
out
=
_op
.
strided_slice
(
inputs
[
0
],
begin
=
begin
,
end
=
end
,
strides
=
stride
)
out_shape
=
_infer_shape
(
out
,
params
)
out_shape
=
_infer_shape
(
out
)
if
not
fshape_indices
:
fshape_indices
=
range
(
len
(
out_shape
))
...
...
@@ -1178,8 +1010,8 @@ def _softplus():
exp_out
=
AttrCvt
(
'exp'
)(
inputs
,
attr
)
inputs
.
append
(
tvm
.
relay
.
const
(
1
,
attr
[
'T'
]
.
name
))
rh
=
tvm
.
relay
.
const
(
1
,
attr
[
'T'
]
.
name
)
add_out
=
_
get_relay_op
(
'add'
)(
exp_out
,
rh
)
return
_
get_relay_op
(
'log'
)(
add_out
)
add_out
=
get_relay_op
(
'add'
)(
exp_out
,
rh
)
return
get_relay_op
(
'log'
)(
add_out
)
return
_impl
def
_topk
():
...
...
@@ -1200,7 +1032,7 @@ def _floordiv():
def
_impl
(
inputs
,
attr
,
params
):
assert
len
(
inputs
)
==
2
div
=
AttrCvt
(
'divide'
)(
inputs
,
attr
)
return
_
get_relay_op
(
'floor'
)(
div
)
return
get_relay_op
(
'floor'
)(
div
)
return
_impl
def
_logical
(
name
):
...
...
@@ -1234,7 +1066,7 @@ def _space_to_batch_nd():
axes
=
[
2
*
i
+
2
for
i
in
range
(
M
)]
+
[
0
]
+
[
2
*
i
+
1
for
i
in
range
(
M
)]
+
\
list
(
range
(
1
+
2
*
M
,
1
+
2
*
M
+
remaining_shape_length
))
permuted_reshaped_padded
=
tvm
.
relay
.
transpose
(
reshaped_padded
,
axes
=
axes
)
permuted_reshaped_padded_shape
=
_infer_shape
(
permuted_reshaped_padded
,
params
)
permuted_reshaped_padded_shape
=
_infer_shape
(
permuted_reshaped_padded
)
# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# producing an output tensor of shape:
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
...
...
@@ -1277,7 +1109,7 @@ def _batch_to_space_nd():
# [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
# ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
# input_shape[M+1], ..., input_shape[N-1]]
reshaped_permuted_shape
=
_infer_shape
(
reshaped_permuted
,
params
)
reshaped_permuted_shape
=
_infer_shape
(
reshaped_permuted
)
cropped
=
reshaped_permuted
for
axis
in
range
(
1
,
M
+
1
):
crop
=
crops
[
axis
-
1
]
...
...
@@ -1305,8 +1137,8 @@ def _log1p():
# op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p
def
_impl
(
inputs
,
attr
,
params
):
one
=
tvm
.
relay
.
const
(
1
,
attr
[
'T'
]
.
name
)
add_out
=
_
get_relay_op
(
'add'
)(
inputs
[
0
],
one
)
return
_
get_relay_op
(
'log'
)(
add_out
)
add_out
=
get_relay_op
(
'add'
)(
inputs
[
0
],
one
)
return
get_relay_op
(
'log'
)(
add_out
)
return
_impl
# compatible operators that do NOT require any conversion.
...
...
@@ -2399,7 +2231,7 @@ class GraphProto(object):
convert_map
=
convert_map
if
convert_map
else
_convert_map
convert_map_rnn
=
_convert_map_rnn
if
op_name
in
identity_list
:
sym
=
_
get_relay_op
(
op_name
)(
*
inputs
,
**
attrs
)
sym
=
get_relay_op
(
op_name
)(
*
inputs
,
**
attrs
)
elif
op_name
in
convert_map
:
sym
=
convert_map
[
op_name
](
inputs
,
attrs
,
self
.
_params
)
elif
op_name
in
convert_map_rnn
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment