Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2dec0510
Commit
2dec0510
authored
Jul 08, 2017
by
Yuwei HU
Committed by
Tianqi Chen
Jul 08, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TEST][TOPI] of depthwise_conv2d (#230)
* test of depthwise_conv2d * fix nose test error * python3 fix
parent
55ba9cb8
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
121 additions
and
9 deletions
+121
-9
topi/python/topi/cuda/depthwise_conv2d_map.py
+1
-1
topi/python/topi/nn/conv.py
+2
-2
topi/recipe/conv/depthwise_conv2d_map_test.py
+6
-6
topi/tests/python/test_topi_depthwise_conv2d_map.py
+112
-0
No files found.
topi/python/topi/cuda/depthwise_conv2d_map.py
View file @
2dec0510
...
...
@@ -4,7 +4,7 @@ import tvm
from
..nn.util
import
get_const_tuple
def
schedule_depthwise_conv2d_map
(
op
):
"""Schedule for depthwise_conv2d
map ops.
"""Schedule for depthwise_conv2d map ops.
This include scale-shift and relu.
...
...
topi/python/topi/nn/conv.py
View file @
2dec0510
...
...
@@ -43,8 +43,8 @@ def depthwise_conv2d(Input, Filter, Stride, padding):
# calculate output shape
if
padding
==
'VALID'
:
out_channel
=
in_channel
*
channel_multiplier
out_height
=
(
in_height
-
filter_height
)
/
stride_h
+
1
out_width
=
(
in_width
-
filter_width
)
/
stride_w
+
1
out_height
=
(
in_height
-
filter_height
)
/
/
stride_h
+
1
out_width
=
(
in_width
-
filter_width
)
/
/
stride_w
+
1
pad_along_height
=
0
pad_along_width
=
0
if
padding
==
'SAME'
:
...
...
topi/recipe/conv/depthwise_conv2d_map_test.py
View file @
2dec0510
...
...
@@ -78,15 +78,15 @@ def test_depthwise_conv2d_map():
index_w
=
pad_left_scipy
-
pad_left_tvm
for
i
in
range
(
batch
):
for
j
in
range
(
out_channel
):
depthwise_conv2d_scipy
[
i
,
j
,:,:]
=
signal
.
convolve2d
(
input_np
[
i
,
j
//
channel_multiplier
,:,:],
np
.
rot90
(
filter_np
[
j
//
channel_multiplier
,
j
%
channel_multiplier
,:,:],
2
),
mode
=
'same'
)[
index_h
:
in_height
:
stride_h
,
index_w
:
in_width
:
stride_w
]
depthwise_conv2d_scipy
[
i
,
j
,:,:]
=
signal
.
convolve2d
(
input_np
[
i
,
j
//
channel_multiplier
,:,:],
np
.
rot90
(
filter_np
[
j
//
channel_multiplier
,
j
%
channel_multiplier
,:,:],
2
),
mode
=
'same'
)[
index_h
:
in_height
:
stride_h
,
index_w
:
in_width
:
stride_w
]
if
padding
==
'VALID'
:
for
i
in
range
(
batch
):
for
j
in
range
(
out_channel
):
depthwise_conv2d_scipy
[
i
,
j
,:,:]
=
signal
.
convolve2d
(
input_np
[
i
,
j
//
channel_multiplier
,:,:],
np
.
rot90
(
filter_np
[
j
//
channel_multiplier
,
j
%
channel_multiplier
,:,:],
2
),
mode
=
'valid'
)[
0
:(
in_height
-
filter_height
+
1
):
stride_h
,
0
:(
in_width
-
filter_height
+
1
):
stride_w
]
depthwise_conv2d_scipy
[
i
,
j
,:,:]
=
signal
.
convolve2d
(
input_np
[
i
,
j
//
channel_multiplier
,:,:],
np
.
rot90
(
filter_np
[
j
//
channel_multiplier
,
j
%
channel_multiplier
,:,:],
2
),
mode
=
'valid'
)[
0
:(
in_height
-
filter_height
+
1
):
stride_h
,
0
:(
in_width
-
filter_height
+
1
):
stride_w
]
for
c
in
range
(
out_channel
):
scale_shift_scipy
[:,
c
,:,:]
=
depthwise_conv2d_scipy
[:,
c
,:,:]
*
scale_np
[
c
]
+
shift_np
[
c
]
relu_scipy
[:,:,:,:]
=
np
.
maximum
(
scale_shift_scipy
[:,:,:,:],
0
)
...
...
topi/tests/python/test_topi_depthwise_conv2d_map.py
0 → 100644
View file @
2dec0510
import
tvm
import
topi
import
numpy
as
np
from
scipy
import
signal
from
topi.nn.util
import
get_const_tuple
from
topi.cuda.depthwise_conv2d_map
import
schedule_depthwise_conv2d_map
def
depthwise_conv2d_map_with_workload
(
batch
,
in_channel
,
in_height
,
channel_multiplier
,
filter_height
,
stride_h
,
padding
):
in_width
=
in_height
filter_channel
=
in_channel
filter_width
=
filter_height
stride_w
=
stride_h
# placeholder
Input
=
tvm
.
placeholder
((
batch
,
in_channel
,
in_height
,
in_width
),
name
=
'Input'
)
Filter
=
tvm
.
placeholder
((
filter_channel
,
channel_multiplier
,
filter_height
,
filter_width
),
name
=
'Filter'
)
Stride
=
tvm
.
nd
.
array
(
np
.
array
([
stride_h
,
stride_w
]))
Scale
=
tvm
.
placeholder
((
in_channel
*
channel_multiplier
,),
name
=
'Scale'
)
Shift
=
tvm
.
placeholder
((
in_channel
*
channel_multiplier
,),
name
=
'Shift'
)
# declare
DepthwiseConv2d
=
topi
.
nn
.
depthwise_conv2d
(
Input
,
Filter
,
Stride
,
padding
)
ScaleShift
=
topi
.
nn
.
scale_shift
(
DepthwiseConv2d
,
Scale
,
Shift
)
Relu
=
topi
.
nn
.
relu
(
ScaleShift
)
# schedule
s1
=
schedule_depthwise_conv2d_map
(
DepthwiseConv2d
.
op
)
s2
=
schedule_depthwise_conv2d_map
(
ScaleShift
.
op
)
s3
=
schedule_depthwise_conv2d_map
(
Relu
.
op
)
def
depthwise_conv2d_map_scipy
(
input_np
,
filter_np
,
scale_np
,
shift_np
):
out_shape
=
get_const_tuple
(
DepthwiseConv2d
.
shape
)
out_channel
=
out_shape
[
1
]
out_height
=
out_shape
[
2
]
out_width
=
out_shape
[
3
]
depthwise_conv2d_scipy
=
np
.
zeros
((
batch
,
out_channel
,
out_height
,
out_width
),
dtype
=
DepthwiseConv2d
.
dtype
)
scale_shift_scipy
=
np
.
zeros
((
batch
,
out_channel
,
out_height
,
out_width
),
dtype
=
ScaleShift
.
dtype
)
relu_scipy
=
np
.
zeros
((
batch
,
out_channel
,
out_height
,
out_width
),
dtype
=
Relu
.
dtype
)
if
padding
==
'SAME'
:
pad_top_tvm
=
np
.
int
(
np
.
ceil
(
float
(
np
.
max
((
out_height
-
1
)
*
stride_h
+
filter_height
-
in_height
,
0
))
/
2
))
pad_left_tvm
=
np
.
int
(
np
.
ceil
(
float
(
np
.
max
((
out_width
-
1
)
*
stride_w
+
filter_width
-
in_width
,
0
))
/
2
))
pad_top_scipy
=
np
.
int
(
np
.
ceil
(
float
(
filter_height
-
1
)
/
2
))
pad_left_scipy
=
np
.
int
(
np
.
ceil
(
float
(
filter_width
-
1
)
/
2
))
index_h
=
pad_top_scipy
-
pad_top_tvm
index_w
=
pad_left_scipy
-
pad_left_tvm
for
i
in
range
(
batch
):
for
j
in
range
(
out_channel
):
depthwise_conv2d_scipy
[
i
,
j
,:,:]
=
signal
.
convolve2d
(
input_np
[
i
,
j
//
channel_multiplier
,:,:],
np
.
rot90
(
filter_np
[
j
//
channel_multiplier
,
j
%
channel_multiplier
,:,:],
2
),
mode
=
'same'
)[
index_h
:
in_height
:
stride_h
,
index_w
:
in_width
:
stride_w
]
if
padding
==
'VALID'
:
for
i
in
range
(
batch
):
for
j
in
range
(
out_channel
):
depthwise_conv2d_scipy
[
i
,
j
,:,:]
=
signal
.
convolve2d
(
input_np
[
i
,
j
//
channel_multiplier
,:,:],
np
.
rot90
(
filter_np
[
j
//
channel_multiplier
,
j
%
channel_multiplier
,:,:],
2
),
mode
=
'valid'
)[
0
:(
in_height
-
filter_height
+
1
):
stride_h
,
0
:(
in_width
-
filter_height
+
1
):
stride_w
]
for
c
in
range
(
out_channel
):
scale_shift_scipy
[:,
c
,:,:]
=
depthwise_conv2d_scipy
[:,
c
,:,:]
*
scale_np
[
c
]
+
shift_np
[
c
]
relu_scipy
[:,:,:,:]
=
np
.
maximum
(
scale_shift_scipy
[:,:,:,:],
0
)
return
depthwise_conv2d_scipy
,
scale_shift_scipy
,
relu_scipy
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
print
(
"Skip because
%
s is not enabled"
%
device
)
return
ctx
=
tvm
.
context
(
device
,
0
)
# build the kernels
f1
=
tvm
.
build
(
s1
,
[
Input
,
Filter
,
DepthwiseConv2d
],
device
)
f2
=
tvm
.
build
(
s2
,
[
Input
,
Filter
,
Scale
,
Shift
,
ScaleShift
],
device
)
f3
=
tvm
.
build
(
s3
,
[
Input
,
Filter
,
Scale
,
Shift
,
Relu
],
device
)
# prepare data
input_np
=
np
.
random
.
uniform
(
size
=
get_const_tuple
(
Input
.
shape
))
.
astype
(
Input
.
dtype
)
filter_np
=
np
.
random
.
uniform
(
size
=
get_const_tuple
(
Filter
.
shape
))
.
astype
(
Filter
.
dtype
)
input_tvm
=
tvm
.
nd
.
array
(
input_np
,
ctx
)
filter_tvm
=
tvm
.
nd
.
array
(
filter_np
,
ctx
)
scale_np
=
np
.
random
.
uniform
(
size
=
get_const_tuple
(
Scale
.
shape
))
.
astype
(
Scale
.
dtype
)
shift_np
=
np
.
random
.
uniform
(
size
=
get_const_tuple
(
Shift
.
shape
))
.
astype
(
Shift
.
dtype
)
scale_tvm
=
tvm
.
nd
.
array
(
scale_np
,
ctx
)
shift_tvm
=
tvm
.
nd
.
array
(
shift_np
,
ctx
)
depthwise_conv2d_tvm
=
tvm
.
nd
.
array
(
np
.
zeros
(
shape
=
get_const_tuple
(
DepthwiseConv2d
.
shape
),
dtype
=
DepthwiseConv2d
.
dtype
),
ctx
)
scale_shift_tvm
=
tvm
.
nd
.
array
(
np
.
zeros
(
shape
=
get_const_tuple
(
ScaleShift
.
shape
),
dtype
=
ScaleShift
.
dtype
),
ctx
)
relu_tvm
=
tvm
.
nd
.
array
(
np
.
zeros
(
shape
=
get_const_tuple
(
Relu
.
shape
),
dtype
=
Relu
.
dtype
),
ctx
)
# launch kernel 1 (depthwise_conv2d)
timer_1
=
f1
.
time_evaluator
(
f1
.
entry_name
,
ctx
,
number
=
1
)
tcost_1
=
timer_1
(
input_tvm
,
filter_tvm
,
depthwise_conv2d_tvm
)
# launch kernel 2 (depthwise_conv2d + scale_shift)
timer_2
=
f2
.
time_evaluator
(
f2
.
entry_name
,
ctx
,
number
=
1
)
tcost_2
=
timer_2
(
input_tvm
,
filter_tvm
,
scale_tvm
,
shift_tvm
,
scale_shift_tvm
)
# launch kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3
=
f3
.
time_evaluator
(
f3
.
entry_name
,
ctx
,
number
=
1
)
tcost_3
=
timer_3
(
input_tvm
,
filter_tvm
,
scale_tvm
,
shift_tvm
,
relu_tvm
)
# correctness with scipy
depthwise_conv2d_scipy
,
scale_shift_scipy
,
relu_scipy
=
depthwise_conv2d_map_scipy
(
input_np
,
filter_np
,
scale_np
,
shift_np
)
np
.
testing
.
assert_allclose
(
depthwise_conv2d_tvm
.
asnumpy
(),
depthwise_conv2d_scipy
,
rtol
=
1e-5
)
np
.
testing
.
assert_allclose
(
scale_shift_tvm
.
asnumpy
(),
scale_shift_scipy
,
rtol
=
1e-5
)
np
.
testing
.
assert_allclose
(
relu_tvm
.
asnumpy
(),
relu_scipy
,
rtol
=
1e-5
)
check_device
(
"opencl"
)
check_device
(
"cuda"
)
check_device
(
"metal"
)
def
test_depthwise_conv2d_map
():
depthwise_conv2d_map_with_workload
(
1
,
728
,
64
,
1
,
3
,
1
,
"SAME"
)
depthwise_conv2d_map_with_workload
(
1
,
728
,
32
,
1
,
3
,
1
,
"SAME"
)
depthwise_conv2d_map_with_workload
(
4
,
256
,
64
,
2
,
5
,
2
,
"SAME"
)
depthwise_conv2d_map_with_workload
(
4
,
256
,
32
,
2
,
5
,
2
,
"SAME"
)
depthwise_conv2d_map_with_workload
(
1
,
728
,
64
,
1
,
3
,
1
,
"VALID"
)
depthwise_conv2d_map_with_workload
(
1
,
728
,
32
,
1
,
3
,
1
,
"VALID"
)
depthwise_conv2d_map_with_workload
(
4
,
256
,
64
,
2
,
5
,
2
,
"VALID"
)
depthwise_conv2d_map_with_workload
(
4
,
256
,
32
,
2
,
5
,
2
,
"VALID"
)
if
__name__
==
"__main__"
:
test_depthwise_conv2d_map
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment