Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
46657ed1
Commit
46657ed1
authored
Oct 06, 2017
by
Leyuan Wang
Committed by
Tianqi Chen
Oct 06, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Conv2d modified for better performance (#516)
* conv2d tweaked for better end-to-end performance * syntax changed
parent
13970eba
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
16 deletions
+25
-16
topi/python/topi/cuda/conv2d_nchw.py
+25
-16
No files found.
topi/python/topi/cuda/conv2d_nchw.py
View file @
46657ed1
...
...
@@ -66,9 +66,20 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
def
conv2d_56_64_128
(
s
,
temp
,
temp_R
,
temp_S
,
Filter_S
,
Out
,
Out_L
,
flag
):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
if
util
.
get_const_int
(
Filter_S
.
shape
[
0
])
==
util
.
get_const_int
(
Filter_S
.
shape
[
1
]):
mark
=
util
.
get_const_int
(
Out
.
shape
[
2
])
*
util
.
get_const_int
(
Out
.
shape
[
3
])
num_thread_x
=
0
if
mark
%
8
==
0
and
mark
%
7
==
0
:
num_thread_x
=
8
num_thread_y
=
8
vthread_x
=
7
else
:
for
i
in
range
(
5
,
mark
):
if
mark
%
i
==
0
and
num_thread_x
==
0
:
vthread_x
=
i
mark
=
mark
//
i
if
mark
%
i
==
0
and
vthread_x
>
0
:
num_thread_x
=
i
break
num_thread_y
=
8
vthread_y
=
2
ifactor
=
8
...
...
@@ -80,20 +91,20 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
thread_yz
=
tvm
.
thread_axis
((
0
,
vthread_y
),
"vthread"
,
name
=
"vy"
)
i
,
oc
,
h
,
w
=
s
[
Out
]
.
op
.
axis
oh
,
ih
=
s
[
Out
]
.
split
(
h
,
nparts
=
vthread_x
)
w
=
s
[
Out
]
.
fuse
(
ih
,
w
)
w
=
s
[
Out
]
.
fuse
(
h
,
w
)
ow
,
iw
=
s
[
Out
]
.
split
(
w
,
factor
=
num_thread_x
*
vthread_x
)
ooc
,
ioc
=
s
[
Out
]
.
split
(
oc
,
factor
=
num_thread_y
*
vthread_y
)
o
w
,
iw
=
s
[
Out
]
.
split
(
w
,
factor
=
num_
thread_x
)
o
iw
,
iiw
=
s
[
Out
]
.
split
(
iw
,
nparts
=
v
thread_x
)
oioc
,
iioc
=
s
[
Out
]
.
split
(
ioc
,
nparts
=
vthread_y
)
s
[
Out
]
.
reorder
(
i
,
ooc
,
o
h
,
oioc
,
ow
,
iioc
,
iw
)
s
[
Out
]
.
bind
(
iw
,
thread_x
)
s
[
Out
]
.
reorder
(
i
,
ooc
,
o
w
,
oioc
,
oiw
,
iioc
,
i
iw
)
s
[
Out
]
.
bind
(
i
i
w
,
thread_x
)
s
[
Out
]
.
bind
(
iioc
,
thread_y
)
s
[
Out
]
.
bind
(
ow
,
thread_xz
)
s
[
Out
]
.
bind
(
o
i
w
,
thread_xz
)
s
[
Out
]
.
bind
(
oioc
,
thread_yz
)
s
[
Out
]
.
bind
(
o
h
,
block_x
)
s
[
Out
]
.
bind
(
o
w
,
block_x
)
s
[
Out
]
.
bind
(
ooc
,
block_y
)
s
[
Out_L
]
.
compute_at
(
s
[
Out
],
iw
)
s
[
Out_L
]
.
compute_at
(
s
[
Out
],
i
i
w
)
# schedule Out_L local write
i
,
oc
,
h
,
w
=
s
[
Out_L
]
.
op
.
axis
...
...
@@ -260,9 +271,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
else
:
# scheduler params
vthread_x
=
min
(
8
,
util
.
get_const_int
(
Out
.
shape
[
2
])
)
vthread_x
=
util
.
get_const_int
(
Out
.
shape
[
2
]
)
num_thread_x
=
16
num_thread_y
=
min
(
8
,
util
.
get_const_int
(
Out
.
shape
[
3
])
)
num_thread_y
=
util
.
get_const_int
(
Out
.
shape
[
3
]
)
ofactor
=
8
block_x
=
tvm
.
thread_axis
(
"blockIdx.x"
)
thread_x
=
tvm
.
thread_axis
((
0
,
num_thread_x
),
"threadIdx.x"
)
...
...
@@ -271,12 +282,10 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
i
,
oc
,
h
,
w
=
s
[
Out
]
.
op
.
axis
ooc
,
ioc
=
s
[
Out
]
.
split
(
oc
,
factor
=
num_thread_x
)
oh
,
ih
=
s
[
Out
]
.
split
(
h
,
factor
=
vthread_x
)
ow
,
iw
=
s
[
Out
]
.
split
(
w
,
factor
=
num_thread_y
)
s
[
Out
]
.
reorder
(
i
,
ooc
,
oh
,
ih
,
ow
,
iw
,
ioc
)
s
[
Out
]
.
reorder
(
i
,
ooc
,
h
,
w
,
ioc
)
s
[
Out
]
.
bind
(
ioc
,
thread_x
)
s
[
Out
]
.
bind
(
i
w
,
thread_y
)
s
[
Out
]
.
bind
(
i
h
,
thread_xz
)
s
[
Out
]
.
bind
(
w
,
thread_y
)
s
[
Out
]
.
bind
(
h
,
thread_xz
)
s
[
Out
]
.
bind
(
ooc
,
block_x
)
s
[
Out_L
]
.
compute_at
(
s
[
Out
],
ioc
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment