test_topi_broadcast.py 6.95 KB
Newer Older
1
"""Test code for broadcasting operators."""
2
from common import get_all_backend
3 4 5 6
import numpy as np
import tvm
import topi

7
def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
8 9
    # Build the logic and compile the function
    A = tvm.placeholder(shape=in_shape, name="A")
10
    B = fbcast(A, out_shape)
11

12
    def check_device(device):
13 14
        ctx = tvm.context(device, 0)
        if not ctx.exist:
15 16
            print("Skip because %s is not enabled" % device)
            return
17
        print("Running on target: %s" % device)
18 19
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(B)
20 21 22 23 24
        foo = tvm.build(s, [A, B], device, name="broadcast_to")
        data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = np.broadcast_to(data_npy, out_shape)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
25
        foo(data_nd, out_nd)
26 27
        np.testing.assert_allclose(out_nd.asnumpy(), out_npy)

28 29
    for target in get_all_backend():
        check_device(target)
30
    check_device("sdaccel")
31 32


33 34 35 36 37
def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
                                ftopi, fnumpy,
                                lhs_min=-100, lhs_max=100,
                                rhs_min=-100, rhs_max=100,
                                dtype="float32"):
38
    # Build the logic and compile the function
39 40 41 42 43
    A = (tvm.var("A", dtype=dtype) if lhs_shape is None
         else tvm.placeholder(shape=lhs_shape, name="A", dtype=dtype))
    B = (tvm.var("B", dtype=dtype) if rhs_shape is None
         else tvm.placeholder(shape=rhs_shape, name="B", dtype=dtype))
    C = ftopi(A, B)
44
    if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
45 46
        assert(isinstance(C, tvm.expr.Expr))
        return
47

48
    def check_device(device):
49 50
        ctx = tvm.context(device, 0)
        if not ctx.exist:
51 52
            print("Skip because %s is not enabled" % device)
            return
53
        print("Running on target: %s" % device)
54 55
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(C)
56 57 58 59 60 61
        foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + ftopi.__name__)
        if lhs_shape is None:
            lhs_npy = float(np.random.uniform(low=lhs_min, high=lhs_max))
            if dtype.startswith('int'):
                lhs_npy = int(lhs_npy)
            lhs_nd = lhs_npy
62
        else:
63 64 65 66 67 68 69
            lhs_npy = np.random.uniform(low=lhs_min, high=lhs_max,
                                        size=lhs_shape).astype(A.dtype)
            lhs_nd = tvm.nd.array(lhs_npy, ctx)

        if rhs_shape is None:
            rhs_npy = float(np.random.uniform(low=rhs_min, high=rhs_max))
            if dtype.startswith('int'):
70
                rhs_npy = int(rhs_npy)
71 72 73 74 75 76 77 78
            rhs_nd = rhs_npy
        else:
            rhs_npy = np.random.uniform(low=rhs_min, high=rhs_max,
                                        size=rhs_shape).astype(A.dtype)
            rhs_nd = tvm.nd.array(rhs_npy, ctx)

        out_npy = fnumpy(lhs_npy, rhs_npy)
        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
79
        foo(lhs_nd, rhs_nd, out_nd)
80 81
        np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)

82 83
    for target in get_all_backend():
        check_device(target)
84
    check_device("sdaccel")
85

86
def test_broadcast_to():
87 88 89 90 91 92 93
    verify_broadcast_to_ele((1,), (10,), topi.broadcast_to)
    verify_broadcast_to_ele((), (10,), topi.broadcast_to)
    verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4), topi.broadcast_to)
    verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32), topi.broadcast_to)

def test_add():
    verify_broadcast_binary_ele(
94 95
        (), (), topi.add, np.add)
    verify_broadcast_binary_ele(
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
        (5, 2, 3), (2, 1), topi.add, np.add)

def test_subtract():
    verify_broadcast_binary_ele(
        (5, 2, 3), (), topi.subtract, np.subtract)
    verify_broadcast_binary_ele(
        (5, 2, 3), None, topi.subtract, np.subtract)
    verify_broadcast_binary_ele(
        None, None, topi.subtract, np.subtract)
    verify_broadcast_binary_ele(
        (1, 32), (64, 32), topi.subtract, np.subtract)

def test_multiply():
    verify_broadcast_binary_ele(
        (5, 64, 128), (2, 5, 64, 1), topi.multiply, np.multiply)

def test_divide():
    verify_broadcast_binary_ele(
        None, (10,), topi.divide, np.divide, rhs_min=0.0001)
    verify_broadcast_binary_ele(
116 117
        (), None, topi.divide, np.divide, rhs_min=0.0001)
    verify_broadcast_binary_ele(
118 119 120 121 122 123 124 125 126 127 128
        (2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)

def test_maximum_minmum():
    verify_broadcast_binary_ele(
        (32,), (64, 32), topi.maximum, np.maximum)
    verify_broadcast_binary_ele(
        (1, 2, 2, 1, 32), (64, 32), topi.minimum, np.minimum)

def test_power():
    verify_broadcast_binary_ele(
        (1, 2, 2), (2,), topi.power, np.power, lhs_min=0.001, rhs_min=0.001, rhs_max=2)
129

130 131 132
def test_mod():
    verify_broadcast_binary_ele(
        (1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")
133

134 135 136 137 138 139
def test_cmp():
    # explicit specify the output type
    def greater(x, y):
        return topi.greater(x, y).astype("int8")
    def less(x, y):
        return topi.less(x, y).astype("int8")
140 141 142 143 144 145 146 147
    def equal(x, y):
        return topi.equal(x, y).astype("int8")
    def not_equal(x, y):
        return topi.not_equal(x, y).astype("int8")
    def greater_equal(x, y):
        return topi.greater_equal(x, y).astype("int8")
    def less_equal(x, y):
        return topi.less_equal(x, y).astype("int8")
148 149 150 151
    verify_broadcast_binary_ele(
        (1, 2, 2), (2,), greater, np.greater)
    verify_broadcast_binary_ele(
        (2, 1, 2), (2, 3, 1), less, np.less)
152 153 154 155 156 157 158 159 160 161 162 163
    verify_broadcast_binary_ele(
        (2, 1, 2), (2, 3, 1), equal, np.equal,
        lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32')
    verify_broadcast_binary_ele(
        (2, 1, 2), (2, 3, 1), not_equal, np.not_equal,
        lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32')
    verify_broadcast_binary_ele(
        (7, 1, 5), (7, 3, 1), greater_equal, np.greater_equal,
        lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
    verify_broadcast_binary_ele(
        (7, 1, 5), (7, 3, 1), less_equal, np.less_equal,
        lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
164

165 166 167 168 169 170 171 172 173 174 175 176 177 178
def test_shift():
    # explicit specify the output type
    verify_broadcast_binary_ele(
        (2, 1, 2), None, topi.right_shift, np.right_shift,
        dtype="int32", rhs_min=0, rhs_max=32)

    verify_broadcast_binary_ele(
        (1, 2, 2), (2,), topi.left_shift, np.left_shift,
        dtype="int32", rhs_min=0, rhs_max=32)

    verify_broadcast_binary_ele(
        (1, 2, 2), (2,), topi.left_shift, np.left_shift,
        dtype="int8", rhs_min=0, rhs_max=32)

179

180
if __name__ == "__main__":
181
    test_add()
182
    test_shift()
183 184 185 186 187 188 189
    test_cmp()
    test_mod()
    test_subtract()
    test_multiply()
    test_divide()
    test_maximum_minmum()
    test_power()
190
    test_broadcast_to()