Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
78a0f47b
Commit
78a0f47b
authored
5 years ago
by
hlu1
Committed by
Tianqi Chen
5 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARM] Fix concat (#3061)
parent
24fe04f8
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
2 deletions
+40
-2
python/tvm/relay/op/_transform.py
+2
-1
python/tvm/relay/op/op.py
+7
-0
topi/python/topi/arm_cpu/injective.py
+29
-0
topi/tests/python/test_topi_transform.py
+2
-1
No files found.
python/tvm/relay/op/_transform.py
View file @
78a0f47b
...
...
@@ -23,6 +23,7 @@ from .op import schedule_injective, OpPattern
schedule_injective
=
_reg
.
schedule_injective
schedule_broadcast
=
_reg
.
schedule_injective
schedule_concatenate
=
_reg
.
schedule_concatenate
_reg
.
register_schedule
(
"collapse_sum_like"
,
_schedule_reduce
)
...
...
@@ -46,7 +47,7 @@ _reg.register_schedule("take", schedule_injective)
_reg
.
register_schedule
(
"transpose"
,
schedule_injective
)
_reg
.
register_schedule
(
"where"
,
schedule_broadcast
)
_reg
.
register_schedule
(
"stack"
,
schedule_injective
)
_reg
.
register_schedule
(
"concatenate"
,
schedule_
injectiv
e
)
_reg
.
register_schedule
(
"concatenate"
,
schedule_
concatenat
e
)
_reg
.
register_schedule
(
"_contrib_reverse_reshape"
,
schedule_injective
)
_reg
.
register_schedule
(
"gather_nd"
,
schedule_injective
)
...
...
This diff is collapsed.
Click to expand it.
python/tvm/relay/op/op.py
View file @
78a0f47b
...
...
@@ -219,6 +219,13 @@ def schedule_injective(attrs, outputs, target):
with
target
:
return
topi
.
generic
.
schedule_injective
(
outputs
)
def
schedule_concatenate
(
attrs
,
outputs
,
target
):
"""Generic schedule for concatinate."""
with
target
:
return
topi
.
generic
.
schedule_concatenate
(
outputs
)
__DEBUG_COUNTER__
=
0
def
debug
(
expr
,
debug_func
=
None
):
...
...
This diff is collapsed.
Click to expand it.
topi/python/topi/arm_cpu/injective.py
View file @
78a0f47b
...
...
@@ -51,3 +51,32 @@ def schedule_injective(outs):
elif
len
(
s
[
x
]
.
op
.
axis
)
>=
2
:
s
[
x
]
.
parallel
(
s
[
x
]
.
op
.
axis
[
0
])
return
s
@generic.schedule_concatenate.register
([
"arm_cpu"
])
def
schedule_concatenate
(
outs
):
"""Schedule for concatenate op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
x
=
outs
[
0
]
tvm
.
schedule
.
AutoInlineInjective
(
s
)
if
len
(
s
[
x
]
.
op
.
axis
)
>=
4
:
fused
=
s
[
x
]
.
fuse
(
s
[
x
]
.
op
.
axis
[
0
],
s
[
x
]
.
op
.
axis
[
1
],
s
[
x
]
.
op
.
axis
[
2
])
s
[
x
]
.
parallel
(
fused
)
elif
len
(
s
[
x
]
.
op
.
axis
)
>=
3
:
fused
=
s
[
x
]
.
fuse
(
s
[
x
]
.
op
.
axis
[
0
],
s
[
x
]
.
op
.
axis
[
1
])
s
[
x
]
.
parallel
(
fused
)
elif
len
(
s
[
x
]
.
op
.
axis
)
>=
2
:
s
[
x
]
.
parallel
(
s
[
x
]
.
op
.
axis
[
0
])
return
s
This diff is collapsed.
Click to expand it.
topi/tests/python/test_topi_transform.py
View file @
78a0f47b
...
...
@@ -127,7 +127,7 @@ def verify_concatenate(shapes, axis):
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_
injectiv
e
(
out_tensor
)
s
=
topi
.
generic
.
schedule_
concatenat
e
(
out_tensor
)
foo
=
tvm
.
build
(
s
,
tensor_l
+
[
out_tensor
],
device
,
name
=
"concatenate"
)
data_npys
=
[
np
.
random
.
normal
(
size
=
shape
)
.
astype
(
tensor_l
[
0
]
.
dtype
)
for
shape
in
shapes
]
...
...
@@ -476,6 +476,7 @@ def test_concatenate():
(
12
,
6
,
7
,
3
),
(
8
,
6
,
7
,
3
),
(
2
,
6
,
7
,
3
)],
0
)
verify_concatenate
([(
1
,
14400
),
(
1
,
2400
),
(
1
,
640
),
(
1
,
240
)],
1
)
def
test_stack
():
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment