Commit 78a0f47b by hlu1 Committed by Tianqi Chen

[ARM] Fix concat (#3061)

parent 24fe04f8
...@@ -23,6 +23,7 @@ from .op import schedule_injective, OpPattern ...@@ -23,6 +23,7 @@ from .op import schedule_injective, OpPattern
schedule_injective = _reg.schedule_injective schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective
schedule_concatenate = _reg.schedule_concatenate
_reg.register_schedule("collapse_sum_like", _schedule_reduce) _reg.register_schedule("collapse_sum_like", _schedule_reduce)
...@@ -46,7 +47,7 @@ _reg.register_schedule("take", schedule_injective) ...@@ -46,7 +47,7 @@ _reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective) _reg.register_schedule("transpose", schedule_injective)
_reg.register_schedule("where", schedule_broadcast) _reg.register_schedule("where", schedule_broadcast)
_reg.register_schedule("stack", schedule_injective) _reg.register_schedule("stack", schedule_injective)
_reg.register_schedule("concatenate", schedule_injective) _reg.register_schedule("concatenate", schedule_concatenate)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective) _reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective) _reg.register_schedule("gather_nd", schedule_injective)
......
...@@ -219,6 +219,13 @@ def schedule_injective(attrs, outputs, target): ...@@ -219,6 +219,13 @@ def schedule_injective(attrs, outputs, target):
with target: with target:
return topi.generic.schedule_injective(outputs) return topi.generic.schedule_injective(outputs)
def schedule_concatenate(attrs, outputs, target):
"""Generic schedule for concatinate."""
with target:
return topi.generic.schedule_concatenate(outputs)
__DEBUG_COUNTER__ = 0 __DEBUG_COUNTER__ = 0
def debug(expr, debug_func=None): def debug(expr, debug_func=None):
......
...@@ -51,3 +51,32 @@ def schedule_injective(outs): ...@@ -51,3 +51,32 @@ def schedule_injective(outs):
elif len(s[x].op.axis) >= 2: elif len(s[x].op.axis) >= 2:
s[x].parallel(s[x].op.axis[0]) s[x].parallel(s[x].op.axis[0])
return s return s
@generic.schedule_concatenate.register(["arm_cpu"])
def schedule_concatenate(outs):
"""Schedule for concatenate op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
if len(s[x].op.axis) >= 4:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 3:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 2:
s[x].parallel(s[x].op.axis[0])
return s
...@@ -127,7 +127,7 @@ def verify_concatenate(shapes, axis): ...@@ -127,7 +127,7 @@ def verify_concatenate(shapes, axis):
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor) s = topi.generic.schedule_concatenate(out_tensor)
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate") foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes] data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
...@@ -476,6 +476,7 @@ def test_concatenate(): ...@@ -476,6 +476,7 @@ def test_concatenate():
(12, 6, 7, 3), (12, 6, 7, 3),
(8, 6, 7, 3), (8, 6, 7, 3),
(2, 6, 7, 3)], 0) (2, 6, 7, 3)], 0)
verify_concatenate([(1, 14400), (1, 2400), (1, 640), (1, 240)], 1)
def test_stack(): def test_stack():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment