Commit 833855e7 by Tianqi Chen Committed by GitHub

[TOPI] Fix reduction fusion with injective input (#475)

parent 203b8188
......@@ -34,6 +34,7 @@ def _schedule_reduce(op, sch):
# Bind the axes to threads and blocks
sch[data_out].bind(sch[data_out].op.reduce_axis[0], thread_x)
sch[data_out].set_store_predicate(thread_x.equal(0))
sch[data_out].bind(outer_in, thread_y)
sch[data_out].bind(bx, block_x)
else:
......@@ -57,17 +58,22 @@ def schedule_reduce(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
sch = tvm.create_schedule([x.op for x in outs])
def traverse(operator):
def traverse_before_reduce(operator):
if tag.is_injective(operator.tag):
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
def traverse_after_reduce(operator):
if tag.is_broadcast(operator.tag):
raise RuntimeError("Not yet support ewise after reduce")
elif operator.tag == 'comm_reduce':
_schedule_reduce(operator, sch)
for tensor in operator.input_tensors:
traverse_before_reduce(tensor.op)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
traverse(outs[0].op)
traverse_after_reduce(outs[0].op)
return sch
......@@ -7,12 +7,13 @@ import topi
def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
A1 = topi.exp(A)
if type == "sum":
B = topi.sum(A, axis=axis, keepdims=keepdims)
B = topi.sum(A1, axis=axis, keepdims=keepdims)
elif type == "max":
B = topi.max(A, axis=axis, keepdims=keepdims)
B = topi.max(A1, axis=axis, keepdims=keepdims)
elif type == "min":
B = topi.min(A, axis=axis, keepdims=keepdims)
B = topi.min(A1, axis=axis, keepdims=keepdims)
else:
raise NotImplementedError
s = topi.cuda.schedule_reduce(B)
......@@ -23,15 +24,15 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, [A, B], device, name="sum")
# Test
in_npy = np.random.normal(size=in_shape).astype(np.float32)
in_npy = np.random.uniform(size=in_shape).astype(np.float32)
in_npy_map = np.exp(in_npy)
if type == "sum":
out_npy = in_npy.sum(axis=axis, keepdims=keepdims)
out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
elif type == "max":
out_npy = in_npy.max(axis=axis, keepdims=keepdims)
out_npy = in_npy_map.max(axis=axis, keepdims=keepdims)
elif type == "min":
out_npy = in_npy.min(axis=axis, keepdims=keepdims)
out_npy = in_npy_map.min(axis=axis, keepdims=keepdims)
else:
raise NotImplementedError
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment