Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
be8fa6ac
Commit
be8fa6ac
authored
Aug 06, 2019
by
Zhi
Committed by
Thierry Moreau
Aug 06, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[relay][frontend] clean up tf frontend (#3710)
* clean up tf frontend * fix get_relay_op
parent
8d5de5ed
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
197 deletions
+47
-197
python/tvm/relay/frontend/common.py
+27
-3
python/tvm/relay/frontend/mxnet.py
+2
-8
python/tvm/relay/frontend/tensorflow.py
+18
-186
No files found.
python/tvm/relay/frontend/common.py
View file @
be8fa6ac
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
"""Common utilities"""
"""Common utilities"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
logging
import
logging
import
tvm
from
topi.util
import
get_const_tuple
from
topi.util
import
get_const_tuple
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
module
as
_module
from
..
import
module
as
_module
...
@@ -224,6 +226,7 @@ class StrAttrsDict(object):
...
@@ -224,6 +226,7 @@ class StrAttrsDict(object):
raise
AttributeError
(
"Required attribute {} not found."
.
format
(
key
))
raise
AttributeError
(
"Required attribute {} not found."
.
format
(
key
))
return
default
return
default
def
get_relay_op
(
op_name
):
def
get_relay_op
(
op_name
):
"""Get the callable function from Relay based on operator name.
"""Get the callable function from Relay based on operator name.
Parameters
Parameters
...
@@ -246,9 +249,10 @@ def get_relay_op(op_name):
...
@@ -246,9 +249,10 @@ def get_relay_op(op_name):
if
op
is
not
None
:
if
op
is
not
None
:
break
break
if
not
op
:
if
not
op
:
raise
RuntimeError
(
"Unable to map op_name {} to relay"
.
format
(
op_name
))
raise
tvm
.
error
.
OpNotImplemented
(
"Unable to map op_name {} to relay"
.
format
(
op_name
))
return
op
return
op
class
ExprTable
(
object
):
class
ExprTable
(
object
):
"""Table storing Relay expressions by names."""
"""Table storing Relay expressions by names."""
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -298,21 +302,27 @@ class AttrCvt(object):
...
@@ -298,21 +302,27 @@ class AttrCvt(object):
If set as str, returned operator name is the str.
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provided, then the attribute is considered as optional.
If default_value if provided, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
If transform function is provided, the original attribute value is handled
by transform function.
by transform function.
excludes : list
excludes : list
A list of excluded attributes that should `NOT` appear.
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occurred.
Raise NotImplementedError if occurred.
disables : list
disables : list
A list of attributes that is disabled in relay. Log warnings.
A list of attributes that is disabled in relay. Log warnings.
ignores : list
ignores : list
A list of attributes that is ignored in relay. Debug level logging.
A list of attributes that is ignored in relay. Debug level logging.
extras : dict
extras : dict
A series of additional attributes should be added anyway to the returned
A series of additional attributes should be added anyway to the returned
attribute dict.
attribute dict.
custom_check : callable
custom_check : callable
A custom function takes attribute, and return True/False.
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
Raise RuntimeError if not bool(True) returned.
...
@@ -329,6 +339,14 @@ class AttrCvt(object):
...
@@ -329,6 +339,14 @@ class AttrCvt(object):
self
.
_custom_check
=
custom_check
self
.
_custom_check
=
custom_check
def
__call__
(
self
,
inputs
,
attrs
,
*
args
):
def
__call__
(
self
,
inputs
,
attrs
,
*
args
):
self
.
_ignores
.
append
(
'_output_shapes'
)
self
.
_ignores
.
append
(
'_input_shapes'
)
self
.
_ignores
.
append
(
'T'
)
self
.
_ignores
.
append
(
'use_cudnn_on_gpu'
)
self
.
_ignores
.
append
(
'_node_name'
)
self
.
_ignores
.
append
(
'is_training'
)
self
.
_ignores
.
append
(
'_target_layout'
)
# apply custom check
# apply custom check
if
self
.
_custom_check
:
if
self
.
_custom_check
:
func
,
msg
=
self
.
_custom_check
func
,
msg
=
self
.
_custom_check
...
@@ -348,7 +366,8 @@ class AttrCvt(object):
...
@@ -348,7 +366,8 @@ class AttrCvt(object):
new_attrs
=
{}
new_attrs
=
{}
for
k
in
attrs
.
keys
():
for
k
in
attrs
.
keys
():
if
k
in
self
.
_excludes
:
if
k
in
self
.
_excludes
:
raise
NotImplementedError
(
"Attribute {} not supported yet."
.
format
(
k
))
raise
NotImplementedError
(
'Attribute
%
s in operator
%
s is not'
+
' supported.'
,
k
,
op_name
)
elif
k
in
self
.
_disables
:
elif
k
in
self
.
_disables
:
logging
.
warning
(
"Attribute
%
s is disabled in relay.sym.
%
s"
,
k
,
op_name
)
logging
.
warning
(
"Attribute
%
s is disabled in relay.sym.
%
s"
,
k
,
op_name
)
elif
k
in
self
.
_ignores
:
elif
k
in
self
.
_ignores
:
...
@@ -401,6 +420,7 @@ class AttrCvt(object):
...
@@ -401,6 +420,7 @@ class AttrCvt(object):
raise
AttributeError
(
"Required attribute {} not found."
.
format
(
key
))
raise
AttributeError
(
"Required attribute {} not found."
.
format
(
key
))
return
attr
[
key
]
return
attr
[
key
]
def
get_name
(
node
):
def
get_name
(
node
):
name
=
''
name
=
''
if
hasattr
(
node
,
"name_hint"
):
if
hasattr
(
node
,
"name_hint"
):
...
@@ -410,17 +430,19 @@ def get_name(node):
...
@@ -410,17 +430,19 @@ def get_name(node):
def
infer_type
(
node
):
def
infer_type
(
node
):
"""A method to infer the type of an intermediate node in the relay graph."""
"""A method to infer the type of an intermediate node in the relay graph."""
mod
=
_module
.
Module
.
from_expr
(
node
)
mod
=
node
if
isinstance
(
node
,
_module
.
Module
)
else
_module
.
Module
.
from_expr
(
node
)
mod
=
_transform
.
InferType
()(
mod
)
mod
=
_transform
.
InferType
()(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
node
,
_expr
.
Function
)
else
entry
.
body
return
entry
if
isinstance
(
node
,
_expr
.
Function
)
else
entry
.
body
def
infer_shape
(
inputs
):
def
infer_shape
(
inputs
):
"""A method to get the output shape of an intermediate node in the graph."""
"""A method to get the output shape of an intermediate node in the graph."""
out_type
=
infer_type
(
inputs
)
out_type
=
infer_type
(
inputs
)
out_shapes
=
get_const_tuple
(
out_type
.
checked_type
.
shape
)
out_shapes
=
get_const_tuple
(
out_type
.
checked_type
.
shape
)
return
out_shapes
return
out_shapes
def
infer_channels
(
inputs
,
transpose
=
False
):
def
infer_channels
(
inputs
,
transpose
=
False
):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
these attributes. We check the shape of weights provided to get the number.
these attributes. We check the shape of weights provided to get the number.
...
@@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False):
...
@@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False):
channels
=
out_shapes
[
0
][
0
]
if
not
transpose
else
out_shapes
[
0
][
1
]
channels
=
out_shapes
[
0
][
0
]
if
not
transpose
else
out_shapes
[
0
][
1
]
return
channels
return
channels
def
new_var
(
name_hint
,
def
new_var
(
name_hint
,
type_annotation
=
None
,
type_annotation
=
None
,
shape
=
None
,
shape
=
None
,
dtype
=
"float32"
):
dtype
=
"float32"
):
return
_expr
.
var
(
name_hint
,
type_annotation
,
shape
,
dtype
)
return
_expr
.
var
(
name_hint
,
type_annotation
,
shape
,
dtype
)
class
Renamer
(
object
):
class
Renamer
(
object
):
"""A simply renamer for operators.
"""A simply renamer for operators.
...
...
python/tvm/relay/frontend/mxnet.py
View file @
be8fa6ac
...
@@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs
...
@@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs
import
json
import
json
import
tvm
import
tvm
from
..
import
analysis
,
transform
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
op
as
_op
from
..
import
op
as
_op
from
..
import
module
as
_module
from
..
import
module
as
_module
from
...
import
nd
as
_nd
from
...
import
nd
as
_nd
from
.common
import
StrAttrsDict
from
.common
import
StrAttrsDict
from
.common
import
infer_type
as
_infer_type
from
.nnvm_common
import
_rename
,
_binop_scalar
,
_rbinop_scalar
,
_reduce
from
.nnvm_common
import
_rename
,
_binop_scalar
,
_rbinop_scalar
,
_reduce
from
.nnvm_common
import
_arg_reduce
,
_init_op
,
_softmax_op
,
_cast
from
.nnvm_common
import
_arg_reduce
,
_init_op
,
_softmax_op
,
_cast
from
.nnvm_common
import
_clip
,
_transpose
,
_upsampling
from
.nnvm_common
import
_clip
,
_transpose
,
_upsampling
...
@@ -41,13 +42,6 @@ _activation_map = {
...
@@ -41,13 +42,6 @@ _activation_map = {
"relu"
:
_op
.
nn
.
relu
"relu"
:
_op
.
nn
.
relu
}
}
def
_infer_type
(
node
):
"""A method to infer the type of an intermediate node in the relay graph."""
mod
=
_module
.
Module
.
from_expr
(
node
)
mod
=
transform
.
InferType
()(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
node
,
_expr
.
Function
)
else
entry
.
body
def
_mx_fully_connected
(
inputs
,
attrs
):
def
_mx_fully_connected
(
inputs
,
attrs
):
import
mxnet
as
mx
import
mxnet
as
mx
units
=
attrs
.
get_int
(
"num_hidden"
)
units
=
attrs
.
get_int
(
"num_hidden"
)
...
...
python/tvm/relay/frontend/tensorflow.py
View file @
be8fa6ac
...
@@ -19,20 +19,21 @@
...
@@ -19,20 +19,21 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
__future__
import
print_function
from
__future__
import
print_function
import
logging
import
warnings
import
warnings
from
collections
import
defaultdict
from
collections
import
defaultdict
# Numpy support
# Numpy support
import
numpy
as
np
import
numpy
as
np
import
tvm
import
tvm
from
topi.util
import
get_const_tuple
from
..
import
analysis
from
..
import
analysis
from
..
import
transform
as
_transform
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
op
as
_op
from
..
import
op
as
_op
from
..expr_functor
import
ExprMutator
from
..expr_functor
import
ExprMutator
from
..
import
module
as
_module
from
..
import
module
as
_module
from
.common
import
AttrCvt
,
get_relay_op
from
.common
import
infer_type
as
_infer_type
from
.common
import
infer_shape
as
_infer_shape
from
.common
import
infer_channels
as
_infer_channels
__all__
=
[
'from_tensorflow'
]
__all__
=
[
'from_tensorflow'
]
...
@@ -50,140 +51,6 @@ def _infer_value(input_val, params):
...
@@ -50,140 +51,6 @@ def _infer_value(input_val, params):
m
.
run
()
m
.
run
()
return
m
.
get_output
(
0
)
return
m
.
get_output
(
0
)
def
_get_relay_op
(
op_name
):
ops
=
[
_op
,
_op
.
nn
,
_op
.
image
,
_op
.
vision
]
for
operator
in
ops
:
try
:
op
=
getattr
(
operator
,
op_name
)
return
op
except
AttributeError
:
continue
raise
tvm
.
error
.
OpNotImplemented
(
'Operator {} is not supported for frontend TensorFlow.'
.
format
(
op_name
))
class
AttrCvt
(
object
):
"""Common attribute converter. An AttrConverter instance is a callable:
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
```
Parameters
----------
op_name : str or callable
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provided, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occurred.
disables : list
A list of attributes that is disabled in relay. Log warnings.
ignores : list
A list of attributes that is ignored in relay. Debug level logging.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
"""
def
__init__
(
self
,
op_name
,
transforms
=
None
,
excludes
=
None
,
disables
=
None
,
ignores
=
None
,
extras
=
None
,
custom_check
=
None
):
self
.
_op_name
=
op_name
self
.
_transforms
=
transforms
if
transforms
else
{}
self
.
_excludes
=
excludes
if
excludes
else
[]
self
.
_disables
=
disables
if
disables
else
[]
self
.
_ignores
=
ignores
if
ignores
else
[]
self
.
_extras
=
extras
if
extras
else
{}
self
.
_custom_check
=
custom_check
def
__call__
(
self
,
inputs
,
attrs
,
*
args
):
self
.
_ignores
.
append
(
'_output_shapes'
)
self
.
_ignores
.
append
(
'_input_shapes'
)
self
.
_ignores
.
append
(
'T'
)
self
.
_ignores
.
append
(
'use_cudnn_on_gpu'
)
self
.
_ignores
.
append
(
'_node_name'
)
self
.
_ignores
.
append
(
'is_training'
)
self
.
_ignores
.
append
(
'_target_layout'
)
# apply custom check
if
self
.
_custom_check
:
func
,
msg
=
self
.
_custom_check
if
not
func
(
attrs
):
raise
RuntimeError
(
"Check failed: {}"
.
format
(
msg
))
# get new op_name
if
isinstance
(
self
.
_op_name
,
str
):
op_name
=
self
.
_op_name
else
:
assert
callable
(
self
.
_op_name
),
"op_name can either be string or callable"
op_name
=
self
.
_op_name
(
attrs
)
# convert attributes
new_attrs
=
{}
for
k
in
attrs
.
keys
():
if
k
in
self
.
_excludes
:
raise
tvm
.
error
.
OpAttributeUnImplemented
(
'Attribute {} in operator {} is not supported.'
.
format
(
k
,
op_name
))
elif
k
in
self
.
_disables
:
logging
.
warning
(
"Attribute
%
s is disabled in relay.
%
s"
,
k
,
op_name
)
elif
k
in
self
.
_ignores
:
logging
.
debug
(
"Attribute
%
s is ignored in relay.
%
s"
,
k
,
op_name
)
elif
k
in
self
.
_transforms
:
new_name
,
defaults
,
transform
=
self
.
_parse_default
(
self
.
_transforms
[
k
])
if
defaults
is
None
:
new_attr
=
self
.
_required_attr
(
attrs
,
k
)
else
:
new_attr
=
attrs
.
get
(
k
,
None
)
if
new_attr
is
None
:
new_attrs
[
new_name
]
=
defaults
else
:
new_attrs
[
new_name
]
=
transform
(
new_attr
)
else
:
# copy
new_attrs
[
k
]
=
attrs
[
k
]
# add extras
new_attrs
.
update
(
self
.
_extras
)
return
_get_relay_op
(
op_name
)(
*
inputs
,
**
new_attrs
)
def
_parse_default
(
self
,
target
):
"""Helper function to parse default values."""
if
not
isinstance
(
target
,
(
list
,
tuple
)):
k
,
v
,
t
=
target
,
None
,
lambda
x
:
x
elif
len
(
target
)
==
1
:
k
,
v
,
t
=
target
[
0
],
None
,
lambda
x
:
x
elif
len
(
target
)
==
2
:
k
,
v
,
t
=
target
[
0
],
target
[
1
],
lambda
x
:
x
elif
len
(
target
)
>
2
:
k
,
v
,
t
=
target
[
0
],
target
[
1
],
target
[
2
]
else
:
k
=
None
# should raise
if
not
isinstance
(
k
,
str
):
msg
=
"{} is not a valid target, (name, default) expected."
.
format
(
target
)
raise
ValueError
(
msg
)
return
k
,
v
,
t
def
_parse_bool
(
self
,
value
):
"""Helper function to parse default boolean values."""
if
isinstance
(
value
,
str
):
return
value
.
strip
()
.
lower
()
in
[
'true'
,
'1'
,
't'
,
'y'
,
'yes'
]
return
bool
(
value
)
def
_required_attr
(
self
,
attr
,
key
):
"""Wrapper for getting required attributes."""
assert
isinstance
(
attr
,
dict
)
if
key
not
in
attr
:
raise
tvm
.
error
.
OpAttributeRequired
(
'Attribute {} not found in operator {}'
.
format
(
key
,
self
.
_op_name
))
return
attr
[
key
]
def
_get_pad_pair
(
input1d
,
kernel1d
,
stride1d
):
def
_get_pad_pair
(
input1d
,
kernel1d
,
stride1d
):
if
input1d
%
stride1d
==
0
:
if
input1d
%
stride1d
==
0
:
pad
=
max
(
kernel1d
-
stride1d
,
0
)
pad
=
max
(
kernel1d
-
stride1d
,
0
)
...
@@ -195,12 +62,6 @@ def _get_pad_pair(input1d, kernel1d, stride1d):
...
@@ -195,12 +62,6 @@ def _get_pad_pair(input1d, kernel1d, stride1d):
return
[
pad_before
,
pad_after
]
return
[
pad_before
,
pad_after
]
def
_get_name_hint
(
node
):
name
=
''
if
hasattr
(
node
,
"name_hint"
):
name
=
node
.
name_hint
return
name
def
_math_name_picker
(
surfix
):
def
_math_name_picker
(
surfix
):
def
_impl
(
attr
):
def
_impl
(
attr
):
return
'broadcast_'
+
surfix
return
'broadcast_'
+
surfix
...
@@ -222,30 +83,6 @@ def _dimension_constraint():
...
@@ -222,30 +83,6 @@ def _dimension_constraint():
return
False
return
False
return
_dim_check
,
"Only 2d kernel supported."
return
_dim_check
,
"Only 2d kernel supported."
def
_infer_channels
(
node
,
params
,
transpose
=
False
):
"""A hack for getting 'channels' or 'units' since tensorflow don't provide
these attributes. We check the shape of weights provided to get the number.
"""
out_shape
=
_infer_shape
(
node
,
params
)
channels
=
out_shape
[
0
]
if
not
transpose
else
out_shape
[
1
]
return
channels
def
_infer_out_shapes
(
inputs
,
params
):
"""A method to get the output shape of intermediate nodes in the relay graph."""
return
[
_infer_shape
(
inputs
,
params
)]
def
_infer_type
(
node
):
"""A method to infer the type of an intermediate node in the relay graph."""
mod
=
_module
.
Module
.
from_expr
(
node
)
mod
=
_transform
.
InferType
()(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
node
,
_expr
.
Function
)
else
entry
.
body
def
_infer_shape
(
node
,
params
=
None
):
"""A method to get the output shape of an intermediate node in the relay graph."""
out_type
=
_infer_type
(
node
)
return
get_const_tuple
(
out_type
.
checked_type
.
shape
)
def
_get_param
(
params
,
input_node
):
def
_get_param
(
params
,
input_node
):
return
params
.
pop
(
input_node
.
name_hint
)
.
asnumpy
()
return
params
.
pop
(
input_node
.
name_hint
)
.
asnumpy
()
...
@@ -280,7 +117,7 @@ def _argx(func, func_name):
...
@@ -280,7 +117,7 @@ def _argx(func, func_name):
def
_elemwise
(
name
):
def
_elemwise
(
name
):
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
assert
len
(
inputs
)
==
2
,
"{} take 2 inputs, {} given"
.
format
(
name
,
len
(
inputs
))
assert
len
(
inputs
)
==
2
,
"{} take 2 inputs, {} given"
.
format
(
name
,
len
(
inputs
))
return
_
get_relay_op
(
name
)(
*
inputs
)
return
get_relay_op
(
name
)(
*
inputs
)
return
_impl
return
_impl
def
_pooling
(
name
):
def
_pooling
(
name
):
...
@@ -300,7 +137,7 @@ def _pooling(name):
...
@@ -300,7 +137,7 @@ def _pooling(name):
else
:
else
:
msg
=
'Value {} of attribute "data_format" of operator Pooling '
\
msg
=
'Value {} of attribute "data_format" of operator Pooling '
\
'is not valid.'
'is not valid.'
raise
tvm
.
error
.
OpAttributeInvalid
(
msg
.
format
(
attr
s
[
'data_format'
]))
raise
tvm
.
error
.
OpAttributeInvalid
(
msg
.
format
(
attr
[
'data_format'
]))
if
attr
[
'_target_layout'
]
==
"NCHW"
and
attr
[
'data_format'
]
==
"NHWC"
:
if
attr
[
'_target_layout'
]
==
"NCHW"
and
attr
[
'data_format'
]
==
"NHWC"
:
tmp_shape
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
tmp_shape
=
attr
[
'_input_shapes'
][
inputs
[
0
]]
...
@@ -539,7 +376,7 @@ def _crop_and_resize():
...
@@ -539,7 +376,7 @@ def _crop_and_resize():
res_crop
=
_op
.
strided_slice
(
inputs
[
0
],
begin
=
begin
,
end
=
size
)
res_crop
=
_op
.
strided_slice
(
inputs
[
0
],
begin
=
begin
,
end
=
size
)
# 2) Resize
# 2) Resize
res_resize
=
_
get_relay_op
(
'resize'
)(
res_crop
,
**
attrs
)
res_resize
=
get_relay_op
(
'resize'
)(
res_crop
,
**
attrs
)
out
=
_op
.
concatenate
([
out
,
res_resize
],
axis
=
0
)
if
out
else
res_resize
out
=
_op
.
concatenate
([
out
,
res_resize
],
axis
=
0
)
if
out
else
res_resize
return
out
return
out
return
_impl
return
_impl
...
@@ -598,7 +435,7 @@ def _check_numerics():
...
@@ -598,7 +435,7 @@ def _check_numerics():
def
_matmul
():
def
_matmul
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
channels
=
_infer_channels
(
inputs
[
1
],
params
,
not
attr
[
'transpose_b'
])
channels
=
_infer_channels
(
inputs
[
1
],
not
attr
[
'transpose_b'
])
if
attr
[
'transpose_a'
]:
if
attr
[
'transpose_a'
]:
inputs
[
0
]
=
_op
.
transpose
(
inputs
[
0
],
axes
=
(
1
,
0
))
inputs
[
0
]
=
_op
.
transpose
(
inputs
[
0
],
axes
=
(
1
,
0
))
if
not
attr
[
'transpose_b'
]:
if
not
attr
[
'transpose_b'
]:
...
@@ -615,15 +452,10 @@ def _batch_matmul():
...
@@ -615,15 +452,10 @@ def _batch_matmul():
adj_y
=
attr
[
'adj_y'
]
adj_y
=
attr
[
'adj_y'
]
input_x
=
_op
.
transpose
(
inputs
[
0
],
axes
=
[
0
,
2
,
1
])
if
adj_x
else
inputs
[
0
]
input_x
=
_op
.
transpose
(
inputs
[
0
],
axes
=
[
0
,
2
,
1
])
if
adj_x
else
inputs
[
0
]
input_y
=
_op
.
transpose
(
inputs
[
1
],
axes
=
[
0
,
2
,
1
])
if
not
adj_y
else
inputs
[
1
]
input_y
=
_op
.
transpose
(
inputs
[
1
],
axes
=
[
0
,
2
,
1
])
if
not
adj_y
else
inputs
[
1
]
ret
=
_
get_relay_op
(
'batch_matmul'
)(
input_x
,
input_y
)
ret
=
get_relay_op
(
'batch_matmul'
)(
input_x
,
input_y
)
return
ret
return
ret
return
_impl
return
_impl
def
_undef
():
def
_impl
(
inputs
,
attr
,
params
):
return
_sym
.
__undef__
()
return
_impl
def
_identity
():
def
_identity
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
return
inputs
[
0
]
return
inputs
[
0
]
...
@@ -985,7 +817,7 @@ def _stridedSlice():
...
@@ -985,7 +817,7 @@ def _stridedSlice():
if
begin_mask
or
end_mask
or
ellipsis_mask
or
new_axis_mask
or
shrink_axis_mask
:
if
begin_mask
or
end_mask
or
ellipsis_mask
or
new_axis_mask
or
shrink_axis_mask
:
begin
,
end
,
stride
,
fshape_indices
=
_transform_mask
(
stride_dim
,
ellipsis_mask
)
begin
,
end
,
stride
,
fshape_indices
=
_transform_mask
(
stride_dim
,
ellipsis_mask
)
out
=
_op
.
strided_slice
(
inputs
[
0
],
begin
=
begin
,
end
=
end
,
strides
=
stride
)
out
=
_op
.
strided_slice
(
inputs
[
0
],
begin
=
begin
,
end
=
end
,
strides
=
stride
)
out_shape
=
_infer_shape
(
out
,
params
)
out_shape
=
_infer_shape
(
out
)
if
not
fshape_indices
:
if
not
fshape_indices
:
fshape_indices
=
range
(
len
(
out_shape
))
fshape_indices
=
range
(
len
(
out_shape
))
...
@@ -1178,8 +1010,8 @@ def _softplus():
...
@@ -1178,8 +1010,8 @@ def _softplus():
exp_out
=
AttrCvt
(
'exp'
)(
inputs
,
attr
)
exp_out
=
AttrCvt
(
'exp'
)(
inputs
,
attr
)
inputs
.
append
(
tvm
.
relay
.
const
(
1
,
attr
[
'T'
]
.
name
))
inputs
.
append
(
tvm
.
relay
.
const
(
1
,
attr
[
'T'
]
.
name
))
rh
=
tvm
.
relay
.
const
(
1
,
attr
[
'T'
]
.
name
)
rh
=
tvm
.
relay
.
const
(
1
,
attr
[
'T'
]
.
name
)
add_out
=
_
get_relay_op
(
'add'
)(
exp_out
,
rh
)
add_out
=
get_relay_op
(
'add'
)(
exp_out
,
rh
)
return
_
get_relay_op
(
'log'
)(
add_out
)
return
get_relay_op
(
'log'
)(
add_out
)
return
_impl
return
_impl
def
_topk
():
def
_topk
():
...
@@ -1200,7 +1032,7 @@ def _floordiv():
...
@@ -1200,7 +1032,7 @@ def _floordiv():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
div
=
AttrCvt
(
'divide'
)(
inputs
,
attr
)
div
=
AttrCvt
(
'divide'
)(
inputs
,
attr
)
return
_
get_relay_op
(
'floor'
)(
div
)
return
get_relay_op
(
'floor'
)(
div
)
return
_impl
return
_impl
def
_logical
(
name
):
def
_logical
(
name
):
...
@@ -1234,7 +1066,7 @@ def _space_to_batch_nd():
...
@@ -1234,7 +1066,7 @@ def _space_to_batch_nd():
axes
=
[
2
*
i
+
2
for
i
in
range
(
M
)]
+
[
0
]
+
[
2
*
i
+
1
for
i
in
range
(
M
)]
+
\
axes
=
[
2
*
i
+
2
for
i
in
range
(
M
)]
+
[
0
]
+
[
2
*
i
+
1
for
i
in
range
(
M
)]
+
\
list
(
range
(
1
+
2
*
M
,
1
+
2
*
M
+
remaining_shape_length
))
list
(
range
(
1
+
2
*
M
,
1
+
2
*
M
+
remaining_shape_length
))
permuted_reshaped_padded
=
tvm
.
relay
.
transpose
(
reshaped_padded
,
axes
=
axes
)
permuted_reshaped_padded
=
tvm
.
relay
.
transpose
(
reshaped_padded
,
axes
=
axes
)
permuted_reshaped_padded_shape
=
_infer_shape
(
permuted_reshaped_padded
,
params
)
permuted_reshaped_padded_shape
=
_infer_shape
(
permuted_reshaped_padded
)
# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# producing an output tensor of shape:
# producing an output tensor of shape:
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
...
@@ -1277,7 +1109,7 @@ def _batch_to_space_nd():
...
@@ -1277,7 +1109,7 @@ def _batch_to_space_nd():
# [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
# [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
# ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
# ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
# input_shape[M+1], ..., input_shape[N-1]]
# input_shape[M+1], ..., input_shape[N-1]]
reshaped_permuted_shape
=
_infer_shape
(
reshaped_permuted
,
params
)
reshaped_permuted_shape
=
_infer_shape
(
reshaped_permuted
)
cropped
=
reshaped_permuted
cropped
=
reshaped_permuted
for
axis
in
range
(
1
,
M
+
1
):
for
axis
in
range
(
1
,
M
+
1
):
crop
=
crops
[
axis
-
1
]
crop
=
crops
[
axis
-
1
]
...
@@ -1305,8 +1137,8 @@ def _log1p():
...
@@ -1305,8 +1137,8 @@ def _log1p():
# op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p
# op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
one
=
tvm
.
relay
.
const
(
1
,
attr
[
'T'
]
.
name
)
one
=
tvm
.
relay
.
const
(
1
,
attr
[
'T'
]
.
name
)
add_out
=
_
get_relay_op
(
'add'
)(
inputs
[
0
],
one
)
add_out
=
get_relay_op
(
'add'
)(
inputs
[
0
],
one
)
return
_
get_relay_op
(
'log'
)(
add_out
)
return
get_relay_op
(
'log'
)(
add_out
)
return
_impl
return
_impl
# compatible operators that do NOT require any conversion.
# compatible operators that do NOT require any conversion.
...
@@ -2399,7 +2231,7 @@ class GraphProto(object):
...
@@ -2399,7 +2231,7 @@ class GraphProto(object):
convert_map
=
convert_map
if
convert_map
else
_convert_map
convert_map
=
convert_map
if
convert_map
else
_convert_map
convert_map_rnn
=
_convert_map_rnn
convert_map_rnn
=
_convert_map_rnn
if
op_name
in
identity_list
:
if
op_name
in
identity_list
:
sym
=
_
get_relay_op
(
op_name
)(
*
inputs
,
**
attrs
)
sym
=
get_relay_op
(
op_name
)(
*
inputs
,
**
attrs
)
elif
op_name
in
convert_map
:
elif
op_name
in
convert_map
:
sym
=
convert_map
[
op_name
](
inputs
,
attrs
,
self
.
_params
)
sym
=
convert_map
[
op_name
](
inputs
,
attrs
,
self
.
_params
)
elif
op_name
in
convert_map_rnn
:
elif
op_name
in
convert_map_rnn
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment