Commit f863bfdc by Tianqi Chen Committed by GitHub

[TOPI] Fix reduce fusion with more levels (#477)

parent 262345fa
......@@ -60,8 +60,12 @@ def schedule_reduce(outs):
sch = tvm.create_schedule([x.op for x in outs])
def traverse_before_reduce(operator):
if tag.is_injective(operator.tag):
if isinstance(operator, tvm.tensor.PlaceholderOp):
return
elif tag.is_injective(operator.tag):
sch[operator].compute_inline()
for tensor in operator.input_tensors:
traverse_before_reduce(tensor.op)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
......
......@@ -7,7 +7,7 @@ import topi
def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
A1 = topi.exp(A)
A1 = topi.sqrt(topi.exp(A))
if type == "sum":
B = topi.sum(A1, axis=axis, keepdims=keepdims)
elif type == "max":
......@@ -26,7 +26,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
foo = tvm.build(s, [A, B], device, name="sum")
# Test
in_npy = np.random.uniform(size=in_shape).astype(np.float32)
in_npy_map = np.exp(in_npy)
in_npy_map = np.sqrt(np.exp(in_npy))
if type == "sum":
out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
elif type == "max":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment