Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
75d53777
Commit
75d53777
authored
Sep 24, 2017
by
Yuwei HU
Committed by
Tianqi Chen
Sep 23, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add pool (#478)
parent
f863bfdc
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
168 additions
and
45 deletions
+168
-45
topi/python/topi/cuda/__init__.py
+1
-1
topi/python/topi/cuda/pooling.py
+55
-1
topi/python/topi/nn/pooling.py
+61
-42
topi/tests/python/test_topi_pooling.py
+51
-1
No files found.
topi/python/topi/cuda/__init__.py
View file @
75d53777
...
...
@@ -11,4 +11,4 @@ from .reduction import schedule_reduce
from
.softmax
import
schedule_softmax
from
.injective
import
schedule_injective
,
schedule_elemwise
,
schedule_broadcast
from
.dense
import
schedule_dense
from
.pooling
import
schedule_global_pool
from
.pooling
import
schedule_
pool
,
schedule_
global_pool
topi/python/topi/cuda/pooling.py
View file @
75d53777
...
...
@@ -56,7 +56,7 @@ def schedule_global_pool(outs):
if
tensor
.
op
.
input_tensors
:
traverse
(
tensor
.
op
)
# schedule global_pool
elif
'global_pool'
in
OP
.
tag
:
elif
OP
.
tag
.
startswith
(
'global_pool'
)
:
Pool
=
OP
.
output
(
0
)
_schedule
(
Pool
)
else
:
...
...
@@ -64,3 +64,57 @@ def schedule_global_pool(outs):
traverse
(
outs
[
0
]
.
op
)
return
s
def
schedule_pool
(
outs
):
"""Schedule for pool.
Parameters
----------
outs: Array of Tensor
The computation graph description of pool
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for pool.
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
def
_schedule
(
PaddedInput
,
Pool
):
s
[
PaddedInput
]
.
compute_inline
()
num_thread
=
512
if
Pool
.
op
in
s
.
outputs
:
Out
=
Pool
OL
=
s
.
cache_write
(
Pool
,
"local"
)
else
:
Out
=
outs
[
0
]
.
op
.
output
(
0
)
s
[
Pool
]
.
set_scope
(
"local"
)
fused
=
s
[
Out
]
.
fuse
(
*
s
[
Out
]
.
op
.
axis
)
bx
,
tx
=
s
[
Out
]
.
split
(
fused
,
factor
=
num_thread
)
s
[
Out
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
Out
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
if
Pool
.
op
in
s
.
outputs
:
s
[
OL
]
.
compute_at
(
s
[
Out
],
tx
)
else
:
s
[
Pool
]
.
compute_at
(
s
[
Out
],
tx
)
def
traverse
(
OP
):
# inline all one-to-one-mapping operators except the last stage (output)
if
tag
.
is_broadcast
(
OP
.
tag
):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
traverse
(
tensor
.
op
)
# schedule pool
elif
OP
.
tag
.
startswith
(
'pool'
):
PaddedInput
=
OP
.
input_tensors
[
0
]
Pool
=
OP
.
output
(
0
)
_schedule
(
PaddedInput
,
Pool
)
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
OP
.
tag
)
traverse
(
outs
[
0
]
.
op
)
return
s
topi/python/topi/nn/pooling.py
View file @
75d53777
...
...
@@ -6,84 +6,103 @@ from .util import get_pad_tuple
from
..
import
util
from
..
import
tag
def
max_pool
(
data
,
kernel
,
stride
,
padding
):
"""Perform max pooling on the data
def
global_pool
(
data
,
pool_type
):
"""Perform global pooling on the data
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width]
kernel : list/tuple of two ints
Kernel size, or [kernel_height, kernel_width]
stride : list/tuple of two ints
Stride size, or [stride_height, stride_width]
paddding : list/tuple of two ints
Pad size, or [pad_height, pad_width]
pool_type : str
Pool type, 'max' or 'avg'
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel,
out_height, out_width
]
4-D with shape [batch, channel,
1, 1
]
"""
assert
len
(
data
.
shape
)
==
4
,
"only support 4-dim pooling"
assert
len
(
stride
)
==
2
,
"only support 2-dim stride"
kernel_height
,
kernel_width
=
kernel
stride_height
,
stride_width
=
stride
batch
,
channel
,
height
,
width
=
data
.
shape
pad_top
,
pad_left
,
pad_down
,
pad_right
=
get_pad_tuple
(
padding
,
(
kernel_height
,
kernel_width
))
pad_before
=
[
0
,
0
,
pad_top
,
pad_left
]
pad_after
=
[
0
,
0
,
pad_down
,
pad_right
]
temp
=
pad
(
data
,
pad_before
,
pad_after
,
name
=
"pad_temp"
,
pad_value
=
tvm
.
min_value
(
"float32"
))
out_height
=
util
.
simplify
((
height
-
kernel_height
+
pad_top
+
pad_down
)
//
stride_height
+
1
)
out_width
=
util
.
simplify
((
width
-
kernel_width
+
pad_left
+
pad_right
)
//
stride_width
+
1
)
dheight
=
tvm
.
reduce_axis
((
0
,
kernel_height
))
dwidth
=
tvm
.
reduce_axis
((
0
,
kernel_width
))
dheight
=
tvm
.
reduce_axis
((
0
,
height
))
dwidth
=
tvm
.
reduce_axis
((
0
,
width
))
return
tvm
.
compute
(
(
batch
,
channel
,
out_height
,
out_width
),
lambda
i
,
c
,
h
,
w
:
tvm
.
max
(
temp
[
i
,
c
,
h
*
stride_height
+
dheight
,
w
*
stride_width
+
dwidth
],
axis
=
[
dheight
,
dwidth
]),
tag
=
"max_pool"
)
if
pool_type
==
'max'
:
return
tvm
.
compute
((
batch
,
channel
,
1
,
1
),
lambda
n
,
c
,
h
,
w
:
\
tvm
.
max
(
data
[
n
,
c
,
dheight
,
dwidth
],
axis
=
[
dheight
,
dwidth
]),
\
tag
=
"global_pool_max"
)
elif
pool_type
==
'avg'
:
tsum
=
tvm
.
compute
((
batch
,
channel
,
1
,
1
),
lambda
n
,
c
,
h
,
w
:
\
tvm
.
sum
(
data
[
n
,
c
,
dheight
,
dwidth
],
axis
=
[
dheight
,
dwidth
]),
\
tag
=
"global_pool_sum"
)
return
tvm
.
compute
((
batch
,
channel
,
1
,
1
),
lambda
n
,
c
,
h
,
w
:
\
tsum
[
n
,
c
,
h
,
w
]
/
(
height
*
width
),
\
tag
=
tag
.
ELEMWISE
)
else
:
raise
ValueError
(
"Pool type should be 'avg' or 'max'."
)
def
global_pool
(
data
,
pool_type
):
"""Perform
global
pooling on the data
def
pool
(
data
,
kernel
,
stride
,
padding
,
pool_type
):
"""Perform pooling on the data
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width]
kernel : list/tuple of two ints
Kernel size, [kernel_height, kernel_width]
stride : list/tuple of two ints
Stride size, [stride_height, stride_width]
paddding : list/tuple of two ints
Pad size, [pad_height, pad_width]
pool_type : str
Pool type, 'max' or 'avg'
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel,
1, 1
]
4-D with shape [batch, channel,
out_height, out_width
]
"""
assert
len
(
data
.
shape
)
==
4
,
"only support 4-dim pooling"
assert
len
(
stride
)
==
2
,
"only support 2-dim stride"
kernel_height
,
kernel_width
=
kernel
stride_height
,
stride_width
=
stride
batch
,
channel
,
height
,
width
=
data
.
shape
dheight
=
tvm
.
reduce_axis
((
0
,
height
))
dwidth
=
tvm
.
reduce_axis
((
0
,
width
))
pad_top
,
pad_left
,
pad_down
,
pad_right
=
get_pad_tuple
(
padding
,
(
kernel_height
,
kernel_width
))
pad_before
=
[
0
,
0
,
pad_top
,
pad_left
]
pad_after
=
[
0
,
0
,
pad_down
,
pad_right
]
out_height
=
util
.
simplify
((
height
-
kernel_height
+
pad_top
+
pad_down
)
//
stride_height
+
1
)
out_width
=
util
.
simplify
((
width
-
kernel_width
+
pad_left
+
pad_right
)
//
stride_width
+
1
)
dheight
=
tvm
.
reduce_axis
((
0
,
kernel_height
))
dwidth
=
tvm
.
reduce_axis
((
0
,
kernel_width
))
if
pool_type
==
'max'
:
return
tvm
.
compute
((
batch
,
channel
,
1
,
1
),
lambda
n
,
c
,
h
,
w
:
\
tvm
.
max
(
data
[
n
,
c
,
dheight
,
dwidth
],
axis
=
[
dheight
,
dwidth
]),
\
tag
=
"global_pool_max"
)
temp
=
pad
(
data
,
pad_before
,
pad_after
,
name
=
"pad_temp"
,
\
pad_value
=
tvm
.
min_value
(
data
.
dtype
))
return
tvm
.
compute
((
batch
,
channel
,
out_height
,
out_width
),
\
lambda
n
,
c
,
h
,
w
:
\
tvm
.
max
(
temp
[
n
,
c
,
h
*
stride_height
+
dheight
,
w
*
stride_width
+
dwidth
],
\
axis
=
[
dheight
,
dwidth
]),
\
tag
=
"pool_max"
)
elif
pool_type
==
'avg'
:
tsum
=
tvm
.
compute
((
batch
,
channel
,
1
,
1
),
lambda
n
,
c
,
h
,
w
:
\
tvm
.
sum
(
data
[
n
,
c
,
dheight
,
dwidth
],
axis
=
[
dheight
,
dwidth
]),
\
tag
=
"global_pool_sum"
)
return
tvm
.
compute
((
batch
,
channel
,
1
,
1
),
lambda
n
,
c
,
h
,
w
:
\
tsum
[
n
,
c
,
h
,
w
]
/
(
height
*
width
),
\
temp
=
pad
(
data
,
pad_before
,
pad_after
,
name
=
"pad_temp"
,
\
pad_value
=
tvm
.
const
(
0.
)
.
astype
(
data
.
dtype
))
tsum
=
tvm
.
compute
((
batch
,
channel
,
out_height
,
out_width
),
\
lambda
n
,
c
,
h
,
w
:
\
tvm
.
sum
(
temp
[
n
,
c
,
h
*
stride_height
+
dheight
,
w
*
stride_width
+
dwidth
],
\
axis
=
[
dheight
,
dwidth
]),
\
tag
=
"pool_avg"
)
return
tvm
.
compute
((
batch
,
channel
,
out_height
,
out_width
),
\
lambda
n
,
c
,
h
,
w
:
\
tsum
[
n
,
c
,
h
,
w
]
/
(
kernel_height
*
kernel_width
),
\
tag
=
tag
.
ELEMWISE
)
else
:
raise
ValueError
(
"Pool type should be 'avg' or 'max'."
)
topi/tests/python/test_topi_pooling.py
View file @
75d53777
...
...
@@ -4,6 +4,55 @@ import tvm
import
topi
from
topi.util
import
get_const_tuple
def
verify_pool
(
n
,
ic
,
ih
,
kh
,
sh
,
padding
,
pool_type
):
iw
=
ih
kw
=
kh
sw
=
sh
ph
,
pw
=
padding
A
=
tvm
.
placeholder
((
n
,
ic
,
ih
,
iw
),
name
=
'A'
)
B
=
topi
.
nn
.
pool
(
A
,
kernel
=
[
kh
,
kw
],
stride
=
[
sh
,
sw
],
padding
=
padding
,
pool_type
=
pool_type
)
B
=
topi
.
nn
.
relu
(
B
)
s
=
topi
.
cuda
.
schedule_pool
(
B
)
dtype
=
A
.
dtype
a_np
=
np
.
random
.
uniform
(
size
=
(
n
,
ic
,
ih
,
iw
))
.
astype
(
dtype
)
pad_np
=
np
.
zeros
(
shape
=
(
n
,
ic
,
ih
+
2
*
ph
,
iw
+
2
*
pw
))
.
astype
(
dtype
)
no_zero
=
(
range
(
n
),
range
(
ic
),
(
range
(
ph
,
ih
+
ph
)),
(
range
(
pw
,
iw
+
pw
)))
pad_np
[
np
.
ix_
(
*
no_zero
)]
=
a_np
_
,
oc
,
oh
,
ow
=
get_const_tuple
(
B
.
shape
)
b_np
=
np
.
zeros
(
shape
=
(
n
,
oc
,
oh
,
ow
))
.
astype
(
dtype
)
if
pool_type
==
'avg'
:
for
i
in
range
(
oh
):
for
j
in
range
(
ow
):
b_np
[:,:,
i
,
j
]
=
np
.
mean
(
pad_np
[:,
:,
i
*
sh
:
i
*
sh
+
kh
,
j
*
sw
:
j
*
sw
+
kw
],
axis
=
(
2
,
3
))
elif
pool_type
==
'max'
:
for
i
in
range
(
oh
):
for
j
in
range
(
ow
):
b_np
[:,:,
i
,
j
]
=
np
.
max
(
pad_np
[:,
:,
i
*
sh
:
i
*
sh
+
kh
,
j
*
sw
:
j
*
sw
+
kw
],
axis
=
(
2
,
3
))
b_np
=
np
.
maximum
(
b_np
,
0.0
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
print
(
"Skip because
%
s is not enabled"
%
device
)
return
ctx
=
tvm
.
gpu
(
0
)
if
device
==
"cuda"
else
tvm
.
cl
(
0
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
B
.
shape
),
dtype
=
dtype
),
ctx
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
device
)
f
(
a
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
for
device
in
[
'cuda'
,
'opencl'
,
'metal'
]:
check_device
(
device
)
def
test_pool
():
verify_pool
(
1
,
256
,
32
,
2
,
2
,
[
0
,
0
],
'avg'
)
verify_pool
(
1
,
256
,
31
,
3
,
3
,
[
1
,
1
],
'avg'
)
verify_pool
(
1
,
256
,
32
,
2
,
2
,
[
0
,
0
],
'max'
)
verify_pool
(
1
,
256
,
31
,
3
,
3
,
[
1
,
1
],
'max'
)
def
verify_global_pool
(
n
,
c
,
h
,
w
,
pool_type
):
A
=
tvm
.
placeholder
((
n
,
c
,
h
,
w
),
name
=
'A'
)
B
=
topi
.
nn
.
global_pool
(
A
,
pool_type
=
pool_type
)
...
...
@@ -24,7 +73,7 @@ def verify_global_pool(n, c, h, w, pool_type):
ctx
=
tvm
.
gpu
(
0
)
if
device
==
"cuda"
else
tvm
.
cl
(
0
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
B
.
shape
),
dtype
=
B
.
dtype
),
ctx
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"global_avg_pool"
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
device
)
f
(
a
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
...
...
@@ -39,4 +88,5 @@ def test_global_pool():
if
__name__
==
"__main__"
:
test_pool
()
test_global_pool
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment