test_topi_broadcast.py 12 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
"""Test code for broadcasting operators."""
18
from common import get_all_backend
19 20 21 22
import numpy as np
import tvm
import topi

23

24
def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
25 26
    # Build the logic and compile the function
    A = tvm.placeholder(shape=in_shape, name="A")
27
    B = fbcast(A, out_shape)
28

29
    def check_device(device):
30 31
        ctx = tvm.context(device, 0)
        if not ctx.exist:
32 33
            print("Skip because %s is not enabled" % device)
            return
34
        print("Running on target: %s" % device)
35 36
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(B)
37 38 39 40 41
        foo = tvm.build(s, [A, B], device, name="broadcast_to")
        data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = np.broadcast_to(data_npy, out_shape)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
42
        foo(data_nd, out_nd)
43
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
44

45 46
    for target in get_all_backend():
        check_device(target)
47
    check_device("sdaccel")
48 49


50 51 52 53 54
def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
                                ftopi, fnumpy,
                                lhs_min=-100, lhs_max=100,
                                rhs_min=-100, rhs_max=100,
                                dtype="float32"):
55
    # Build the logic and compile the function
56 57 58 59 60
    A = (tvm.var("A", dtype=dtype) if lhs_shape is None
         else tvm.placeholder(shape=lhs_shape, name="A", dtype=dtype))
    B = (tvm.var("B", dtype=dtype) if rhs_shape is None
         else tvm.placeholder(shape=rhs_shape, name="B", dtype=dtype))
    C = ftopi(A, B)
61
    if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
62 63
        assert(isinstance(C, tvm.expr.Expr))
        return
64

65 66 67 68 69 70 71 72 73 74 75 76
    def gen_operand(shape, low, high, ctx):
        if shape is None:
            npy = float(np.random.uniform(low=low, high=high))
            if dtype.startswith('int'):
                npy = int(npy)
            nd = npy
        else:
            npy = np.random.uniform(low=low, high=high,
                                    size=shape).astype(dtype)
            nd = tvm.nd.array(npy, ctx)
        return npy, nd

77
    def check_device(device):
78 79
        ctx = tvm.context(device, 0)
        if not ctx.exist:
80 81
            print("Skip because %s is not enabled" % device)
            return
82
        print("Running on target: %s" % device)
83 84
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(C)
85 86
        foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + ftopi.__name__)

87 88
        lhs_npy, lhs_nd = gen_operand(lhs_shape, lhs_min, lhs_max, ctx)
        rhs_npy, rhs_nd = gen_operand(rhs_shape, rhs_min, rhs_max, ctx)
89
        out_npy = fnumpy(lhs_npy, rhs_npy)
90 91 92 93 94 95 96 97 98 99 100 101 102 103

        if fnumpy == np.floor_divide:
            # avoid check too close to X.5 and X.0
            # FIXME: floor_divide(94.90735, 0.6731018) behaves as floor(div(94.90735, 0.6731018))
            # However the result is somehow incorrect - need to further investigate.
            # And looks like numpy's floor_div(a,b) is implemented different from floor(div(a,b))
            mask = np.logical_or(np.abs(np.abs(np.fmod(lhs_npy / rhs_npy, 1)) - 0.5) < 1e-6,
                                 np.abs(np.fmod(lhs_npy / rhs_npy, 1)) < 1e-6)
            if mask.any():
                lhs_npy = lhs_npy + mask * 1e-3  * rhs_npy
                lhs_npy = lhs_npy.astype(dtype)
                lhs_nd = tvm.nd.array(lhs_npy, ctx) if lhs_shape is not None else lhs_npy.item()
                out_npy = fnumpy(lhs_npy, rhs_npy)

104
        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
105
        foo(lhs_nd, rhs_nd, out_nd)
106
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
107

108 109
    for target in get_all_backend():
        check_device(target)
110
    check_device("sdaccel")
111

112

113
def test_broadcast_to():
114 115 116 117 118
    verify_broadcast_to_ele((1,), (10,), topi.broadcast_to)
    verify_broadcast_to_ele((), (10,), topi.broadcast_to)
    verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4), topi.broadcast_to)
    verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32), topi.broadcast_to)

119

120 121
def test_add():
    verify_broadcast_binary_ele(
122 123
        (), (), topi.add, np.add)
    verify_broadcast_binary_ele(
124 125
        (5, 2, 3), (2, 1), topi.add, np.add)

126

127 128 129 130 131 132 133 134 135 136
def test_subtract():
    verify_broadcast_binary_ele(
        (5, 2, 3), (), topi.subtract, np.subtract)
    verify_broadcast_binary_ele(
        (5, 2, 3), None, topi.subtract, np.subtract)
    verify_broadcast_binary_ele(
        None, None, topi.subtract, np.subtract)
    verify_broadcast_binary_ele(
        (1, 32), (64, 32), topi.subtract, np.subtract)

137

138 139 140 141
def test_multiply():
    verify_broadcast_binary_ele(
        (5, 64, 128), (2, 5, 64, 1), topi.multiply, np.multiply)

142

143 144 145 146
def test_divide():
    verify_broadcast_binary_ele(
        None, (10,), topi.divide, np.divide, rhs_min=0.0001)
    verify_broadcast_binary_ele(
147 148
        (), None, topi.divide, np.divide, rhs_min=0.0001)
    verify_broadcast_binary_ele(
149 150
        (2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)

151 152 153 154 155 156
def test_floor_divide():
    verify_broadcast_binary_ele(
        None, (10,), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
    verify_broadcast_binary_ele(
        (), None, topi.floor_divide, np.floor_divide, rhs_min=0.0001)
    verify_broadcast_binary_ele(
157
        (2, 3, 64, 32), (64, 32), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
158

159 160 161 162 163 164
def test_maximum_minmum():
    verify_broadcast_binary_ele(
        (32,), (64, 32), topi.maximum, np.maximum)
    verify_broadcast_binary_ele(
        (1, 2, 2, 1, 32), (64, 32), topi.minimum, np.minimum)

165

166 167 168
def test_power():
    verify_broadcast_binary_ele(
        (1, 2, 2), (2,), topi.power, np.power, lhs_min=0.001, rhs_min=0.001, rhs_max=2)
169

170

171 172 173
def test_mod():
    verify_broadcast_binary_ele(
        (1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")
174

175 176 177 178 179
def test_floor_mod():
    verify_broadcast_binary_ele(
        (1, 2, 2), (2,), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="int32")
    verify_broadcast_binary_ele(
        (3, 4, 5), (3, 4, 5), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="float32")
180

181 182 183 184
def test_cmp():
    # explicit specify the output type
    def greater(x, y):
        return topi.greater(x, y).astype("int8")
185

186 187
    def less(x, y):
        return topi.less(x, y).astype("int8")
188

189 190
    def equal(x, y):
        return topi.equal(x, y).astype("int8")
191

192 193
    def not_equal(x, y):
        return topi.not_equal(x, y).astype("int8")
194

195 196
    def greater_equal(x, y):
        return topi.greater_equal(x, y).astype("int8")
197

198 199
    def less_equal(x, y):
        return topi.less_equal(x, y).astype("int8")
200 201 202 203
    verify_broadcast_binary_ele(
        (1, 2, 2), (2,), greater, np.greater)
    verify_broadcast_binary_ele(
        (2, 1, 2), (2, 3, 1), less, np.less)
204 205 206 207 208 209 210 211 212 213 214 215
    verify_broadcast_binary_ele(
        (2, 1, 2), (2, 3, 1), equal, np.equal,
        lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32')
    verify_broadcast_binary_ele(
        (2, 1, 2), (2, 3, 1), not_equal, np.not_equal,
        lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32')
    verify_broadcast_binary_ele(
        (7, 1, 5), (7, 3, 1), greater_equal, np.greater_equal,
        lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
    verify_broadcast_binary_ele(
        (7, 1, 5), (7, 3, 1), less_equal, np.less_equal,
        lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
216

217

218 219 220 221 222 223 224 225 226 227 228 229 230 231
def test_shift():
    # explicit specify the output type
    verify_broadcast_binary_ele(
        (2, 1, 2), None, topi.right_shift, np.right_shift,
        dtype="int32", rhs_min=0, rhs_max=32)

    verify_broadcast_binary_ele(
        (1, 2, 2), (2,), topi.left_shift, np.left_shift,
        dtype="int32", rhs_min=0, rhs_max=32)

    verify_broadcast_binary_ele(
        (1, 2, 2), (2,), topi.left_shift, np.left_shift,
        dtype="int8", rhs_min=0, rhs_max=32)

232

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
def test_logical_single_ele():
    def test_apply(
            func,
            name,
            f_numpy,
            indata,
            dtype="bool",
    ):
        # Build the logic and compile the function
        A = tvm.placeholder(shape=indata.shape, name="A", dtype=dtype)
        B = func(A)
        if isinstance(A, tvm.expr.Expr):
            assert (isinstance(B, tvm.expr.Expr))
            return

        def check_device(device):
            ctx = tvm.context(device, 0)
            if not ctx.exist:
                print("Skip because %s is not enabled" % device)
                return
            print("Running on target: %s" % device)
            with tvm.target.create(device):
                s = topi.generic.schedule_broadcast(B)
            foo = tvm.build(s, [A, B], device, name=name)

            data_npy = indata.astype(A.dtype)
            data_nd = tvm.nd.array(data_npy, ctx)

            out_npy = f_numpy(indata)
            out_nd = tvm.nd.array(np.empty(data_npy.shape).astype(B.dtype), ctx)
            foo(data_nd, out_nd)
            tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

        for device in get_all_backend():
            check_device(device)

    test_apply(topi.logical_not, "logical_not", np.logical_not, np.array([True, False, 0, 1]))
    test_apply(topi.logical_not, "logical_not", np.logical_not, np.array(np.arange(5) < 3))


def test_logical_binary_ele():
    def test_apply(
            func,
            name,
            f_numpy,
            lhs,
            rhs,
            dtype="bool",
    ):
        # Build the logic and compile the function
        A = (tvm.var("A", dtype=dtype))
        B = (tvm.var("B", dtype=dtype))
        C = func(A, B)
        if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
            assert (isinstance(C, tvm.expr.Expr))
            return

        def check_device(device):
            ctx = tvm.context(device, 0)
            if not ctx.exist:
                print("Skip because %s is not enabled" % device)
                return
            print("Running on target: %s" % device)
            with tvm.target.create(device):
                s = topi.generic.schedule_broadcast(C)
            foo = tvm.build(s, [A, B, C], device, name=name)

            lhs_nd = tvm.nd.array(lhs, ctx)
            rhs_nd = tvm.nd.array(rhs, ctx)

            out_npy = f_numpy(lhs, rhs)
            out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
            foo(lhs_nd, rhs_nd, out_nd)
            tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)

        for device in get_all_backend():
            check_device(device)

    test_apply(topi.logical_and, "logical_and", np.logical_and, True, False)
    test_apply(topi.logical_and, "logical_and", np.logical_and, [True, False], [False, False])
    test_apply(topi.logical_or, "logical_or", np.logical_or, True, False)
    test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False])


317
if __name__ == "__main__":
318
    test_add()
319
    test_shift()
320 321
    test_cmp()
    test_mod()
322
    test_floor_mod()
323 324 325
    test_subtract()
    test_multiply()
    test_divide()
326
    test_floor_divide()
327 328
    test_maximum_minmum()
    test_power()
329
    test_broadcast_to()
330 331
    test_logical_single_ele()
    test_logical_binary_ele()