_transform.py 1.58 KB
Newer Older
1
"""Backend compiler related feature registration"""
2
# pylint: disable=invalid-name,unused-argument
3
from __future__ import absolute_import
4
import topi
5
from . import op as _reg
6
from ._reduce import _schedule_reduce
7
from .op import schedule_injective, OpPattern
8

9 10
schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective
11

12

13 14
_reg.register_schedule("collapse_sum_like", _schedule_reduce)
_reg.register_schedule("broadcast_to_like", schedule_broadcast)
15
_reg.register_schedule("expand_dims", schedule_broadcast)
16
_reg.register_schedule("squeeze", schedule_injective)
17 18
_reg.register_schedule("reshape", schedule_injective)
_reg.register_schedule("reshape_like", schedule_injective)
19 20
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
21
_reg.register_schedule("cast", schedule_injective)
22 23 24
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
25 26
_reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective)
27
_reg.register_schedule("where", schedule_broadcast)
28 29 30 31 32 33 34 35 36 37 38 39

# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)

# concatenate
@_reg.register_compute("concatenate")
def concatenate_compute(attrs, inputs, output_type, target):
    return [topi.concatenate(inputs, axis=attrs.axis)]

_reg.register_schedule("concatenate", schedule_injective)
_reg.register_pattern("concatenate", OpPattern.INJECTIVE)