Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
c9f9a3f9
Commit
c9f9a3f9
authored
Aug 08, 2018
by
Siju
Committed by
Tianqi Chen
Aug 08, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
l2normalization operator support for tensorflow (#1528)
parent
7ea06e6e
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
3 deletions
+59
-3
nnvm/python/nnvm/frontend/tensorflow.py
+21
-2
nnvm/tests/python/frontend/tensorflow/test_forward.py
+38
-1
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
c9f9a3f9
...
@@ -434,6 +434,21 @@ def _lrn():
...
@@ -434,6 +434,21 @@ def _lrn():
return
AttrCvt
(
op_name
=
'lrn'
)(
new_inputs
,
attr_new
)
return
AttrCvt
(
op_name
=
'lrn'
)(
new_inputs
,
attr_new
)
return
_impl
return
_impl
def
_sum
():
def
_impl
(
inputs
,
attr
,
params
):
axis
=
params
.
pop
(
inputs
[
1
]
.
list_output_names
()[
0
])
.
asnumpy
()
return
AttrCvt
(
op_name
=
'sum'
,
extras
=
{
'axis'
:
axis
},
transforms
=
{
'keep_dims'
:
'keepdims'
},
ignores
=
[
'name'
,
'Tidx'
])(
inputs
[
0
],
attr
)
return
_impl
def
_square
():
def
_impl
(
inputs
,
attr
,
params
):
return
_sym
.
elemwise_mul
(
inputs
[
0
],
inputs
[
0
])
return
_impl
def
_gather_v2
():
def
_gather_v2
():
"Tensorflow now support only gatherv2"
"Tensorflow now support only gatherv2"
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
...
@@ -651,13 +666,17 @@ _convert_map = {
...
@@ -651,13 +666,17 @@ _convert_map = {
'Identity'
:
_identity
(),
'Identity'
:
_identity
(),
'MatMul'
:
_matmul
(),
'MatMul'
:
_matmul
(),
'MaxPool'
:
_pooling
(
'max_pool'
),
'MaxPool'
:
_pooling
(
'max_pool'
),
'Add'
:
_elemwise
(
'add'
),
'Sub'
:
_elemwise
(
'sub'
),
'Mul'
:
_elemwise
(
'mul'
),
'Mul'
:
_elemwise
(
'mul'
),
'Maximum'
:
_elemwise
(
'max'
),
'Minimum'
:
_elemwise
(
'min'
),
'Sum'
:
_sum
(),
'Square'
:
_square
(),
'Relu'
:
AttrCvt
(
'relu'
),
'Relu'
:
AttrCvt
(
'relu'
),
'Reshape'
:
_reshape
(),
'Reshape'
:
_reshape
(),
'ResizeBilinear'
:
_resize_bilinear
(),
'ResizeBilinear'
:
_resize_bilinear
(),
'Softmax'
:
AttrCvt
(
'softmax'
,
{
'axis'
:
(
'axis'
,
1
)}),
'Softmax'
:
AttrCvt
(
'softmax'
,
{
'axis'
:
(
'axis'
,
1
)}),
'Sub'
:
_elemwise
(
'sub'
),
'Add'
:
_elemwise
(
'add'
),
'Rsqrt'
:
_rsqrt
(),
'Rsqrt'
:
_rsqrt
(),
'Squeeze'
:
_squeeze
(),
'Squeeze'
:
_squeeze
(),
'FusedBatchNorm'
:
_fused_batch_norm
(),
'FusedBatchNorm'
:
_fused_batch_norm
(),
...
...
nnvm/tests/python/frontend/tensorflow/test_forward.py
View file @
c9f9a3f9
...
@@ -12,6 +12,7 @@ import tensorflow as tf
...
@@ -12,6 +12,7 @@ import tensorflow as tf
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
graph_util
from
tensorflow.python.framework
import
graph_util
from
tensorflow.python.ops
import
nn_ops
from
tensorflow.python.ops
import
nn_ops
from
tensorflow.python.ops
import
nn
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
gen_array_ops
from
tensorflow.python.ops
import
gen_array_ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
math_ops
...
@@ -948,7 +949,6 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
...
@@ -948,7 +949,6 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
sess
,
sess
,
sess
.
graph
.
as_graph_def
(
add_shapes
=
True
),
sess
.
graph
.
as_graph_def
(
add_shapes
=
True
),
[
'lrn'
],)
[
'lrn'
],)
tf_output
=
run_tf_graph
(
sess
,
inp_array
,
'lrn0_data:0'
,
'lrn:0'
)
tf_output
=
run_tf_graph
(
sess
,
inp_array
,
'lrn0_data:0'
,
'lrn:0'
)
tvm_output
=
run_tvm_graph
(
graph_def
,
tvm_output
=
run_tvm_graph
(
graph_def
,
inp_array
,
inp_array
,
...
@@ -959,6 +959,42 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
...
@@ -959,6 +959,42 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
def
test_forward_lrn
():
def
test_forward_lrn
():
_test_lrn
((
1
,
3
,
20
,
20
),
3
,
1
,
1.0
,
1.0
,
0.5
)
_test_lrn
((
1
,
3
,
20
,
20
),
3
,
1
,
1.0
,
1.0
,
0.5
)
#######################################################################
# l2_normalize
# ------------
def
_test_l2_normalize
(
ishape
,
eps
,
axis
):
""" testing l2 normalize (uses max, sum, square, sqrt frontend operators)"""
inp_array
=
np
.
random
.
uniform
(
size
=
ishape
)
.
astype
(
np
.
float32
)
inp_array
.
fill
(
1
)
with
tf
.
Graph
()
.
as_default
():
in1
=
tf
.
placeholder
(
shape
=
inp_array
.
shape
,
dtype
=
inp_array
.
dtype
,
name
=
"Placeholder"
)
nn
.
l2_normalize
(
in1
,
axis
=
axis
,
epsilon
=
eps
,
name
=
None
,
dim
=
None
)
with
tf
.
Session
()
as
sess
:
graph_def
=
tf
.
graph_util
.
convert_variables_to_constants
(
sess
,
sess
.
graph
.
as_graph_def
(
add_shapes
=
True
),
[
'l2_normalize'
],
)
tf_output
=
run_tf_graph
(
sess
,
inp_array
,
'Placeholder:0'
,
'Placeholder:0'
)
tvm_output
=
run_tvm_graph
(
graph_def
,
inp_array
,
"Placeholder"
,
tf_output
.
shape
,
tf_output
.
dtype
)
np
.
testing
.
assert_allclose
(
tf_output
,
tvm_output
,
atol
=
1e-3
,
rtol
=
1e-3
)
sess
.
close
()
def
test_forward_l2_normalize
():
_test_l2_normalize
((
1
,
3
,
20
,
20
),
0.001
,
(
0
,))
#######################################################################
# Main
# Main
# ----
# ----
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -981,3 +1017,4 @@ if __name__ == '__main__':
...
@@ -981,3 +1017,4 @@ if __name__ == '__main__':
test_forward_gather
()
test_forward_gather
()
test_forward_ptb
()
test_forward_ptb
()
test_forward_lrn
()
test_forward_lrn
()
test_forward_l2_normalize
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment