Commit 51f8327f by Xingjian Shi Committed by Tianqi Chen

Use ewise schedule for broadcasting (#460)

parent d6007a24
......@@ -8,6 +8,6 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
from .reduction import schedule_reduce
from .broadcast import schedule_broadcast_to, schedule_broadcast_binary_op
from .broadcast import schedule_broadcast
from .softmax import schedule_softmax
from .elemwise import schedule_elemwise
......@@ -3,27 +3,11 @@
from __future__ import absolute_import as _abs
import tvm
def _schedule_broadcast(op, sch):
data_out = op.output(0)
from .elemwise import _schedule_elemwise
num_thread = 512
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
xo, vi = sch[data_out].split(sch[data_out].op.axis[len(sch[data_out].op.axis) - 1],
factor=4)
sch[data_out].vectorize(vi)
fused_axis = sch[data_out].fuse(*[sch[data_out].op.axis[i]
for i in range(len(sch[data_out].op.axis) - 1)] + [xo])
bx, tx = sch[data_out].split(fused_axis, factor=num_thread)
sch[data_out].bind(bx, block_x)
sch[data_out].bind(tx, thread_x)
return sch
def schedule_broadcast_to(outs):
"""Schedule for broadcast_to ops + element-wise ops.
def schedule_broadcast(outs):
"""Schedule for broadcasting ops (broadcast_to + broadcast binary) + element-wise ops.
Parameters
----------
......@@ -45,40 +29,8 @@ def schedule_broadcast_to(outs):
for tensor in operator.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif operator.tag == 'broadcast_to':
_schedule_broadcast(operator, sch)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
traverse(outs[0].op)
return sch
def schedule_broadcast_binary_op(outs):
"""Schedule for broadcast_binary ops + element-wise ops.
Parameters
----------
outs: Array of Tensor
The computation graph description of broadcast_binary in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
sch = tvm.create_schedule([x.op for x in outs])
def traverse(operator):
if operator.tag == 'ewise' or operator.tag == 'scale_shift':
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif operator.tag == 'broadcast_binary_op':
_schedule_broadcast(operator, sch)
elif operator.tag == 'broadcast_to' or operator.tag == 'broadcast_binary_op':
_schedule_elemwise(operator, sch)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
......
......@@ -2,6 +2,17 @@
"""Schedule for element wise operator"""
import tvm
def _schedule_elemwise(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
num_thread = 512
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch
def schedule_elemwise(outs):
"""Schedule for element wise op.
......@@ -20,12 +31,4 @@ def schedule_elemwise(outs):
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
x = outs[0]
fused = s[x].fuse(*x.op.axis)
num_thread = 64
bx, tx = s[x].split(fused, factor=num_thread)
s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return s
return _schedule_elemwise(outs[0].op, s)
......@@ -38,7 +38,7 @@ def test_broadcast_to(in_shape, out_shape):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.broadcast_to(A, out_shape)
s = topi.cuda.schedule_broadcast_to(B)
s = topi.cuda.schedule_broadcast(B)
fcuda = tvm.build(s, [A, B], "cuda", name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
......@@ -72,7 +72,7 @@ def test_broadcast_binary_op(lhs_shape, rhs_shape, typ="add"):
C = topi.broadcast_minimum(A, B)
else:
raise NotImplementedError
s = topi.cuda.schedule_broadcast_binary_op(C)
s = topi.cuda.schedule_broadcast(C)
fcuda = tvm.build(s, [A, B, C], "cuda", name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
......
......@@ -8,7 +8,7 @@ def verify_broadcast_to_ele(in_shape, out_shape):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.broadcast_to(A, out_shape)
s = topi.cuda.schedule_broadcast_to(B)
s = topi.cuda.schedule_broadcast(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
......@@ -47,7 +47,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
C = topi.broadcast_minimum(A, B)
else:
raise NotImplementedError
s = topi.cuda.schedule_broadcast_binary_op(C)
s = topi.cuda.schedule_broadcast(C)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment