Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
be8fa6ac
Commit
be8fa6ac
authored
Aug 06, 2019
by
Zhi
Committed by
Thierry Moreau
Aug 06, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[relay][frontend] clean up tf frontend (#3710)
* clean up tf frontend * fix get_relay_op
parent
8d5de5ed
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
11 deletions
+29
-11
python/tvm/relay/frontend/common.py
+27
-3
python/tvm/relay/frontend/mxnet.py
+2
-8
python/tvm/relay/frontend/tensorflow.py
+0
-0
No files found.
python/tvm/relay/frontend/common.py
View file @
be8fa6ac
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
"""Common utilities"""
"""Common utilities"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
logging
import
logging
import
tvm
from
topi.util
import
get_const_tuple
from
topi.util
import
get_const_tuple
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
module
as
_module
from
..
import
module
as
_module
...
@@ -224,6 +226,7 @@ class StrAttrsDict(object):
...
@@ -224,6 +226,7 @@ class StrAttrsDict(object):
raise
AttributeError
(
"Required attribute {} not found."
.
format
(
key
))
raise
AttributeError
(
"Required attribute {} not found."
.
format
(
key
))
return
default
return
default
def
get_relay_op
(
op_name
):
def
get_relay_op
(
op_name
):
"""Get the callable function from Relay based on operator name.
"""Get the callable function from Relay based on operator name.
Parameters
Parameters
...
@@ -246,9 +249,10 @@ def get_relay_op(op_name):
...
@@ -246,9 +249,10 @@ def get_relay_op(op_name):
if
op
is
not
None
:
if
op
is
not
None
:
break
break
if
not
op
:
if
not
op
:
raise
RuntimeError
(
"Unable to map op_name {} to relay"
.
format
(
op_name
))
raise
tvm
.
error
.
OpNotImplemented
(
"Unable to map op_name {} to relay"
.
format
(
op_name
))
return
op
return
op
class
ExprTable
(
object
):
class
ExprTable
(
object
):
"""Table storing Relay expressions by names."""
"""Table storing Relay expressions by names."""
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -298,21 +302,27 @@ class AttrCvt(object):
...
@@ -298,21 +302,27 @@ class AttrCvt(object):
If set as str, returned operator name is the str.
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provided, then the attribute is considered as optional.
If default_value if provided, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
If transform function is provided, the original attribute value is handled
by transform function.
by transform function.
excludes : list
excludes : list
A list of excluded attributes that should `NOT` appear.
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occurred.
Raise NotImplementedError if occurred.
disables : list
disables : list
A list of attributes that is disabled in relay. Log warnings.
A list of attributes that is disabled in relay. Log warnings.
ignores : list
ignores : list
A list of attributes that is ignored in relay. Debug level logging.
A list of attributes that is ignored in relay. Debug level logging.
extras : dict
extras : dict
A series of additional attributes should be added anyway to the returned
A series of additional attributes should be added anyway to the returned
attribute dict.
attribute dict.
custom_check : callable
custom_check : callable
A custom function takes attribute, and return True/False.
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
Raise RuntimeError if not bool(True) returned.
...
@@ -329,6 +339,14 @@ class AttrCvt(object):
...
@@ -329,6 +339,14 @@ class AttrCvt(object):
self
.
_custom_check
=
custom_check
self
.
_custom_check
=
custom_check
def
__call__
(
self
,
inputs
,
attrs
,
*
args
):
def
__call__
(
self
,
inputs
,
attrs
,
*
args
):
self
.
_ignores
.
append
(
'_output_shapes'
)
self
.
_ignores
.
append
(
'_input_shapes'
)
self
.
_ignores
.
append
(
'T'
)
self
.
_ignores
.
append
(
'use_cudnn_on_gpu'
)
self
.
_ignores
.
append
(
'_node_name'
)
self
.
_ignores
.
append
(
'is_training'
)
self
.
_ignores
.
append
(
'_target_layout'
)
# apply custom check
# apply custom check
if
self
.
_custom_check
:
if
self
.
_custom_check
:
func
,
msg
=
self
.
_custom_check
func
,
msg
=
self
.
_custom_check
...
@@ -348,7 +366,8 @@ class AttrCvt(object):
...
@@ -348,7 +366,8 @@ class AttrCvt(object):
new_attrs
=
{}
new_attrs
=
{}
for
k
in
attrs
.
keys
():
for
k
in
attrs
.
keys
():
if
k
in
self
.
_excludes
:
if
k
in
self
.
_excludes
:
raise
NotImplementedError
(
"Attribute {} not supported yet."
.
format
(
k
))
raise
NotImplementedError
(
'Attribute
%
s in operator
%
s is not'
+
' supported.'
,
k
,
op_name
)
elif
k
in
self
.
_disables
:
elif
k
in
self
.
_disables
:
logging
.
warning
(
"Attribute
%
s is disabled in relay.sym.
%
s"
,
k
,
op_name
)
logging
.
warning
(
"Attribute
%
s is disabled in relay.sym.
%
s"
,
k
,
op_name
)
elif
k
in
self
.
_ignores
:
elif
k
in
self
.
_ignores
:
...
@@ -401,6 +420,7 @@ class AttrCvt(object):
...
@@ -401,6 +420,7 @@ class AttrCvt(object):
raise
AttributeError
(
"Required attribute {} not found."
.
format
(
key
))
raise
AttributeError
(
"Required attribute {} not found."
.
format
(
key
))
return
attr
[
key
]
return
attr
[
key
]
def
get_name
(
node
):
def
get_name
(
node
):
name
=
''
name
=
''
if
hasattr
(
node
,
"name_hint"
):
if
hasattr
(
node
,
"name_hint"
):
...
@@ -410,17 +430,19 @@ def get_name(node):
...
@@ -410,17 +430,19 @@ def get_name(node):
def
infer_type
(
node
):
def
infer_type
(
node
):
"""A method to infer the type of an intermediate node in the relay graph."""
"""A method to infer the type of an intermediate node in the relay graph."""
mod
=
_module
.
Module
.
from_expr
(
node
)
mod
=
node
if
isinstance
(
node
,
_module
.
Module
)
else
_module
.
Module
.
from_expr
(
node
)
mod
=
_transform
.
InferType
()(
mod
)
mod
=
_transform
.
InferType
()(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
node
,
_expr
.
Function
)
else
entry
.
body
return
entry
if
isinstance
(
node
,
_expr
.
Function
)
else
entry
.
body
def
infer_shape
(
inputs
):
def
infer_shape
(
inputs
):
"""A method to get the output shape of an intermediate node in the graph."""
"""A method to get the output shape of an intermediate node in the graph."""
out_type
=
infer_type
(
inputs
)
out_type
=
infer_type
(
inputs
)
out_shapes
=
get_const_tuple
(
out_type
.
checked_type
.
shape
)
out_shapes
=
get_const_tuple
(
out_type
.
checked_type
.
shape
)
return
out_shapes
return
out_shapes
def
infer_channels
(
inputs
,
transpose
=
False
):
def
infer_channels
(
inputs
,
transpose
=
False
):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
these attributes. We check the shape of weights provided to get the number.
these attributes. We check the shape of weights provided to get the number.
...
@@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False):
...
@@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False):
channels
=
out_shapes
[
0
][
0
]
if
not
transpose
else
out_shapes
[
0
][
1
]
channels
=
out_shapes
[
0
][
0
]
if
not
transpose
else
out_shapes
[
0
][
1
]
return
channels
return
channels
def
new_var
(
name_hint
,
def
new_var
(
name_hint
,
type_annotation
=
None
,
type_annotation
=
None
,
shape
=
None
,
shape
=
None
,
dtype
=
"float32"
):
dtype
=
"float32"
):
return
_expr
.
var
(
name_hint
,
type_annotation
,
shape
,
dtype
)
return
_expr
.
var
(
name_hint
,
type_annotation
,
shape
,
dtype
)
class
Renamer
(
object
):
class
Renamer
(
object
):
"""A simply renamer for operators.
"""A simply renamer for operators.
...
...
python/tvm/relay/frontend/mxnet.py
View file @
be8fa6ac
...
@@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs
...
@@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs
import
json
import
json
import
tvm
import
tvm
from
..
import
analysis
,
transform
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
op
as
_op
from
..
import
op
as
_op
from
..
import
module
as
_module
from
..
import
module
as
_module
from
...
import
nd
as
_nd
from
...
import
nd
as
_nd
from
.common
import
StrAttrsDict
from
.common
import
StrAttrsDict
from
.common
import
infer_type
as
_infer_type
from
.nnvm_common
import
_rename
,
_binop_scalar
,
_rbinop_scalar
,
_reduce
from
.nnvm_common
import
_rename
,
_binop_scalar
,
_rbinop_scalar
,
_reduce
from
.nnvm_common
import
_arg_reduce
,
_init_op
,
_softmax_op
,
_cast
from
.nnvm_common
import
_arg_reduce
,
_init_op
,
_softmax_op
,
_cast
from
.nnvm_common
import
_clip
,
_transpose
,
_upsampling
from
.nnvm_common
import
_clip
,
_transpose
,
_upsampling
...
@@ -41,13 +42,6 @@ _activation_map = {
...
@@ -41,13 +42,6 @@ _activation_map = {
"relu"
:
_op
.
nn
.
relu
"relu"
:
_op
.
nn
.
relu
}
}
def
_infer_type
(
node
):
"""A method to infer the type of an intermediate node in the relay graph."""
mod
=
_module
.
Module
.
from_expr
(
node
)
mod
=
transform
.
InferType
()(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
node
,
_expr
.
Function
)
else
entry
.
body
def
_mx_fully_connected
(
inputs
,
attrs
):
def
_mx_fully_connected
(
inputs
,
attrs
):
import
mxnet
as
mx
import
mxnet
as
mx
units
=
attrs
.
get_int
(
"num_hidden"
)
units
=
attrs
.
get_int
(
"num_hidden"
)
...
...
python/tvm/relay/frontend/tensorflow.py
View file @
be8fa6ac
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment