Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5ae1a079
Commit
5ae1a079
authored
Sep 18, 2017
by
Xingjian Shi
Committed by
Tianqi Chen
Sep 17, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI] add binary broadacst (#456)
* add binary broadacst * fix testing * revise testing threshold
parent
dd029c83
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
318 additions
and
8 deletions
+318
-8
topi/python/topi/broadcast.py
+165
-4
topi/python/topi/cuda/__init__.py
+1
-1
topi/python/topi/cuda/broadcast.py
+34
-3
topi/recipe/broadcast/test_broadcast_map.py
+53
-0
topi/tests/python/test_topi_broadcast.py
+65
-0
No files found.
topi/python/topi/broadcast.py
View file @
5ae1a079
...
...
@@ -2,6 +2,7 @@
"""Broadcast operators"""
from
__future__
import
absolute_import
as
_abs
import
tvm
from
.util
import
get_const_tuple
,
equal_const_int
def
_get_bcast_info
(
original_shape
,
target_shape
):
"""Get the broadcasting info.
...
...
@@ -35,11 +36,9 @@ def _get_bcast_info(original_shape, target_shape):
original_shape
=
original_shape
[::
-
1
]
target_shape
=
target_shape
[::
-
1
]
for
i
in
range
(
len
(
original_shape
)):
if
not
isinstance
(
original_shape
[
i
],
tvm
.
expr
.
IntImm
):
raise
ValueError
(
"Element of original_shape tuple should be IntImm"
)
if
tvm
.
ir_pass
.
Equal
(
tvm
.
convert
(
target_shape
[
i
]),
original_shape
[
i
]):
if
equal_const_int
(
original_shape
[
i
],
target_shape
[
i
]):
bcast_info
[
i
]
=
0
elif
tvm
.
ir_pass
.
Equal
(
original_shape
[
i
],
tvm
.
convert
(
1
)
):
elif
equal_const_int
(
original_shape
[
i
],
1
):
bcast_info
[
i
]
=
1
else
:
raise
ValueError
(
"Original Shape: {} cannot be broadcast to {}"
...
...
@@ -48,6 +47,38 @@ def _get_bcast_info(original_shape, target_shape):
return
bcast_info
def
_get_binary_op_bcast_shape
(
lhs_shape
,
rhs_shape
):
"""Get the shape after binary broadcasting.
We will strictly follow the broadcasting rule in numpy.
Parameters
----------
lhs_shape : tuple
rhs_shape : tuple
Returns
-------
ret_shape : tuple
"""
ret_shape
=
[]
if
len
(
lhs_shape
)
>
len
(
rhs_shape
):
lhs_shape
,
rhs_shape
=
rhs_shape
,
lhs_shape
for
ptr
in
range
(
len
(
rhs_shape
)):
if
ptr
<
len
(
lhs_shape
):
l_val
,
r_val
=
lhs_shape
[
len
(
lhs_shape
)
-
1
-
ptr
],
\
rhs_shape
[
len
(
rhs_shape
)
-
1
-
ptr
]
assert
(
l_val
==
1
or
r_val
==
1
or
l_val
==
r_val
),
\
"Shape is NOT broadcastable, lhs=
%
s, rhs=
%
s"
\
%
(
str
(
lhs_shape
),
str
(
rhs_shape
))
ret_shape
.
append
(
max
(
l_val
,
r_val
))
else
:
ret_shape
.
append
(
rhs_shape
[
len
(
rhs_shape
)
-
1
-
ptr
])
ret_shape
=
ret_shape
[::
-
1
]
return
ret_shape
@tvm.tag_scope
(
tag
=
"broadcast_to"
)
def
broadcast_to
(
data
,
shape
):
"""Broadcast the src to the target shape
...
...
@@ -80,3 +111,133 @@ def broadcast_to(data, shape):
bcast_info
,
*
args
),
name
=
data
.
name
+
"_broadcast"
)
return
ret
@tvm.tag_scope
(
tag
=
"broadcast_binary_op"
)
def
broadcast_binary_op
(
lhs
,
rhs
,
func
,
name
=
"bop"
):
"""Binary operands that will automatically broadcast the inputs
We follows the numpy broadcasting rule.
See also https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
func : function
Returns
-------
ret : tvm.Tensor
"""
def
_inner_arg_eval
(
lhs
,
rhs
,
lhs_bcast_info
,
rhs_bcast_info
,
func
,
*
args
):
lhs_indices
=
[]
rhs_indices
=
[]
for
i
in
range
(
len
(
args
)):
if
lhs_bcast_info
[
i
]
==
0
:
lhs_indices
.
append
(
args
[
i
])
elif
lhs_bcast_info
[
i
]
==
1
:
lhs_indices
.
append
(
0
)
if
rhs_bcast_info
[
i
]
==
0
:
rhs_indices
.
append
(
args
[
i
])
elif
rhs_bcast_info
[
i
]
==
1
:
rhs_indices
.
append
(
0
)
return
func
(
lhs
[
tuple
(
lhs_indices
)],
rhs
[
tuple
(
rhs_indices
)])
ret_shape
=
_get_binary_op_bcast_shape
(
get_const_tuple
(
lhs
.
shape
),
get_const_tuple
(
rhs
.
shape
))
lhs_bcast_info
=
_get_bcast_info
(
original_shape
=
lhs
.
shape
,
target_shape
=
ret_shape
)
rhs_bcast_info
=
_get_bcast_info
(
original_shape
=
rhs
.
shape
,
target_shape
=
ret_shape
)
ret
=
tvm
.
compute
([
tvm
.
convert
(
ele
)
for
ele
in
ret_shape
],
lambda
*
args
:
_inner_arg_eval
(
lhs
,
rhs
,
lhs_bcast_info
,
rhs_bcast_info
,
func
,
*
args
),
name
=
lhs
.
name
+
"_"
+
rhs
.
name
+
"_"
+
name
)
return
ret
def
broadcast_add
(
lhs
,
rhs
):
"""Binary addition with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
Returns
-------
ret : tvm.Tensor
"""
return
broadcast_binary_op
(
lhs
,
rhs
,
lambda
a
,
b
:
a
+
b
,
"add"
)
def
broadcast_mul
(
lhs
,
rhs
):
"""Binary multiplication with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
Returns
-------
ret : tvm.Tensor
"""
return
broadcast_binary_op
(
lhs
,
rhs
,
lambda
a
,
b
:
a
*
b
,
"mul"
)
def
broadcast_div
(
lhs
,
rhs
):
"""Binary division with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
Returns
-------
ret : tvm.Tensor
"""
return
broadcast_binary_op
(
lhs
,
rhs
,
lambda
a
,
b
:
a
/
b
,
"div"
)
def
broadcast_sub
(
lhs
,
rhs
):
"""Binary subtraction with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
Returns
-------
ret : tvm.Tensor
"""
return
broadcast_binary_op
(
lhs
,
rhs
,
lambda
a
,
b
:
a
-
b
,
"sub"
)
def
broadcast_maximum
(
lhs
,
rhs
):
"""Take element-wise maximum of two tensors with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
Returns
-------
ret : tvm.Tensor
"""
return
broadcast_binary_op
(
lhs
,
rhs
,
tvm
.
max
,
"maximum"
)
def
broadcast_minimum
(
lhs
,
rhs
):
"""Take element-wise minimum of two tensors with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
Returns
-------
ret : tvm.Tensor
"""
return
broadcast_binary_op
(
lhs
,
rhs
,
tvm
.
min
,
"minimum"
)
topi/python/topi/cuda/__init__.py
View file @
5ae1a079
...
...
@@ -8,6 +8,6 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise
from
.depthwise_conv2d
import
schedule_depthwise_conv2d_backward_input_nhwc
from
.depthwise_conv2d
import
schedule_depthwise_conv2d_backward_weight_nhwc
from
.reduction
import
schedule_reduce
from
.broadcast
import
schedule_broadcast_to
from
.broadcast
import
schedule_broadcast_to
,
schedule_broadcast_binary_op
from
.softmax
import
schedule_softmax
from
.elemwise
import
schedule_elemwise
topi/python/topi/cuda/broadcast.py
View file @
5ae1a079
...
...
@@ -3,8 +3,7 @@
from
__future__
import
absolute_import
as
_abs
import
tvm
def
_schedule_broadcast_to
(
op
,
sch
):
data_in
=
op
.
input_tensors
[
0
]
def
_schedule_broadcast
(
op
,
sch
):
data_out
=
op
.
output
(
0
)
num_thread
=
512
...
...
@@ -47,7 +46,39 @@ def schedule_broadcast_to(outs):
if
tensor
.
op
.
input_tensors
:
traverse
(
tensor
.
op
)
elif
operator
.
tag
==
'broadcast_to'
:
_schedule_broadcast_to
(
operator
,
sch
)
_schedule_broadcast
(
operator
,
sch
)
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
operator
.
tag
)
traverse
(
outs
[
0
]
.
op
)
return
sch
def
schedule_broadcast_binary_op
(
outs
):
"""Schedule for broadcast_binary ops + element-wise ops.
Parameters
----------
outs: Array of Tensor
The computation graph description of broadcast_binary in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
sch
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
def
traverse
(
operator
):
if
operator
.
tag
==
'ewise'
or
operator
.
tag
==
'scale_shift'
:
if
operator
not
in
sch
.
outputs
:
sch
[
operator
]
.
compute_inline
()
for
tensor
in
operator
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
traverse
(
tensor
.
op
)
elif
operator
.
tag
==
'broadcast_binary_op'
:
_schedule_broadcast
(
operator
,
sch
)
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
operator
.
tag
)
...
...
topi/recipe/broadcast/test_broadcast_map.py
View file @
5ae1a079
...
...
@@ -51,7 +51,60 @@ def test_broadcast_to(in_shape, out_shape):
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
def
test_broadcast_binary_op
(
lhs_shape
,
rhs_shape
,
typ
=
"add"
):
global
TASK
TASK
=
"bcast_binary_"
+
typ
+
"_lhs"
+
\
"_"
.
join
([
str
(
ele
)
for
ele
in
lhs_shape
])
+
\
"rhs"
+
"_"
.
join
([
str
(
ele
)
for
ele
in
rhs_shape
])
A
=
tvm
.
placeholder
(
shape
=
lhs_shape
,
name
=
"A"
)
B
=
tvm
.
placeholder
(
shape
=
rhs_shape
,
name
=
"B"
)
if
typ
==
"add"
:
C
=
topi
.
broadcast_add
(
A
,
B
)
elif
typ
==
"sub"
:
C
=
topi
.
broadcast_sub
(
A
,
B
)
elif
typ
==
"div"
:
C
=
topi
.
broadcast_div
(
A
,
B
)
elif
typ
==
"mul"
:
C
=
topi
.
broadcast_mul
(
A
,
B
)
elif
typ
==
"maximum"
:
C
=
topi
.
broadcast_maximum
(
A
,
B
)
elif
typ
==
"minimum"
:
C
=
topi
.
broadcast_minimum
(
A
,
B
)
else
:
raise
NotImplementedError
s
=
topi
.
cuda
.
schedule_broadcast_binary_op
(
C
)
fcuda
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
"cuda"
,
name
=
"broadcast_binary"
+
"_"
+
typ
)
lhs_npy
=
np
.
random
.
uniform
(
size
=
lhs_shape
)
.
astype
(
A
.
dtype
)
rhs_npy
=
np
.
random
.
uniform
(
size
=
rhs_shape
)
.
astype
(
A
.
dtype
)
if
typ
==
"add"
:
out_npy
=
lhs_npy
+
rhs_npy
elif
typ
==
"sub"
:
out_npy
=
lhs_npy
-
rhs_npy
elif
typ
==
"div"
:
rhs_npy
=
np
.
abs
(
rhs_npy
)
+
0.001
out_npy
=
lhs_npy
/
rhs_npy
elif
typ
==
"mul"
:
out_npy
=
lhs_npy
*
rhs_npy
elif
typ
==
"maximum"
:
out_npy
=
np
.
maximum
(
lhs_npy
,
rhs_npy
)
elif
typ
==
"minimum"
:
out_npy
=
np
.
minimum
(
lhs_npy
,
rhs_npy
)
lhs_nd
=
tvm
.
nd
.
array
(
lhs_npy
,
tvm
.
gpu
())
rhs_nd
=
tvm
.
nd
.
array
(
rhs_npy
,
tvm
.
gpu
())
out_nd
=
tvm
.
nd
.
array
(
np
.
empty
(
out_npy
.
shape
)
.
astype
(
B
.
dtype
),
tvm
.
gpu
())
for
_
in
range
(
2
):
fcuda
(
lhs_nd
,
rhs_nd
,
out_nd
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
if
__name__
==
"__main__"
:
test_broadcast_to
((
1
,),
(
10
,))
test_broadcast_to
((
1
,
1
,
5
,
4
),
(
3
,
4
,
4
,
4
,
5
,
4
))
test_broadcast_to
((
1
,
128
,
1
,
32
),
(
64
,
128
,
64
,
32
))
test_broadcast_binary_op
((
5
,
2
,
3
),
(
2
,
1
),
typ
=
"add"
)
test_broadcast_binary_op
((
5
,
64
,
128
),
(
2
,
5
,
64
,
1
),
typ
=
"mul"
)
test_broadcast_binary_op
((
2
,
3
,
1
,
32
),
(
64
,
32
),
typ
=
"div"
)
test_broadcast_binary_op
((
1
,
32
),
(
64
,
32
),
typ
=
"sub"
)
test_broadcast_binary_op
((
32
,),
(
64
,
32
),
typ
=
"maximum"
)
test_broadcast_binary_op
((
1
,
2
,
2
,
1
,
32
),
(
64
,
32
),
typ
=
"minimum"
)
topi/tests/python/test_topi_broadcast.py
View file @
5ae1a079
...
...
@@ -29,10 +29,75 @@ def verify_broadcast_to_ele(in_shape, out_shape):
check_device
(
"metal"
)
def
verify_broadcast_binary_ele
(
lhs_shape
,
rhs_shape
,
typ
=
"add"
):
# Build the logic and compile the function
A
=
tvm
.
placeholder
(
shape
=
lhs_shape
,
name
=
"A"
)
B
=
tvm
.
placeholder
(
shape
=
rhs_shape
,
name
=
"B"
)
if
typ
==
"add"
:
C
=
topi
.
broadcast_add
(
A
,
B
)
elif
typ
==
"sub"
:
C
=
topi
.
broadcast_sub
(
A
,
B
)
elif
typ
==
"div"
:
C
=
topi
.
broadcast_div
(
A
,
B
)
elif
typ
==
"mul"
:
C
=
topi
.
broadcast_mul
(
A
,
B
)
elif
typ
==
"maximum"
:
C
=
topi
.
broadcast_maximum
(
A
,
B
)
elif
typ
==
"minimum"
:
C
=
topi
.
broadcast_minimum
(
A
,
B
)
else
:
raise
NotImplementedError
s
=
topi
.
cuda
.
schedule_broadcast_binary_op
(
C
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
print
(
"Skip because
%
s is not enabled"
%
device
)
return
ctx
=
tvm
.
gpu
(
0
)
if
device
==
"cuda"
else
tvm
.
cl
(
0
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
device
,
name
=
"broadcast_binary"
+
"_"
+
typ
)
lhs_npy
=
np
.
random
.
uniform
(
size
=
lhs_shape
)
.
astype
(
A
.
dtype
)
rhs_npy
=
np
.
random
.
uniform
(
size
=
rhs_shape
)
.
astype
(
A
.
dtype
)
if
typ
==
"add"
:
out_npy
=
lhs_npy
+
rhs_npy
elif
typ
==
"sub"
:
out_npy
=
lhs_npy
-
rhs_npy
elif
typ
==
"div"
:
rhs_npy
=
np
.
abs
(
rhs_npy
)
+
0.001
out_npy
=
lhs_npy
/
rhs_npy
elif
typ
==
"mul"
:
out_npy
=
lhs_npy
*
rhs_npy
elif
typ
==
"maximum"
:
out_npy
=
np
.
maximum
(
lhs_npy
,
rhs_npy
)
elif
typ
==
"minimum"
:
out_npy
=
np
.
minimum
(
lhs_npy
,
rhs_npy
)
else
:
raise
NotImplementedError
lhs_nd
=
tvm
.
nd
.
array
(
lhs_npy
,
ctx
)
rhs_nd
=
tvm
.
nd
.
array
(
rhs_npy
,
ctx
)
out_nd
=
tvm
.
nd
.
array
(
np
.
empty
(
out_npy
.
shape
)
.
astype
(
B
.
dtype
),
ctx
)
for
_
in
range
(
1
):
foo
(
lhs_nd
,
rhs_nd
,
out_nd
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
,
rtol
=
1E-4
,
atol
=
1E-4
)
check_device
(
"opencl"
)
check_device
(
"cuda"
)
check_device
(
"metal"
)
def
test_broadcast_to
():
verify_broadcast_to_ele
((
1
,),
(
10
,))
verify_broadcast_to_ele
((
1
,
1
,
5
,
4
),
(
3
,
4
,
4
,
4
,
5
,
4
))
verify_broadcast_to_ele
((
1
,
128
,
1
,
32
),
(
64
,
128
,
64
,
32
))
def
test_broadcast_binary
():
verify_broadcast_binary_ele
((
5
,
2
,
3
),
(
2
,
1
),
typ
=
"add"
)
verify_broadcast_binary_ele
((
5
,
64
,
128
),
(
2
,
5
,
64
,
1
),
typ
=
"mul"
)
verify_broadcast_binary_ele
((
2
,
3
,
1
,
32
),
(
64
,
32
),
typ
=
"div"
)
verify_broadcast_binary_ele
((
1
,
32
),
(
64
,
32
),
typ
=
"sub"
)
verify_broadcast_binary_ele
((
32
,),
(
64
,
32
),
typ
=
"maximum"
)
verify_broadcast_binary_ele
((
1
,
2
,
2
,
1
,
32
),
(
64
,
32
),
typ
=
"minimum"
)
if
__name__
==
"__main__"
:
test_broadcast_to
()
test_broadcast_binary
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment