Commit 5ae1a079 by Xingjian Shi Committed by Tianqi Chen

[TOPI] add binary broadacst (#456)

* add binary broadacst

* fix testing

* revise testing threshold
parent dd029c83
......@@ -2,6 +2,7 @@
"""Broadcast operators"""
from __future__ import absolute_import as _abs
import tvm
from .util import get_const_tuple, equal_const_int
def _get_bcast_info(original_shape, target_shape):
"""Get the broadcasting info.
......@@ -35,11 +36,9 @@ def _get_bcast_info(original_shape, target_shape):
original_shape = original_shape[::-1]
target_shape = target_shape[::-1]
for i in range(len(original_shape)):
if not isinstance(original_shape[i], tvm.expr.IntImm):
raise ValueError("Element of original_shape tuple should be IntImm")
if tvm.ir_pass.Equal(tvm.convert(target_shape[i]), original_shape[i]):
if equal_const_int(original_shape[i], target_shape[i]):
bcast_info[i] = 0
elif tvm.ir_pass.Equal(original_shape[i], tvm.convert(1)):
elif equal_const_int(original_shape[i], 1):
bcast_info[i] = 1
else:
raise ValueError("Original Shape: {} cannot be broadcast to {}"
......@@ -48,6 +47,38 @@ def _get_bcast_info(original_shape, target_shape):
return bcast_info
def _get_binary_op_bcast_shape(lhs_shape, rhs_shape):
"""Get the shape after binary broadcasting.
We will strictly follow the broadcasting rule in numpy.
Parameters
----------
lhs_shape : tuple
rhs_shape : tuple
Returns
-------
ret_shape : tuple
"""
ret_shape = []
if len(lhs_shape) > len(rhs_shape):
lhs_shape, rhs_shape = rhs_shape, lhs_shape
for ptr in range(len(rhs_shape)):
if ptr < len(lhs_shape):
l_val, r_val = lhs_shape[len(lhs_shape) - 1 - ptr], \
rhs_shape[len(rhs_shape) - 1 - ptr]
assert(l_val == 1 or r_val == 1 or l_val == r_val),\
"Shape is NOT broadcastable, lhs=%s, rhs=%s"\
%(str(lhs_shape), str(rhs_shape))
ret_shape.append(max(l_val, r_val))
else:
ret_shape.append(rhs_shape[len(rhs_shape) - 1 - ptr])
ret_shape = ret_shape[::-1]
return ret_shape
@tvm.tag_scope(tag="broadcast_to")
def broadcast_to(data, shape):
"""Broadcast the src to the target shape
......@@ -80,3 +111,133 @@ def broadcast_to(data, shape):
bcast_info,
*args), name=data.name + "_broadcast")
return ret
@tvm.tag_scope(tag="broadcast_binary_op")
def broadcast_binary_op(lhs, rhs, func, name="bop"):
"""Binary operands that will automatically broadcast the inputs
We follows the numpy broadcasting rule.
See also https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
func : function
Returns
-------
ret : tvm.Tensor
"""
def _inner_arg_eval(lhs, rhs, lhs_bcast_info, rhs_bcast_info, func, *args):
lhs_indices = []
rhs_indices = []
for i in range(len(args)):
if lhs_bcast_info[i] == 0:
lhs_indices.append(args[i])
elif lhs_bcast_info[i] == 1:
lhs_indices.append(0)
if rhs_bcast_info[i] == 0:
rhs_indices.append(args[i])
elif rhs_bcast_info[i] == 1:
rhs_indices.append(0)
return func(lhs[tuple(lhs_indices)], rhs[tuple(rhs_indices)])
ret_shape = _get_binary_op_bcast_shape(get_const_tuple(lhs.shape), get_const_tuple(rhs.shape))
lhs_bcast_info = _get_bcast_info(original_shape=lhs.shape, target_shape=ret_shape)
rhs_bcast_info = _get_bcast_info(original_shape=rhs.shape, target_shape=ret_shape)
ret = tvm.compute([tvm.convert(ele) for ele in ret_shape],
lambda *args: _inner_arg_eval(lhs, rhs, lhs_bcast_info, rhs_bcast_info,
func, *args),
name=lhs.name + "_" + rhs.name + "_" + name)
return ret
def broadcast_add(lhs, rhs):
"""Binary addition with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
Returns
-------
ret : tvm.Tensor
"""
return broadcast_binary_op(lhs, rhs, lambda a, b: a + b, "add")
def broadcast_mul(lhs, rhs):
"""Binary multiplication with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
Returns
-------
ret : tvm.Tensor
"""
return broadcast_binary_op(lhs, rhs, lambda a, b: a * b, "mul")
def broadcast_div(lhs, rhs):
"""Binary division with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
Returns
-------
ret : tvm.Tensor
"""
return broadcast_binary_op(lhs, rhs, lambda a, b: a / b, "div")
def broadcast_sub(lhs, rhs):
"""Binary subtraction with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
Returns
-------
ret : tvm.Tensor
"""
return broadcast_binary_op(lhs, rhs, lambda a, b: a - b, "sub")
def broadcast_maximum(lhs, rhs):
"""Take element-wise maximum of two tensors with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
Returns
-------
ret : tvm.Tensor
"""
return broadcast_binary_op(lhs, rhs, tvm.max, "maximum")
def broadcast_minimum(lhs, rhs):
"""Take element-wise minimum of two tensors with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor
rhs : tvm.Tensor
Returns
-------
ret : tvm.Tensor
"""
return broadcast_binary_op(lhs, rhs, tvm.min, "minimum")
......@@ -8,6 +8,6 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
from .reduction import schedule_reduce
from .broadcast import schedule_broadcast_to
from .broadcast import schedule_broadcast_to, schedule_broadcast_binary_op
from .softmax import schedule_softmax
from .elemwise import schedule_elemwise
......@@ -3,8 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
def _schedule_broadcast_to(op, sch):
data_in = op.input_tensors[0]
def _schedule_broadcast(op, sch):
data_out = op.output(0)
num_thread = 512
......@@ -47,7 +46,39 @@ def schedule_broadcast_to(outs):
if tensor.op.input_tensors:
traverse(tensor.op)
elif operator.tag == 'broadcast_to':
_schedule_broadcast_to(operator, sch)
_schedule_broadcast(operator, sch)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
traverse(outs[0].op)
return sch
def schedule_broadcast_binary_op(outs):
"""Schedule for broadcast_binary ops + element-wise ops.
Parameters
----------
outs: Array of Tensor
The computation graph description of broadcast_binary in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
sch = tvm.create_schedule([x.op for x in outs])
def traverse(operator):
if operator.tag == 'ewise' or operator.tag == 'scale_shift':
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif operator.tag == 'broadcast_binary_op':
_schedule_broadcast(operator, sch)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
......
......@@ -51,7 +51,60 @@ def test_broadcast_to(in_shape, out_shape):
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
def test_broadcast_binary_op(lhs_shape, rhs_shape, typ="add"):
global TASK
TASK = "bcast_binary_" + typ + "_lhs" +\
"_".join([str(ele) for ele in lhs_shape]) +\
"rhs" + "_".join([str(ele) for ele in rhs_shape])
A = tvm.placeholder(shape=lhs_shape, name="A")
B = tvm.placeholder(shape=rhs_shape, name="B")
if typ == "add":
C = topi.broadcast_add(A, B)
elif typ == "sub":
C = topi.broadcast_sub(A, B)
elif typ == "div":
C = topi.broadcast_div(A, B)
elif typ == "mul":
C = topi.broadcast_mul(A, B)
elif typ == "maximum":
C = topi.broadcast_maximum(A, B)
elif typ == "minimum":
C = topi.broadcast_minimum(A, B)
else:
raise NotImplementedError
s = topi.cuda.schedule_broadcast_binary_op(C)
fcuda = tvm.build(s, [A, B, C], "cuda", name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype)
if typ == "add":
out_npy = lhs_npy + rhs_npy
elif typ == "sub":
out_npy = lhs_npy - rhs_npy
elif typ == "div":
rhs_npy = np.abs(rhs_npy) + 0.001
out_npy = lhs_npy / rhs_npy
elif typ == "mul":
out_npy = lhs_npy * rhs_npy
elif typ == "maximum":
out_npy = np.maximum(lhs_npy, rhs_npy)
elif typ == "minimum":
out_npy = np.minimum(lhs_npy, rhs_npy)
lhs_nd = tvm.nd.array(lhs_npy, tvm.gpu())
rhs_nd = tvm.nd.array(rhs_npy, tvm.gpu())
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), tvm.gpu())
for _ in range(2):
fcuda(lhs_nd, rhs_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
if __name__ == "__main__":
test_broadcast_to((1,), (10,))
test_broadcast_to((1, 1, 5, 4), (3, 4, 4, 4, 5, 4))
test_broadcast_to((1, 128, 1, 32), (64, 128, 64, 32))
test_broadcast_binary_op((5, 2, 3), (2, 1), typ="add")
test_broadcast_binary_op((5, 64, 128), (2, 5, 64, 1), typ="mul")
test_broadcast_binary_op((2, 3, 1, 32), (64, 32), typ="div")
test_broadcast_binary_op((1, 32), (64, 32), typ="sub")
test_broadcast_binary_op((32,), (64, 32), typ="maximum")
test_broadcast_binary_op((1, 2, 2, 1, 32), (64, 32), typ="minimum")
......@@ -29,10 +29,75 @@ def verify_broadcast_to_ele(in_shape, out_shape):
check_device("metal")
def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
# Build the logic and compile the function
A = tvm.placeholder(shape=lhs_shape, name="A")
B = tvm.placeholder(shape=rhs_shape, name="B")
if typ == "add":
C = topi.broadcast_add(A, B)
elif typ == "sub":
C = topi.broadcast_sub(A, B)
elif typ == "div":
C = topi.broadcast_div(A, B)
elif typ == "mul":
C = topi.broadcast_mul(A, B)
elif typ == "maximum":
C = topi.broadcast_maximum(A, B)
elif typ == "minimum":
C = topi.broadcast_minimum(A, B)
else:
raise NotImplementedError
s = topi.cuda.schedule_broadcast_binary_op(C)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype)
if typ == "add":
out_npy = lhs_npy + rhs_npy
elif typ == "sub":
out_npy = lhs_npy - rhs_npy
elif typ == "div":
rhs_npy = np.abs(rhs_npy) + 0.001
out_npy = lhs_npy / rhs_npy
elif typ == "mul":
out_npy = lhs_npy * rhs_npy
elif typ == "maximum":
out_npy = np.maximum(lhs_npy, rhs_npy)
elif typ == "minimum":
out_npy = np.minimum(lhs_npy, rhs_npy)
else:
raise NotImplementedError
lhs_nd = tvm.nd.array(lhs_npy, ctx)
rhs_nd = tvm.nd.array(rhs_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
for _ in range(1):
foo(lhs_nd, rhs_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
check_device("opencl")
check_device("cuda")
check_device("metal")
def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,))
verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4))
verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32))
def test_broadcast_binary():
verify_broadcast_binary_ele((5, 2, 3), (2, 1), typ="add")
verify_broadcast_binary_ele((5, 64, 128), (2, 5, 64, 1), typ="mul")
verify_broadcast_binary_ele((2, 3, 1, 32), (64, 32), typ="div")
verify_broadcast_binary_ele((1, 32), (64, 32), typ="sub")
verify_broadcast_binary_ele((32,), (64, 32), typ="maximum")
verify_broadcast_binary_ele((1, 2, 2, 1, 32), (64, 32), typ="minimum")
if __name__ == "__main__":
test_broadcast_to()
test_broadcast_binary()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment