_transform.py 1.71 KB
Newer Older
1
"""Backend compiler related feature registration"""
2
# pylint: disable=invalid-name,unused-argument
3
from __future__ import absolute_import
4
import topi
5
from . import op as _reg
6
from ._reduce import _schedule_reduce
7
from .op import schedule_injective, OpPattern
8

9 10
schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective
11

12

13
_reg.register_schedule("collapse_sum_like", _schedule_reduce)
14
_reg.register_schedule("broadcast_to", schedule_broadcast)
15
_reg.register_schedule("broadcast_to_like", schedule_broadcast)
16
_reg.register_schedule("expand_dims", schedule_broadcast)
17
_reg.register_schedule("squeeze", schedule_injective)
18 19
_reg.register_schedule("reshape", schedule_injective)
_reg.register_schedule("reshape_like", schedule_injective)
20 21
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
22
_reg.register_schedule("cast", schedule_injective)
23 24 25
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
26 27
_reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective)
28
_reg.register_schedule("where", schedule_broadcast)
29
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
30 31 32 33 34 35 36 37 38 39 40 41

# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)

# concatenate
@_reg.register_compute("concatenate")
def concatenate_compute(attrs, inputs, output_type, target):
    return [topi.concatenate(inputs, axis=attrs.axis)]

_reg.register_schedule("concatenate", schedule_injective)
_reg.register_pattern("concatenate", OpPattern.INJECTIVE)