Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
be8fa6ac
Commit
be8fa6ac
authored
Aug 06, 2019
by
Zhi
Committed by
Thierry Moreau
Aug 06, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[relay][frontend] clean up tf frontend (#3710)
* clean up tf frontend * fix get_relay_op
parent
8d5de5ed
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
11 deletions
+29
-11
python/tvm/relay/frontend/common.py
+27
-3
python/tvm/relay/frontend/mxnet.py
+2
-8
python/tvm/relay/frontend/tensorflow.py
+0
-0
No files found.
python/tvm/relay/frontend/common.py
View file @
be8fa6ac
...
...
@@ -17,6 +17,8 @@
"""Common utilities"""
from
__future__
import
absolute_import
as
_abs
import
logging
import
tvm
from
topi.util
import
get_const_tuple
from
..
import
expr
as
_expr
from
..
import
module
as
_module
...
...
@@ -224,6 +226,7 @@ class StrAttrsDict(object):
raise
AttributeError
(
"Required attribute {} not found."
.
format
(
key
))
return
default
def
get_relay_op
(
op_name
):
"""Get the callable function from Relay based on operator name.
Parameters
...
...
@@ -246,9 +249,10 @@ def get_relay_op(op_name):
if
op
is
not
None
:
break
if
not
op
:
raise
RuntimeError
(
"Unable to map op_name {} to relay"
.
format
(
op_name
))
raise
tvm
.
error
.
OpNotImplemented
(
"Unable to map op_name {} to relay"
.
format
(
op_name
))
return
op
class
ExprTable
(
object
):
"""Table storing Relay expressions by names."""
def
__init__
(
self
):
...
...
@@ -298,21 +302,27 @@ class AttrCvt(object):
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provided, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occurred.
disables : list
A list of attributes that is disabled in relay. Log warnings.
ignores : list
A list of attributes that is ignored in relay. Debug level logging.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
...
...
@@ -329,6 +339,14 @@ class AttrCvt(object):
self
.
_custom_check
=
custom_check
def
__call__
(
self
,
inputs
,
attrs
,
*
args
):
self
.
_ignores
.
append
(
'_output_shapes'
)
self
.
_ignores
.
append
(
'_input_shapes'
)
self
.
_ignores
.
append
(
'T'
)
self
.
_ignores
.
append
(
'use_cudnn_on_gpu'
)
self
.
_ignores
.
append
(
'_node_name'
)
self
.
_ignores
.
append
(
'is_training'
)
self
.
_ignores
.
append
(
'_target_layout'
)
# apply custom check
if
self
.
_custom_check
:
func
,
msg
=
self
.
_custom_check
...
...
@@ -348,7 +366,8 @@ class AttrCvt(object):
new_attrs
=
{}
for
k
in
attrs
.
keys
():
if
k
in
self
.
_excludes
:
raise
NotImplementedError
(
"Attribute {} not supported yet."
.
format
(
k
))
raise
NotImplementedError
(
'Attribute
%
s in operator
%
s is not'
+
' supported.'
,
k
,
op_name
)
elif
k
in
self
.
_disables
:
logging
.
warning
(
"Attribute
%
s is disabled in relay.sym.
%
s"
,
k
,
op_name
)
elif
k
in
self
.
_ignores
:
...
...
@@ -401,6 +420,7 @@ class AttrCvt(object):
raise
AttributeError
(
"Required attribute {} not found."
.
format
(
key
))
return
attr
[
key
]
def
get_name
(
node
):
name
=
''
if
hasattr
(
node
,
"name_hint"
):
...
...
@@ -410,17 +430,19 @@ def get_name(node):
def
infer_type
(
node
):
"""A method to infer the type of an intermediate node in the relay graph."""
mod
=
_module
.
Module
.
from_expr
(
node
)
mod
=
node
if
isinstance
(
node
,
_module
.
Module
)
else
_module
.
Module
.
from_expr
(
node
)
mod
=
_transform
.
InferType
()(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
node
,
_expr
.
Function
)
else
entry
.
body
def
infer_shape
(
inputs
):
"""A method to get the output shape of an intermediate node in the graph."""
out_type
=
infer_type
(
inputs
)
out_shapes
=
get_const_tuple
(
out_type
.
checked_type
.
shape
)
return
out_shapes
def
infer_channels
(
inputs
,
transpose
=
False
):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
these attributes. We check the shape of weights provided to get the number.
...
...
@@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False):
channels
=
out_shapes
[
0
][
0
]
if
not
transpose
else
out_shapes
[
0
][
1
]
return
channels
def
new_var
(
name_hint
,
type_annotation
=
None
,
shape
=
None
,
dtype
=
"float32"
):
return
_expr
.
var
(
name_hint
,
type_annotation
,
shape
,
dtype
)
class
Renamer
(
object
):
"""A simply renamer for operators.
...
...
python/tvm/relay/frontend/mxnet.py
View file @
be8fa6ac
...
...
@@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs
import
json
import
tvm
from
..
import
analysis
,
transform
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
op
as
_op
from
..
import
module
as
_module
from
...
import
nd
as
_nd
from
.common
import
StrAttrsDict
from
.common
import
infer_type
as
_infer_type
from
.nnvm_common
import
_rename
,
_binop_scalar
,
_rbinop_scalar
,
_reduce
from
.nnvm_common
import
_arg_reduce
,
_init_op
,
_softmax_op
,
_cast
from
.nnvm_common
import
_clip
,
_transpose
,
_upsampling
...
...
@@ -41,13 +42,6 @@ _activation_map = {
"relu"
:
_op
.
nn
.
relu
}
def
_infer_type
(
node
):
"""A method to infer the type of an intermediate node in the relay graph."""
mod
=
_module
.
Module
.
from_expr
(
node
)
mod
=
transform
.
InferType
()(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
node
,
_expr
.
Function
)
else
entry
.
body
def
_mx_fully_connected
(
inputs
,
attrs
):
import
mxnet
as
mx
units
=
attrs
.
get_int
(
"num_hidden"
)
...
...
python/tvm/relay/frontend/tensorflow.py
View file @
be8fa6ac
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment