"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument
from __future__ import absolute_import
import topi
from . import op as _reg
from ._reduce import _schedule_reduce
from .op import schedule_injective, OpPattern

schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective


_reg.register_schedule("collapse_sum_like", _schedule_reduce)
_reg.register_schedule("broadcast_to", schedule_broadcast)
_reg.register_schedule("broadcast_to_like", schedule_broadcast)
_reg.register_schedule("expand_dims", schedule_broadcast)
_reg.register_schedule("squeeze", schedule_injective)
_reg.register_schedule("reshape", schedule_injective)
_reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("arange", schedule_injective)
_reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective)
_reg.register_schedule("where", schedule_broadcast)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)

# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)

# concatenate
@_reg.register_compute("concatenate")
def concatenate_compute(attrs, inputs, output_type, target):
    return [topi.concatenate(inputs, axis=attrs.axis)]

_reg.register_schedule("concatenate", schedule_injective)
_reg.register_pattern("concatenate", OpPattern.INJECTIVE)