Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3f7cce3b
Commit
3f7cce3b
authored
Jun 26, 2018
by
Tianqi Chen
Committed by
GitHub
Jun 26, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SCHEDULE] Fix schedule for big array (#1340)
parent
2f77a127
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
48 additions
and
12 deletions
+48
-12
topi/python/topi/cuda/extern.py
+2
-10
topi/python/topi/cuda/injective.py
+19
-2
topi/python/topi/util.py
+22
-0
topi/tests/python/test_topi_relu.py
+5
-0
No files found.
topi/python/topi/cuda/extern.py
View file @
3f7cce3b
...
@@ -2,15 +2,7 @@
...
@@ -2,15 +2,7 @@
"""Schedule for cudnn and miopen extern op"""
"""Schedule for cudnn and miopen extern op"""
import
tvm
import
tvm
from
..
import
generic
from
..
import
generic
from
.injective
import
_schedule_injective
def
_schedule_output
(
op
,
sch
):
x
=
op
.
output
(
0
)
fused
=
sch
[
x
]
.
fuse
(
*
sch
[
x
]
.
op
.
axis
)
num_thread
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
bx
,
tx
=
sch
[
x
]
.
split
(
fused
,
factor
=
num_thread
)
sch
[
x
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
sch
[
x
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
return
sch
@generic.schedule_extern.register
([
"cuda"
,
"gpu"
])
@generic.schedule_extern.register
([
"cuda"
,
"gpu"
])
...
@@ -36,5 +28,5 @@ def schedule_extern(outs):
...
@@ -36,5 +28,5 @@ def schedule_extern(outs):
for
out
in
outs
:
for
out
in
outs
:
if
isinstance
(
out
.
op
,
tvm
.
tensor
.
ExternOp
):
if
isinstance
(
out
.
op
,
tvm
.
tensor
.
ExternOp
):
continue
continue
_schedule_
output
(
out
.
op
,
s
)
_schedule_
injective
(
out
.
op
,
s
)
return
s
return
s
topi/python/topi/cuda/injective.py
View file @
3f7cce3b
# pylint: disable=invalid-name, unused-variable,
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
"""Schedule for composition of injective operator"""
import
tvm
import
tvm
from
..
import
generic
from
..
import
generic
,
util
def
_schedule_injective
(
op
,
sch
):
def
_schedule_injective
(
op
,
sch
):
x
=
op
.
output
(
0
)
x
=
op
.
output
(
0
)
fused
=
sch
[
x
]
.
fuse
(
*
sch
[
x
]
.
op
.
axis
)
fused
=
sch
[
x
]
.
fuse
(
*
sch
[
x
]
.
op
.
axis
)
num_thread
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
num_thread
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
bx
,
tx
=
sch
[
x
]
.
split
(
fused
,
factor
=
num_thread
)
max_block
=
256
try
:
const_size
=
util
.
get_const_int
(
util
.
prod
(
x
.
shape
))
max_block
=
256
need_block_split
=
const_size
>
max_block
*
num_thread
except
ValueError
:
need_block_split
=
False
if
need_block_split
:
xo
,
xi
=
sch
[
x
]
.
split
(
fused
,
factor
=
num_thread
*
max_block
)
bx
,
tx
=
sch
[
x
]
.
split
(
xi
,
factor
=
num_thread
)
sch
[
x
]
.
reorder
(
bx
,
tx
,
xo
)
sch
[
x
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
sch
[
x
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
sch
[
x
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
sch
[
x
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
else
:
bx
,
tx
=
sch
[
x
]
.
split
(
fused
,
factor
=
num_thread
)
sch
[
x
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
sch
[
x
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
return
sch
return
sch
...
...
topi/python/topi/util.py
View file @
3f7cce3b
...
@@ -2,6 +2,28 @@
...
@@ -2,6 +2,28 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
tvm
import
tvm
def
prod
(
x
):
"""Get the product of every items in the tuple.
Parameters
----------
x: tuple
Input tuple
Returns
-------
value : Expr
The result value
"""
if
not
x
:
return
tvm
.
const
(
1
,
"int32"
)
res
=
x
[
0
]
for
i
in
range
(
1
,
len
(
x
)):
res
=
res
*
x
[
i
]
return
res
def
get_const_int
(
expr
):
def
get_const_int
(
expr
):
"""Verifies expr is integer and get the constant value.
"""Verifies expr is integer and get the constant value.
...
...
topi/tests/python/test_topi_relu.py
View file @
3f7cce3b
...
@@ -71,6 +71,10 @@ def verify_prelu(x, w):
...
@@ -71,6 +71,10 @@ def verify_prelu(x, w):
def
test_relu
():
def
test_relu
():
verify_relu
(
10
,
128
)
verify_relu
(
10
,
128
)
def
test_schedule_big_array
():
verify_relu
(
1024
*
100
,
512
)
def
test_leaky_relu
():
def
test_leaky_relu
():
verify_leaky_relu
(
100
,
0.1
)
verify_leaky_relu
(
100
,
0.1
)
...
@@ -78,6 +82,7 @@ def test_prelu():
...
@@ -78,6 +82,7 @@ def test_prelu():
verify_prelu
((
1
,
3
,
2
,
2
),
(
3
,))
verify_prelu
((
1
,
3
,
2
,
2
),
(
3
,))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_schedule_big_array
()
test_relu
()
test_relu
()
test_leaky_relu
()
test_leaky_relu
()
test_prelu
()
test_prelu
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment