Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b236565e
Unverified
Commit
b236565e
authored
Apr 10, 2020
by
Samuel
Committed by
GitHub
Apr 11, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PYTORCH]Repeat, Reciprocal & Reshape Op support (#5280)
parent
0d1babce
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
117 additions
and
0 deletions
+117
-0
python/tvm/relay/frontend/pytorch.py
+42
-0
tests/python/frontend/pytorch/test_forward.py
+75
-0
No files found.
python/tvm/relay/frontend/pytorch.py
View file @
b236565e
...
...
@@ -154,6 +154,34 @@ def _select():
return
_op
.
transform
.
take
(
data
,
index
,
axis
=
dim
)
return
_impl
def
_reciprocal
():
def
_impl
(
inputs
,
input_types
):
data
=
inputs
[
0
]
return
_expr
.
const
(
1.0
)
/
data
return
_impl
def
_repeat
():
def
_impl
(
inputs
,
input_types
):
data
=
inputs
[
0
]
reps
=
_get_dims
(
inputs
[
1
])
return
_op
.
transform
.
tile
(
data
,
reps
=
reps
)
return
_impl
def
_repeat_interleave
():
def
_impl
(
inputs
,
input_types
):
data
=
inputs
[
0
]
if
isinstance
(
inputs
[
1
],
int
):
repeats
=
inputs
[
1
]
axis
=
inputs
[
2
]
else
:
msg
=
"Only repeat with one value as repeat is currently supported."
raise
AssertionError
(
msg
)
if
axis
is
None
:
# Flatten the data if no axis is given from torch
data
=
_op
.
transform
.
reshape
(
data
,
[
-
1
])
axis
=
0
return
_op
.
transform
.
repeat
(
data
,
repeats
=
repeats
,
axis
=
axis
)
return
_impl
def
_ones
():
def
_impl
(
inputs
,
input_types
):
data
=
inputs
[
0
]
...
...
@@ -675,6 +703,16 @@ def _view():
return
_op
.
transform
.
reshape
(
data
,
new_shape
)
return
_impl
def
_reshape
():
def
_impl
(
inputs
,
input_types
):
data
=
inputs
[
0
]
if
isinstance
(
inputs
[
1
],
list
):
new_shape
=
inputs
[
1
]
else
:
new_shape
=
_infer_shape
(
inputs
[
1
])
return
_op
.
transform
.
reshape
(
data
,
new_shape
)
return
_impl
def
_clone
():
def
_impl
(
inputs
,
input_types
):
data
=
inputs
[
0
]
...
...
@@ -1082,6 +1120,9 @@ _convert_map = {
"aten::div_"
:
_elemwise
(
"divide"
),
"aten::ones"
:
_ones
(),
"aten::zeros"
:
_zeros
(),
"aten::reciprocal"
:
_reciprocal
(),
"aten::repeat"
:
_repeat
(),
"aten::repeat_interleave"
:
_repeat_interleave
(),
"aten::to"
:
_to
(),
"aten::squeeze"
:
_squeeze
(),
"aten::unsqueeze"
:
_unsqueeze
(),
...
...
@@ -1122,6 +1163,7 @@ _convert_map = {
"aten::addmm"
:
_dense
(),
"aten::size"
:
_size
(),
"aten::view"
:
_view
(),
"aten::reshape"
:
_reshape
(),
"aten::clone"
:
_clone
(),
"aten::log_softmax"
:
_log_softmax
(),
"aten::sigmoid"
:
_sigmoid
(),
...
...
tests/python/frontend/pytorch/test_forward.py
View file @
b236565e
...
...
@@ -293,6 +293,61 @@ def test_forward_multiply():
verify_model
(
Multiply3
()
.
float
()
.
eval
(),
input_data
=
input_data
)
verify_model
(
Multiply4
()
.
float
()
.
eval
(),
input_data
=
input_data
)
def
test_forward_reciprocal
():
torch
.
set_grad_enabled
(
False
)
input_shape
=
[
2
,
1
,
10
,
1
,
10
]
class
Reciprocal1
(
Module
):
def
forward
(
self
,
*
args
):
return
args
[
0
]
.
reciprocal
()
input_data
=
torch
.
rand
(
input_shape
)
.
float
()
verify_model
(
Reciprocal1
()
.
float
()
.
eval
(),
input_data
=
input_data
)
def
test_forward_repeat
():
torch
.
set_grad_enabled
(
False
)
input_shape
=
[
1
,
3
]
class
Repeat1
(
Module
):
def
forward
(
self
,
*
args
):
return
args
[
0
]
.
repeat
(
1
,
1
)
class
Repeat2
(
Module
):
def
forward
(
self
,
*
args
):
return
args
[
0
]
.
repeat
(
4
,
2
)
class
Repeat3
(
Module
):
def
forward
(
self
,
*
args
):
return
args
[
0
]
.
repeat
(
4
,
2
,
1
)
input_data
=
torch
.
rand
(
input_shape
)
.
float
()
verify_model
(
Repeat1
()
.
float
()
.
eval
(),
input_data
=
input_data
)
verify_model
(
Repeat2
()
.
float
()
.
eval
(),
input_data
=
input_data
)
verify_model
(
Repeat3
()
.
float
()
.
eval
(),
input_data
=
input_data
)
def
test_forward_repeat_interleave
():
torch
.
set_grad_enabled
(
False
)
input_shape
=
[
2
,
2
,
3
]
class
RepeatInterleave1
(
Module
):
def
forward
(
self
,
*
args
):
return
args
[
0
]
.
repeat_interleave
(
2
)
class
RepeatInterleave2
(
Module
):
def
forward
(
self
,
*
args
):
return
args
[
0
]
.
repeat_interleave
(
3
,
dim
=
0
)
class
RepeatInterleave3
(
Module
):
def
forward
(
self
,
*
args
):
return
args
[
0
]
.
repeat_interleave
(
2
,
dim
=
1
)
class
RepeatInterleave4
(
Module
):
def
forward
(
self
,
*
args
):
return
args
[
0
]
.
repeat_interleave
(
4
,
dim
=
2
)
input_data
=
torch
.
rand
(
input_shape
)
.
float
()
verify_model
(
RepeatInterleave1
()
.
float
()
.
eval
(),
input_data
=
input_data
)
verify_model
(
RepeatInterleave2
()
.
float
()
.
eval
(),
input_data
=
input_data
)
verify_model
(
RepeatInterleave3
()
.
float
()
.
eval
(),
input_data
=
input_data
)
verify_model
(
RepeatInterleave4
()
.
float
()
.
eval
(),
input_data
=
input_data
)
def
test_forward_unsqueeze
():
torch
.
set_grad_enabled
(
False
)
input_shape
=
[
10
,
10
]
...
...
@@ -600,6 +655,22 @@ def test_forward_layernorm():
init_weight
(
ln
.
eval
())
verify_model
(
ln
.
eval
(),
input_data
=
inp
)
def
test_forward_reshape
():
torch
.
set_grad_enabled
(
False
)
input_shape
=
[
2
,
1
,
10
,
1
,
10
]
new_shape
=
[
2
,
1
,
10
,
10
]
class
Reshape1
(
Module
):
def
forward
(
self
,
*
args
):
return
args
[
0
]
.
reshape
(
new_shape
)
class
Reshape2
(
Module
):
def
forward
(
self
,
*
args
):
return
args
[
0
]
.
reshape
([
-
1
])
input_data
=
torch
.
rand
(
input_shape
)
.
float
()
verify_model
(
Reshape1
()
.
float
()
.
eval
(),
input_data
=
input_data
)
verify_model
(
Reshape2
()
.
float
()
.
eval
(),
input_data
=
input_data
)
def
test_forward_transpose
():
torch
.
set_grad_enabled
(
False
)
input_shape
=
[
1
,
3
,
10
,
10
]
...
...
@@ -1151,6 +1222,10 @@ if __name__ == "__main__":
test_forward_add
()
test_forward_subtract
()
test_forward_multiply
()
test_forward_reshape
()
test_forward_reciprocal
()
test_forward_repeat
()
test_forward_repeat_interleave
()
test_forward_squeeze
()
test_forward_unsqueeze
()
test_forward_concatenate
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment