Commit 833855e7 by Tianqi Chen Committed by GitHub

[TOPI] Fix reduction fusion with injective input (#475)

parent 203b8188
...@@ -34,6 +34,7 @@ def _schedule_reduce(op, sch): ...@@ -34,6 +34,7 @@ def _schedule_reduce(op, sch):
# Bind the axes to threads and blocks # Bind the axes to threads and blocks
sch[data_out].bind(sch[data_out].op.reduce_axis[0], thread_x) sch[data_out].bind(sch[data_out].op.reduce_axis[0], thread_x)
sch[data_out].set_store_predicate(thread_x.equal(0))
sch[data_out].bind(outer_in, thread_y) sch[data_out].bind(outer_in, thread_y)
sch[data_out].bind(bx, block_x) sch[data_out].bind(bx, block_x)
else: else:
...@@ -57,17 +58,22 @@ def schedule_reduce(outs): ...@@ -57,17 +58,22 @@ def schedule_reduce(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
sch = tvm.create_schedule([x.op for x in outs]) sch = tvm.create_schedule([x.op for x in outs])
def traverse(operator):
def traverse_before_reduce(operator):
if tag.is_injective(operator.tag): if tag.is_injective(operator.tag):
if operator not in sch.outputs: sch[operator].compute_inline()
sch[operator].compute_inline() else:
for tensor in operator.input_tensors: raise RuntimeError("Unsupported operator: %s" % operator.tag)
if tensor.op.input_tensors:
traverse(tensor.op) def traverse_after_reduce(operator):
if tag.is_broadcast(operator.tag):
raise RuntimeError("Not yet support ewise after reduce")
elif operator.tag == 'comm_reduce': elif operator.tag == 'comm_reduce':
_schedule_reduce(operator, sch) _schedule_reduce(operator, sch)
for tensor in operator.input_tensors:
traverse_before_reduce(tensor.op)
else: else:
raise RuntimeError("Unsupported operator: %s" % operator.tag) raise RuntimeError("Unsupported operator: %s" % operator.tag)
traverse(outs[0].op) traverse_after_reduce(outs[0].op)
return sch return sch
...@@ -7,12 +7,13 @@ import topi ...@@ -7,12 +7,13 @@ import topi
def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
# Build the logic and compile the function # Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
A1 = topi.exp(A)
if type == "sum": if type == "sum":
B = topi.sum(A, axis=axis, keepdims=keepdims) B = topi.sum(A1, axis=axis, keepdims=keepdims)
elif type == "max": elif type == "max":
B = topi.max(A, axis=axis, keepdims=keepdims) B = topi.max(A1, axis=axis, keepdims=keepdims)
elif type == "min": elif type == "min":
B = topi.min(A, axis=axis, keepdims=keepdims) B = topi.min(A1, axis=axis, keepdims=keepdims)
else: else:
raise NotImplementedError raise NotImplementedError
s = topi.cuda.schedule_reduce(B) s = topi.cuda.schedule_reduce(B)
...@@ -23,15 +24,15 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -23,15 +24,15 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, [A, B], device, name="sum") foo = tvm.build(s, [A, B], device, name="sum")
# Test # Test
in_npy = np.random.normal(size=in_shape).astype(np.float32) in_npy = np.random.uniform(size=in_shape).astype(np.float32)
in_npy_map = np.exp(in_npy)
if type == "sum": if type == "sum":
out_npy = in_npy.sum(axis=axis, keepdims=keepdims) out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
elif type == "max": elif type == "max":
out_npy = in_npy.max(axis=axis, keepdims=keepdims) out_npy = in_npy_map.max(axis=axis, keepdims=keepdims)
elif type == "min": elif type == "min":
out_npy = in_npy.min(axis=axis, keepdims=keepdims) out_npy = in_npy_map.min(axis=axis, keepdims=keepdims)
else: else:
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment