Commit f863bfdc by Tianqi Chen Committed by GitHub

[TOPI] Fix reduce fusion with more levels (#477)

parent 262345fa
...@@ -60,8 +60,12 @@ def schedule_reduce(outs): ...@@ -60,8 +60,12 @@ def schedule_reduce(outs):
sch = tvm.create_schedule([x.op for x in outs]) sch = tvm.create_schedule([x.op for x in outs])
def traverse_before_reduce(operator): def traverse_before_reduce(operator):
if tag.is_injective(operator.tag): if isinstance(operator, tvm.tensor.PlaceholderOp):
return
elif tag.is_injective(operator.tag):
sch[operator].compute_inline() sch[operator].compute_inline()
for tensor in operator.input_tensors:
traverse_before_reduce(tensor.op)
else: else:
raise RuntimeError("Unsupported operator: %s" % operator.tag) raise RuntimeError("Unsupported operator: %s" % operator.tag)
......
...@@ -7,7 +7,7 @@ import topi ...@@ -7,7 +7,7 @@ import topi
def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
# Build the logic and compile the function # Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
A1 = topi.exp(A) A1 = topi.sqrt(topi.exp(A))
if type == "sum": if type == "sum":
B = topi.sum(A1, axis=axis, keepdims=keepdims) B = topi.sum(A1, axis=axis, keepdims=keepdims)
elif type == "max": elif type == "max":
...@@ -26,7 +26,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -26,7 +26,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
foo = tvm.build(s, [A, B], device, name="sum") foo = tvm.build(s, [A, B], device, name="sum")
# Test # Test
in_npy = np.random.uniform(size=in_shape).astype(np.float32) in_npy = np.random.uniform(size=in_shape).astype(np.float32)
in_npy_map = np.exp(in_npy) in_npy_map = np.sqrt(np.exp(in_npy))
if type == "sum": if type == "sum":
out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims) out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
elif type == "max": elif type == "max":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment