Commit 51f8327f by Xingjian Shi Committed by Tianqi Chen

Use ewise schedule for broadcasting (#460)

parent d6007a24
...@@ -8,6 +8,6 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise ...@@ -8,6 +8,6 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
from .reduction import schedule_reduce from .reduction import schedule_reduce
from .broadcast import schedule_broadcast_to, schedule_broadcast_binary_op from .broadcast import schedule_broadcast
from .softmax import schedule_softmax from .softmax import schedule_softmax
from .elemwise import schedule_elemwise from .elemwise import schedule_elemwise
...@@ -3,27 +3,11 @@ ...@@ -3,27 +3,11 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
def _schedule_broadcast(op, sch): from .elemwise import _schedule_elemwise
data_out = op.output(0)
num_thread = 512
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
xo, vi = sch[data_out].split(sch[data_out].op.axis[len(sch[data_out].op.axis) - 1], def schedule_broadcast(outs):
factor=4) """Schedule for broadcasting ops (broadcast_to + broadcast binary) + element-wise ops.
sch[data_out].vectorize(vi)
fused_axis = sch[data_out].fuse(*[sch[data_out].op.axis[i]
for i in range(len(sch[data_out].op.axis) - 1)] + [xo])
bx, tx = sch[data_out].split(fused_axis, factor=num_thread)
sch[data_out].bind(bx, block_x)
sch[data_out].bind(tx, thread_x)
return sch
def schedule_broadcast_to(outs):
"""Schedule for broadcast_to ops + element-wise ops.
Parameters Parameters
---------- ----------
...@@ -45,40 +29,8 @@ def schedule_broadcast_to(outs): ...@@ -45,40 +29,8 @@ def schedule_broadcast_to(outs):
for tensor in operator.input_tensors: for tensor in operator.input_tensors:
if tensor.op.input_tensors: if tensor.op.input_tensors:
traverse(tensor.op) traverse(tensor.op)
elif operator.tag == 'broadcast_to': elif operator.tag == 'broadcast_to' or operator.tag == 'broadcast_binary_op':
_schedule_broadcast(operator, sch) _schedule_elemwise(operator, sch)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
traverse(outs[0].op)
return sch
def schedule_broadcast_binary_op(outs):
"""Schedule for broadcast_binary ops + element-wise ops.
Parameters
----------
outs: Array of Tensor
The computation graph description of broadcast_binary in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
sch = tvm.create_schedule([x.op for x in outs])
def traverse(operator):
if operator.tag == 'ewise' or operator.tag == 'scale_shift':
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif operator.tag == 'broadcast_binary_op':
_schedule_broadcast(operator, sch)
else: else:
raise RuntimeError("Unsupported operator: %s" % operator.tag) raise RuntimeError("Unsupported operator: %s" % operator.tag)
......
...@@ -2,6 +2,17 @@ ...@@ -2,6 +2,17 @@
"""Schedule for element wise operator""" """Schedule for element wise operator"""
import tvm import tvm
def _schedule_elemwise(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
num_thread = 512
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch
def schedule_elemwise(outs): def schedule_elemwise(outs):
"""Schedule for element wise op. """Schedule for element wise op.
...@@ -20,12 +31,4 @@ def schedule_elemwise(outs): ...@@ -20,12 +31,4 @@ def schedule_elemwise(outs):
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s) tvm.schedule.AutoInlineInjective(s)
return _schedule_elemwise(outs[0].op, s)
x = outs[0]
fused = s[x].fuse(*x.op.axis)
num_thread = 64
bx, tx = s[x].split(fused, factor=num_thread)
s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return s
...@@ -38,7 +38,7 @@ def test_broadcast_to(in_shape, out_shape): ...@@ -38,7 +38,7 @@ def test_broadcast_to(in_shape, out_shape):
# Build the logic and compile the function # Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.broadcast_to(A, out_shape) B = topi.broadcast_to(A, out_shape)
s = topi.cuda.schedule_broadcast_to(B) s = topi.cuda.schedule_broadcast(B)
fcuda = tvm.build(s, [A, B], "cuda", name="broadcast_to") fcuda = tvm.build(s, [A, B], "cuda", name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype) data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
...@@ -72,7 +72,7 @@ def test_broadcast_binary_op(lhs_shape, rhs_shape, typ="add"): ...@@ -72,7 +72,7 @@ def test_broadcast_binary_op(lhs_shape, rhs_shape, typ="add"):
C = topi.broadcast_minimum(A, B) C = topi.broadcast_minimum(A, B)
else: else:
raise NotImplementedError raise NotImplementedError
s = topi.cuda.schedule_broadcast_binary_op(C) s = topi.cuda.schedule_broadcast(C)
fcuda = tvm.build(s, [A, B, C], "cuda", name="broadcast_binary" + "_" + typ) fcuda = tvm.build(s, [A, B, C], "cuda", name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype) lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
......
...@@ -8,7 +8,7 @@ def verify_broadcast_to_ele(in_shape, out_shape): ...@@ -8,7 +8,7 @@ def verify_broadcast_to_ele(in_shape, out_shape):
# Build the logic and compile the function # Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.broadcast_to(A, out_shape) B = topi.broadcast_to(A, out_shape)
s = topi.cuda.schedule_broadcast_to(B) s = topi.cuda.schedule_broadcast(B)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
...@@ -47,7 +47,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): ...@@ -47,7 +47,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
C = topi.broadcast_minimum(A, B) C = topi.broadcast_minimum(A, B)
else: else:
raise NotImplementedError raise NotImplementedError
s = topi.cuda.schedule_broadcast_binary_op(C) s = topi.cuda.schedule_broadcast(C)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment