Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
8c5078c9
Commit
8c5078c9
authored
Jul 26, 2018
by
Leyuan Wang
Committed by
Tianqi Chen
Jul 26, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fixed bugs for conv2d (#1465)
parent
b0ef376a
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
13 deletions
+24
-13
topi/python/topi/cuda/conv2d_nchw.py
+19
-9
topi/python/topi/intel_graphics/conv2d.py
+4
-3
topi/tests/python/test_topi_conv2d_nchw.py
+1
-1
No files found.
topi/python/topi/cuda/conv2d_nchw.py
View file @
8c5078c9
...
...
@@ -13,14 +13,16 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
hfactor
=
2
if
flag
>=
96
:
hfactor
=
4
max_threads
=
int
(
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
)
ow_size
=
util
.
get_const_int
(
Out
.
shape
[
3
])
num_thread
=
ow_size
*
hfactor
num_thread
=
min
(
max_threads
,
ow_size
*
hfactor
)
vthread
=
ofactor
block_x
=
tvm
.
thread_axis
(
"blockIdx.x"
)
thread_x
=
tvm
.
thread_axis
((
0
,
num_thread
),
"threadIdx.x"
)
thread_xz
=
tvm
.
thread_axis
((
0
,
vthread
),
"vthread"
,
name
=
"vx"
)
i
,
oc
,
h
,
w
=
s
[
Out
]
.
op
.
axis
if
ow_size
*
hfactor
==
num_thread
:
ooc
,
ioc
=
s
[
Out
]
.
split
(
oc
,
factor
=
vthread
)
oh
,
ih
=
s
[
Out
]
.
split
(
h
,
factor
=
hfactor
)
s
[
Out
]
.
reorder
(
ooc
,
oh
,
ioc
,
ih
,
w
)
...
...
@@ -30,6 +32,10 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
s
[
Out
]
.
bind
(
w
,
thread_x
)
s
[
Out
]
.
bind
(
ioc
,
thread_xz
)
s
[
Out
]
.
bind
(
oc
,
block_x
)
else
:
ow
,
w
=
s
[
Out
]
.
split
(
w
,
factor
=
num_thread
)
s
[
Out
]
.
bind
(
w
,
thread_x
)
s
[
Out
]
.
bind
(
ow
,
block_x
)
s
[
Out_L
]
.
compute_at
(
s
[
Out
],
w
)
...
...
@@ -40,7 +46,7 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
s
[
temp_S
]
.
compute_at
(
s
[
Out_L
],
ic
)
s
[
Filter_S
]
.
compute_at
(
s
[
Out_L
],
w
)
num_thread1
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num
_threads
num_thread1
=
max
_threads
thread_xx
=
tvm
.
thread_axis
((
0
,
num_thread1
),
"threadIdx.x"
)
block_xx
=
tvm
.
thread_axis
(
"blockIdx.x"
)
...
...
@@ -59,6 +65,7 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
h
=
s
[
temp_S
]
.
fuse
(
h
,
ow
)
_
,
tx
=
s
[
temp_S
]
.
split
(
h
,
factor
=
num_thread
)
s
[
temp_S
]
.
bind
(
tx
,
thread_x
)
if
num_thread
<
max_threads
:
s
[
temp_S
]
.
vectorize
(
iw
)
#schedule Filter_S shared mem load
...
...
@@ -250,12 +257,13 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
def
conv2d_14_256_256
(
s
,
temp
,
temp_R
,
temp_S
,
Filter
,
Filter_S
,
Out
,
Out_L
):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
max_threads
=
int
(
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
)
if
util
.
get_const_int
(
Filter
.
shape
[
0
])
+
util
.
get_const_int
(
Filter
.
shape
[
1
])
<=
768
:
# scheduler params
vthread_x
=
util
.
get_const_int
(
Out
.
shape
[
3
])
num_thread_x
=
64
ofactor
=
8
if
util
.
get_const_int
(
Filter
.
shape
[
3
])
==
1
:
if
util
.
get_const_int
(
Filter
.
shape
[
3
])
==
1
and
vthread_x
*
5
<=
max_threads
:
ofactor
=
64
block_x
=
tvm
.
thread_axis
(
"blockIdx.x"
)
thread_x
=
tvm
.
thread_axis
((
0
,
num_thread_x
),
"threadIdx.x"
)
...
...
@@ -295,9 +303,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
else
:
# scheduler params
vthread_x
=
util
.
get_const_int
(
Out
.
shape
[
2
]
)
vthread_x
=
min
(
8
,
util
.
get_const_int
(
Out
.
shape
[
2
])
)
num_thread_x
=
16
num_thread_y
=
util
.
get_const_int
(
Out
.
shape
[
3
]
)
num_thread_y
=
min
(
max_threads
//
num_thread_x
,
util
.
get_const_int
(
Out
.
shape
[
3
])
)
ofactor
=
8
block_x
=
tvm
.
thread_axis
(
"blockIdx.x"
)
thread_x
=
tvm
.
thread_axis
((
0
,
num_thread_x
),
"threadIdx.x"
)
...
...
@@ -305,11 +313,13 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
thread_xz
=
tvm
.
thread_axis
((
0
,
vthread_x
),
"vthread"
,
name
=
"vx"
)
i
,
oc
,
h
,
w
=
s
[
Out
]
.
op
.
axis
ow
,
iw
=
s
[
Out
]
.
split
(
w
,
factor
=
num_thread_y
)
oh
,
ih
=
s
[
Out
]
.
split
(
h
,
factor
=
vthread_x
)
ooc
,
ioc
=
s
[
Out
]
.
split
(
oc
,
factor
=
num_thread_x
)
s
[
Out
]
.
reorder
(
i
,
ooc
,
h
,
w
,
ioc
)
s
[
Out
]
.
reorder
(
i
,
ooc
,
oh
,
ih
,
ow
,
i
w
,
ioc
)
s
[
Out
]
.
bind
(
ioc
,
thread_x
)
s
[
Out
]
.
bind
(
w
,
thread_y
)
s
[
Out
]
.
bind
(
h
,
thread_xz
)
s
[
Out
]
.
bind
(
i
w
,
thread_y
)
s
[
Out
]
.
bind
(
i
h
,
thread_xz
)
s
[
Out
]
.
bind
(
ooc
,
block_x
)
s
[
Out_L
]
.
compute_at
(
s
[
Out
],
ioc
)
...
...
@@ -323,7 +333,7 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
s
[
temp_S
]
.
compute_at
(
s
[
Out_L
],
oic
)
s
[
Filter_S
]
.
compute_at
(
s
[
Out_L
],
oic
)
num_thread
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num
_threads
num_thread
=
max
_threads
thread_xx
=
tvm
.
thread_axis
((
0
,
num_thread
),
"threadIdx.x"
)
block_xx
=
tvm
.
thread_axis
(
"blockIdx.x"
)
...
...
topi/python/topi/intel_graphics/conv2d.py
View file @
8c5078c9
# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
, too-many-arguments, too-many-locals, too-many-statements, no-member, too-many-branches
"""conv2d schedule on Intel Graphics"""
from
__future__
import
absolute_import
as
_abs
...
...
@@ -57,7 +57,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
return
sym
.
contrib
.
conv2d_NCHWc
(
*
copy_inputs
,
**
new_attrs
)
@conv2d_NCHWc.register
([
"intel_graphics"
])
def
_decl_conv2d
(
data
,
kernel
,
num_filter
,
kernel_size
,
stride
,
padding
,
out_dtype
=
'float32'
):
def
_decl_conv2d
(
data
,
kernel
,
num_filter
,
kernel_size
,
stride
,
padding
,
layout
,
\
out_layout
,
out_dtype
=
'float32'
):
"""Conv2D operator for Intel Graphics backend.
Parameters
...
...
@@ -96,7 +97,7 @@ def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, out_dty
return
_decl_cl_spatialpack_NCHWc
(
data
,
kernel
,
stride
,
padding
,
out_dtype
)
@generic.schedule_conv2d_NCHWc.register
([
"intel_graphics"
])
def
schedule_conv2d_NCHWc
(
num_filter
,
kernel_size
,
stride
,
padding
,
outs
):
def
schedule_conv2d_NCHWc
(
num_filter
,
kernel_size
,
stride
,
padding
,
layout
,
out_layout
,
outs
):
"""Schedule for conv2d_nchw for Intel Graphics
Parameters
...
...
topi/tests/python/test_topi_conv2d_nchw.py
View file @
8c5078c9
...
...
@@ -74,7 +74,7 @@ def test_conv2d_nchw():
verify_conv2d_nchw
(
1
,
256
,
14
,
512
,
3
,
2
,
1
)
verify_conv2d_nchw
(
1
,
256
,
14
,
512
,
1
,
2
,
0
)
verify_conv2d_nchw
(
1
,
512
,
7
,
512
,
3
,
1
,
1
)
# ResNet
50 workloads
# ResNet50 workloads
verify_conv2d_nchw
(
1
,
64
,
56
,
256
,
1
,
1
,
0
)
verify_conv2d_nchw
(
1
,
256
,
56
,
64
,
1
,
1
,
0
)
verify_conv2d_nchw
(
1
,
256
,
56
,
128
,
1
,
2
,
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment