_tensor.py 2.84 KB
Newer Older
1
#pylint: disable=invalid-name, unused-argument
2
"""Backend compiler related feature registration"""
3
from __future__ import absolute_import
4
import topi
5 6
from .op import register_compute, register_schedule, register_pattern
from .op import schedule_injective, OpPattern
7

8 9 10 11 12 13 14 15 16 17 18 19 20 21
schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective

register_schedule("log", schedule_broadcast)
register_schedule("exp", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast)
register_schedule("sigmoid", schedule_broadcast)
register_schedule("floor", schedule_broadcast)
register_schedule("ceil", schedule_broadcast)
register_schedule("trunc", schedule_broadcast)
register_schedule("round", schedule_broadcast)
register_schedule("abs", schedule_broadcast)
register_schedule("tanh", schedule_broadcast)
register_schedule("negative", schedule_broadcast)
22
register_schedule("copy", schedule_broadcast)
23

24
register_schedule("add", schedule_broadcast)
25 26 27
register_schedule("subtract", schedule_broadcast)
register_schedule("multiply", schedule_broadcast)
register_schedule("divide", schedule_broadcast)
28
register_schedule("power", schedule_injective)
29 30 31 32 33 34 35
register_schedule("mod", schedule_broadcast)
register_schedule("equal", schedule_broadcast)
register_schedule("not_equal", schedule_broadcast)
register_schedule("less", schedule_broadcast)
register_schedule("less_equal", schedule_broadcast)
register_schedule("greater", schedule_broadcast)
register_schedule("greater_equal", schedule_broadcast)
36
register_schedule("maximum", schedule_injective)
37 38 39 40 41
register_schedule("minimum", schedule_injective)
register_schedule("right_shift", schedule_injective)
register_schedule("left_shift", schedule_injective)

# zeros
42
@register_compute("zeros")
43 44 45 46
def zeros_compute(attrs, inputs, output_type, target):
    assert not inputs
    return [topi.full(output_type.shape, output_type.dtype, 0.0)]

47 48
register_schedule("zeros", schedule_broadcast)
register_pattern("zeros", OpPattern.ELEMWISE)
49 50

# zeros_like
51
@register_compute("zeros_like")
52 53 54 55
def zeros_like_compute(attrs, inputs, output_type, target):
    assert len(inputs) == 1
    return [topi.full_like(inputs[0], 0.0)]

56
register_schedule("zeros_like", schedule_broadcast)
57 58

# ones
59
@register_compute("ones")
60 61 62 63
def ones_compute(attrs, inputs, output_type, target):
    assert not inputs
    return [topi.full(output_type.shape, output_type.dtype, 1.0)]

64 65
register_schedule("ones", schedule_broadcast)
register_pattern("ones", OpPattern.ELEMWISE)
66 67

# ones_like
68
@register_compute("ones_like")
69 70 71 72
def ones_like(attrs, inputs, output_type, target):
    assert len(inputs) == 1
    return [topi.full_like(inputs[0], 1.0)]

73
register_schedule("ones_like", schedule_broadcast)
74 75

# clip
76
@register_compute("clip")
77 78 79 80
def clip_compute(attrs, inputs, output_type, target):
    assert len(inputs) == 1
    return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]

81
register_schedule("clip", schedule_elemwise)