test_topi_transform.py 32.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20
"""Test code for broadcasting operators."""
import numpy as np
import tvm
import topi
21
import topi.testing
22
from tvm.contrib.nvcc import have_fp16
23

24 25
from common import get_all_backend

26 27 28 29
def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
    A = tvm.placeholder(shape=in_shape, name="A")
    B = topi.expand_dims(A, axis, num_newaxis)
    def check_device(device):
30 31
        ctx = tvm.context(device, 0)
        if not ctx.exist:
32 33
            print("Skip because %s is not enabled" % device)
            return
34
        print("Running on target: %s" % device)
35 36
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(B)
37 38 39 40 41 42
        foo = tvm.build(s, [A, B], device, name="expand_dims")
        data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = data_npy.reshape(out_shape)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
        foo(data_nd, out_nd)
43
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
44

45
    for device in get_all_backend():
46
        check_device(device)
47 48


49 50 51 52 53 54 55 56
def verify_reinterpret(in_shape, in_dtype, out_dtype, generator):
    A = tvm.placeholder(shape=in_shape, name="A", dtype=in_dtype)
    B = topi.reinterpret(A, out_dtype)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
57 58 59
        if in_dtype == "float16" and device == 'cuda' and not have_fp16(ctx.compute_version):
            print("Skip because %s does not have fp16 support" % device)
            return
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_elemwise(B)
        foo = tvm.build(s, [A, B], device, name="reinterpret")
        data_npy = generator(in_shape).astype(in_dtype)
        out_npy = data_npy.view(B.dtype)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.array(np.empty(in_shape).astype(B.dtype), ctx)
        foo(data_nd, out_nd)
        np.testing.assert_equal(out_nd.asnumpy(), out_npy)

    for device in get_all_backend():
        check_device(device)


75
def verify_transpose(in_shape, axes):
76 77 78
    A = tvm.placeholder(shape=in_shape, name="A")
    B = topi.transpose(A, axes)
    def check_device(device):
79 80
        ctx = tvm.context(device, 0)
        if not ctx.exist:
81 82
            print("Skip because %s is not enabled" % device)
            return
83
        print("Running on target: %s" % device)
84 85
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)
86
        foo = tvm.build(s, [A, B], device, name="transpose")
87 88 89 90 91
        data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
        out_npy = data_npy.transpose(axes)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype)
        foo(data_nd, out_nd)
92
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
93

94
    for device in get_all_backend():
95 96
        check_device(device)

97

98 99 100 101
def verify_reshape(src_shape, dst_shape):
    A = tvm.placeholder(shape=src_shape, name="A")
    B = topi.reshape(A, dst_shape)
    def check_device(device):
102 103
        ctx = tvm.context(device, 0)
        if not ctx.exist:
104 105
            print("Skip because %s is not enabled" % device)
            return
106
        print("Running on target: %s" % device)
107 108
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)
109 110 111 112 113 114
        foo = tvm.build(s, [A, B], device, name="reshape")
        data_npy = np.random.normal(size=src_shape).astype(A.dtype)
        out_npy = np.reshape(data_npy, newshape=dst_shape)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.empty(dst_shape, ctx=ctx, dtype=B.dtype)
        foo(data_nd, out_nd)
115
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
116

117
    for device in get_all_backend():
118 119
        check_device(device)

120

121 122 123 124
def verify_squeeze(src_shape, axis):
    A = tvm.placeholder(shape=src_shape, name="A")
    B = topi.squeeze(A, axis=axis)
    def check_device(device):
125 126
        ctx = tvm.context(device, 0)
        if not ctx.exist:
127 128
            print("Skip because %s is not enabled" % device)
            return
129
        print("Running on target: %s" % device)
130 131
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)
132

133 134 135 136
        foo = tvm.build(s, [A, B], device, name="squeeze")
        data_npy = np.random.normal(size=src_shape).astype(A.dtype)
        out_npy = np.squeeze(data_npy, axis=axis)
        data_nd = tvm.nd.array(data_npy, ctx)
137
        out_nd_shape = out_npy.shape
138
        out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype)
139
        foo(data_nd, out_nd)
140
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
141

142
    for device in get_all_backend():
143
        check_device(device)
144

145 146 147 148 149 150
def verify_concatenate(shapes, axis):
    tensor_l = []
    for i, shape in enumerate(shapes):
        tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
    out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis)
    def check_device(device):
151 152
        ctx = tvm.context(device, 0)
        if not ctx.exist:
153 154
            print("Skip because %s is not enabled" % device)
            return
155
        print("Running on target: %s" % device)
156
        with tvm.target.create(device):
hlu1 committed
157
            s = topi.generic.schedule_concatenate(out_tensor)
158

159 160 161 162 163 164
        foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
        data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
        out_npy = np.concatenate(data_npys, axis=axis)
        data_nds = [tvm.nd.array(data_npy, ctx) for data_npy in data_npys]
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=out_tensor.dtype)
        foo(*(data_nds + [out_nd]))
165
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
166

167
    for device in get_all_backend():
168 169
        check_device(device)

170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
def verify_stack(shapes, axis):
    tensor_l = []
    for i, shape in enumerate(shapes):
        tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
    out_tensor = topi.stack(tensor_l, axis)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(out_tensor)

        foo = tvm.build(s, tensor_l + [out_tensor], device, name="stack")
        data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
        out_npy = np.stack(data_npys, axis=axis)
        data_nds = [tvm.nd.array(data_npy, ctx) for data_npy in data_npys]
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=out_tensor.dtype)
        foo(*(data_nds + [out_nd]))
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in get_all_backend():
        check_device(device)

195 196 197 198 199

def verify_split(src_shape, indices_or_sections, axis):
    A = tvm.placeholder(shape=src_shape, name="A")
    tensor_l = topi.split(A, indices_or_sections, axis=axis)
    def check_device(device):
200 201
        ctx = tvm.context(device, 0)
        if not ctx.exist:
202 203
            print("Skip because %s is not enabled" % device)
            return
204
        print("Running on target: %s" % device)
205 206
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(tensor_l)
207

208
        foo = tvm.build(s, [A] + list(tensor_l), device, name="split")
209 210 211 212 213 214
        data_npy = np.random.normal(size=src_shape).astype(A.dtype)
        out_npys = np.split(data_npy, indices_or_sections, axis=axis)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nds = [tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=tensor_l[0].dtype) for out_npy in out_npys]
        foo(*([data_nd] + out_nds))
        for out_nd, out_npy in zip(out_nds, out_npys):
215
            tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
216

217
    for device in get_all_backend():
218 219
        check_device(device)

220

221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
def verify_expand_like(in_shape, out_shape, axis):
    A = tvm.placeholder(shape=in_shape, name="A")
    B = tvm.placeholder(shape=out_shape, name="B")
    C = topi.expand_like(A, B, axis)
    s = tvm.create_schedule([C.op])

    def check_device(device):
        if not tvm.module.enabled(device):
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)

        ctx = tvm.context(device, 0)
        f = tvm.build(s, [A, B, C], device, name="expand_like")
        input = np.random.uniform(size=in_shape).astype(A.dtype)
        tvm_input = tvm.nd.array(input, ctx)

        odim = len(out_shape)
        real_axis = [x if x >= 0 else x + odim for x in axis]
        real_axis = sorted(real_axis)
        for x in real_axis:
            input = np.expand_dims(input, x).astype(A.dtype)
        for x in real_axis:
            input = np.concatenate([input]*out_shape[x], axis=x).astype(A.dtype)
        assert input.shape == out_shape

        tvm_shape_like = tvm.nd.array(np.zeros(out_shape).astype(B.dtype), ctx)
        out = tvm.nd.array(np.zeros(out_shape).astype(A.dtype), ctx)
        f(tvm_input, tvm_shape_like, out)
250
        tvm.testing.assert_allclose(out.asnumpy(), input)
251 252 253 254

    for device in ["llvm"]:
        check_device(device)

255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
def verify_flip(in_shape, axis):
    A = tvm.placeholder(shape=in_shape, name="A")
    B = topi.flip(A, axis) + 1
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)

        foo = tvm.build(s, [A, B], device, name="reverse")
        x_np = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = np.flip(x_np, axis) + 1
        data_nd = tvm.nd.array(x_np, ctx)
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
        foo(data_nd, out_nd)
273
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
274

275
    for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]:
276
        check_device(device)
277

278
def verify_take(src_shape, indices_src, axis=None, mode="clip"):
279 280 281 282 283 284
    src_dtype = "float32"
    indices_dtype = "int32"
    indices_src = np.array(indices_src, dtype=indices_dtype)
    A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
    indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
    if axis is None:
285
        out_tensor = topi.take(a=A, indices=indices, mode=mode)
286
    else:
287
        out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode)
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(out_tensor)

        foo = tvm.build(s, [A] + [indices] + [out_tensor] , device, name="take")
        shape_size = 1
        for i in range(len(src_shape)):
            shape_size = shape_size * src_shape[i]
        data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))

        if axis is None:
305 306
            np_mode = "raise" if mode == "fast" else mode
            out_npys = np.take(data_npy, indices_src, mode=np_mode)
307
        else:
308 309
            np_mode = "raise" if mode == "fast" else mode
            out_npys = np.take(data_npy, indices_src, axis=axis, mode=np_mode)
310 311 312 313
        data_nd = tvm.nd.array(data_npy, ctx)
        indices_nd = tvm.nd.array(indices_src, ctx)
        out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
        foo(data_nd, indices_nd, out_nd)
314
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys)
315

316
    for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
317 318
        check_device(device)

319
def verify_strided_slice(in_shape, begin, end, strides=None):
320
    A = tvm.placeholder(shape=in_shape, name="A")
321 322 323
    strides = [1,1,1] if strides is None else strides
    B = topi.strided_slice(A, begin, end, strides) + 1

324 325 326 327 328 329 330 331 332 333 334
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)

        foo = tvm.build(s, [A, B], device, name="stride_slice")
        x_np = np.random.uniform(size=in_shape).astype(A.dtype)
335 336
        out_npy = topi.testing.strided_slice_python(
            x_np, begin, end, strides) + 1
337 338 339
        data_nd = tvm.nd.array(x_np, ctx)
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
        foo(data_nd, out_nd)
340
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
341

342
    for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
343 344
        check_device(device)

345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
def verify_gather_nd(src_shape, indices_src, indices_dtype):
    src_dtype = "float32"
    indices_src = np.array(indices_src, dtype=indices_dtype)
    A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
    indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
    out_tensor = topi.gather_nd(a=A, indices=indices)

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(out_tensor)

        func = tvm.build(s, [A, indices, out_tensor] , device, name="take")
        shape_size = 1
        for i in range(len(src_shape)):
            shape_size = shape_size * src_shape[i]
        data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
        out_npys = topi.testing.gather_nd_python(data_npy, indices_src)
367

368 369 370 371 372 373 374 375 376
        data_nd = tvm.nd.array(data_npy, ctx)
        indices_nd = tvm.nd.array(indices_src, ctx)
        out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
        func(data_nd, indices_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys)

    for device in get_all_backend():
        check_device(device)

377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
def verify_arange(start, stop, step):
    if start is None and step is None:
        A = topi.arange(stop)
        a_np = np.arange(stop)
    elif start is None:
        A = topi.arange(stop, step=step)
        a_np = np.arange(stop, step=step)
    elif step is None:
        A = topi.arange(start, stop)
        a_np = np.arange(start, stop)
    else:
        A = topi.arange(start, stop, step)
        a_np = np.arange(start, stop, step)

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(A)
        f = tvm.build(s, [A], device, name="arange")
        a_nd = tvm.nd.empty(a_np.shape, dtype='float32', ctx=ctx)
        f(a_nd)
        tvm.testing.assert_allclose(a_nd.asnumpy(), a_np)

    for device in get_all_backend():
        check_device(device)

407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
def verify_repeat(in_shape, repeats, axis):
    A = tvm.placeholder(shape=in_shape, name="A")
    B = topi.repeat(A, repeats, axis)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(B)
        foo = tvm.build(s, [A, B], device, name="repeat")
        data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = np.repeat(data_npy, repeats, axis)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
        foo(data_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in get_all_backend():
        check_device(device)

def verify_tile(in_shape, reps):
    A = tvm.placeholder(shape=in_shape, name="A")
    B = topi.tile(A, reps)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(B)
        foo = tvm.build(s, [A, B], device, name="tile")
        data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = np.tile(data_npy, reps)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
        foo(data_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in get_all_backend():
        check_device(device)

451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
def verify_where(in_shape):
    Cond = tvm.placeholder(shape=in_shape, name="cond")
    dtype = Cond.dtype
    A = tvm.placeholder(shape=in_shape, name="A")
    B = tvm.placeholder(shape=in_shape, name="B")
    C = topi.where(Cond, A, B)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(C)
        f = tvm.build(s, [Cond, A, B, C], device, name="where")
        cond_npy = np.random.uniform(low=-1, high=1, size=in_shape).astype(dtype)
        x_npy = np.random.uniform(size=in_shape).astype(dtype)
        y_npy = np.random.uniform(size=in_shape).astype(dtype)
        out_npy = np.where(cond_npy, x_npy, y_npy)
        cond_nd = tvm.nd.array(cond_npy, ctx)
        x_nd = tvm.nd.array(x_npy, ctx)
        y_nd = tvm.nd.array(y_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
        f(cond_nd, x_nd, y_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

    for device in get_all_backend():
        check_device(device)

480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype):
    indices = tvm.placeholder(shape=indices_shape, name="indices", dtype="int32")
    on_value_const = tvm.const(on_value, dtype)
    off_value_const = tvm.const(off_value, dtype)
    one_hot_result = topi.transform.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype)
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(one_hot_result)
        fn = tvm.build(s, [indices, one_hot_result], device, name="one_hot")
        indices_npy = np.random.randint(0, depth, size=indices_shape).astype(indices.dtype)
        out_npy = topi.testing.one_hot(indices_npy, on_value, off_value, depth, axis, dtype)
        indices_nd = tvm.nd.array(indices_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(one_hot_result.dtype), ctx)
        fn(indices_nd, out_nd)
        out_topi = out_nd.asnumpy()
        tvm.testing.assert_allclose(out_topi, out_npy)

    for device in get_all_backend():
        check_device(device)

505 506 507 508 509 510 511
def test_strided_slice():
    verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
    verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1])
    verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2])
    verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
512

513 514
def test_expand_dims():
    verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
515 516 517
    verify_expand_dims((3, 10), (1, 3, 10), -3, 1)


518 519 520 521 522 523 524 525 526 527 528 529 530
def test_reinterpret():
    verify_reinterpret((1000,), "float32", "int32",
                       lambda shape: np.random.randn(*shape) * 1000)
    verify_reinterpret((1000,), "float16", "int16",
                       lambda shape: np.random.randn(*shape) * 100)
    verify_reinterpret((1000,), "int16", "uint16",
                       lambda shape: np.random.randint(-1000, 1000, size=shape))
    verify_reinterpret((1000,), "uint32", "int32",
                       lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape))
    verify_reinterpret((1000,), "uint32", "int32",
                       lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape))


531 532 533 534
def test_transpose():
    verify_transpose((3, 10, 2), (1, 0, 2))
    verify_transpose((3, 10, 5), (2, 0, 1))
    verify_transpose((3, 10), None)
535 536


537 538 539 540 541 542 543
def test_reshape():
    verify_reshape((1, 2, 3, 4), (2, 3, 4))
    verify_reshape((4, 2, 3, 4), (2, 4, 12))
    verify_reshape((4, 2, 3, 4), (2, 48))
    verify_reshape((16, ), (2, 2, 2, 2))


544 545 546 547
def test_where():
    verify_where((1, 2, 3, 4))


548 549 550 551
def test_squeeze():
    verify_squeeze((1, 2, 3, 4), 0)
    verify_squeeze((1, 2, 1, 4), None)
    verify_squeeze((1, 1, 1, 4), (1, 2))
552
    verify_squeeze((1, 1, 1, 1), None)
553

554 555 556 557 558 559 560 561 562 563 564 565 566 567 568
    # a special case to trigger inline let expression
    A = tvm.placeholder((2,), 'float32', 'A')
    E = topi.squeeze(A)
    C = tvm.compute((1,), lambda i: E[(2 * A[0] - 1).astype('int32')])
    for device in ['cuda', 'opencl']:
        ctx = tvm.context(device, 0)
        if ctx.exist:
            with tvm.target.create(device):
                s = topi.generic.schedule_injective(C)
                func = tvm.build(s, [A, C])
            a = tvm.nd.array(np.array((1, 2)).astype('float32'), ctx=ctx)
            c = tvm.nd.empty((1,), dtype='float32', ctx=ctx)
            func(a, c)
            assert c.asnumpy()[0] == 2

569

570
def test_concatenate():
571
    verify_concatenate([(2,), (2,), (2,)], -1)
572 573 574 575 576 577 578
    verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1)
    verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1)
    verify_concatenate([(5, 6, 7, 3),
                        (16, 6, 7, 3),
                        (12, 6, 7, 3),
                        (8, 6, 7, 3),
                        (2, 6, 7, 3)], 0)
hlu1 committed
579
    verify_concatenate([(1, 14400), (1, 2400), (1, 640), (1, 240)], 1)
580 581


582 583 584 585 586 587 588 589
def test_stack():
    verify_stack([(2,), (2,), (2,)], -1)
    verify_stack([(2,), (2,), (2,)], 1)
    verify_stack([(2,), (2,), (2,)], 0)
    verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
    verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)


590 591 592 593 594
def test_split():
    verify_split((2, 12, 3), 3, 1)
    verify_split((2, 12, 3), [2, 4], 1)
    verify_split((10, 12, 24), [5, 7, 9], -1)

595 596 597 598 599 600 601
def test_flip():
    verify_flip((3, 4, 3), 1)
    verify_flip((3, 4, 3), 0)
    verify_flip((3, 4, 3), 2)
    verify_flip((3, 4, 3), -1)
    verify_flip((3, 4, 3), -3)
    verify_flip((3, 4, 3), -2)
602 603 604 605 606 607 608

def test_expand_like():
    verify_expand_like((3,), (2, 3), [0])
    verify_expand_like((2,), (2, 3), [1])
    verify_expand_like((3, 4), (3, 5, 4), [1])
    verify_expand_like((5, 7), (5, 6, 7, 8), [1, 3])

609 610 611 612 613 614 615 616 617
def test_take():
    verify_take((4,), [1])
    verify_take((4,), [[0,1,2,3]])
    verify_take((3,3,3), [[11,25]])
    verify_take((4,), [[0,1],[2,3]])
    verify_take((4,), [1], 0)
    verify_take((2,2), [[[1,0],[0,1]]], 0)
    verify_take((2,2), [[[1,0],[0,1]]], 1)
    verify_take((4,3,5,6), [[2,1,0,0]], -2)
618 619 620 621 622 623
    verify_take((3,4), [-5, 20])
    verify_take((3,4), [-5, 20], mode="wrap")
    verify_take((3,4), [-1, 2], axis=0)
    verify_take((3,4), [-1, 2], axis=0, mode="wrap")
    verify_take((3,4), [-1, 2], axis=1)
    verify_take((3,4), [-1, 2], axis=1, mode="wrap")
624 625 626
    verify_take((3,3,3), [[11,25]], mode="fast")
    verify_take((3,4), [0, 2], axis=0, mode="fast")
    verify_take((3,4), [0, 2], axis=1, mode="fast")
627

628 629 630 631 632 633 634 635 636 637 638 639 640 641 642
def test_gather_nd():
    for indices_dtype in ['int32', 'float32']:
        verify_gather_nd((4,), [[1.8]], indices_dtype)
        verify_gather_nd((4,), [[1, 3, 2]], indices_dtype)
        verify_gather_nd((2, 3), [[1]], indices_dtype)
        verify_gather_nd((2, 3), [[1], [0]], indices_dtype)
        verify_gather_nd((2, 3), [[1, 0], [0, 2]], indices_dtype)
        verify_gather_nd((2, 3, 4), [[1, 0], [0, 2]], indices_dtype)
        verify_gather_nd((2, 3, 4), [[1, 0], [0, 2], [3, 1]], indices_dtype)
        verify_gather_nd((2, 3, 4), [[[1, 0], [0, 1]], [[0, 2], [1, 2]],
                                     [[3, 1], [0, 2]]], indices_dtype)
        verify_gather_nd((2, 3, 4, 5), [[1, 0], [0, 2]], indices_dtype)
        verify_gather_nd((2, 3, 4, 5), [[1, 0], [2, 1], [3, 2], [4, 2]],
                         indices_dtype)

643 644 645 646 647 648 649 650 651 652 653
def test_arange():
    verify_arange(None, 20, None)
    verify_arange(None, 20, 2)
    verify_arange(1, 20, None)
    verify_arange(1, 20, 2)
    verify_arange(1, 20, 1.5)
    verify_arange(1, 20.5, None)
    verify_arange(1, 20, 3)
    verify_arange(20, 1, -1)
    verify_arange(20, 1, -1.5)

654 655 656 657 658 659 660 661 662 663
def test_repeat():
    verify_repeat((2,), 1, 0)
    verify_repeat((3, 2), 2, 0)
    verify_repeat((3, 2, 4), 3, 1)
    verify_repeat((1, 3, 2, 4), 4, -1)

def test_tile():
    verify_tile((3, 2), (2, 3))
    verify_tile((3, 2, 5), (2,))
    verify_tile((3, ), (2, 3, 3))
664

665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692
def test_layout_transform():
    in_shape = (1, 32, 8, 8)
    A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
    B = topi.layout_transform(A, "NCHW", "NCHW16c")

    input = np.random.uniform(size=in_shape).astype(A.dtype)
    output = np.transpose(input, axes=(0, 2, 3, 1))
    output = np.reshape(output, newshape=(1, 8, 8, 2, 16))
    output = np.transpose(output, axes=(0, 3, 1, 2, 4))

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        tvm_input = tvm.nd.array(input, ctx)
        tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=B.dtype)
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)
        f = tvm.build(s, [A, B], device, name="layout_transform")
        f(tvm_input, tvm_output)
        tvm.testing.assert_allclose(tvm_output.asnumpy(), output)

    for backend in get_all_backend():
        check_device(backend)


693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719
def test_shape():
    in_shape = (8, 7, 13)
    dtype = "int32"
    A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
    B = topi.shape(A, dtype)

    input = np.random.uniform(size=in_shape).astype(A.dtype)
    output = np.asarray(in_shape).astype(dtype)

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        tvm_input = tvm.nd.array(input, ctx)
        tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=dtype)
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)
        f = tvm.build(s, [A, B], device, name="shape")
        f(tvm_input, tvm_output)
        tvm.testing.assert_allclose(tvm_output.asnumpy(), output)

    for backend in get_all_backend():
        check_device(backend)


720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749
def test_sequence_mask():
    for in_shape in (5, 10), (3, 4, 5, 4):
        for axis in [0, 1]:
            for mask_value in [0.0, 1.0]:
                max_length = in_shape[axis]
                batch_size = in_shape[1 - axis]
                A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
                B = tvm.placeholder(shape=(batch_size,), dtype="int32", name="B")
                C = topi.sequence_mask(A, B, axis=axis, mask_value=mask_value)
                A_data = np.random.normal(0, 1, in_shape).astype(np.float32)
                B_data = np.random.randint(1, max_length, (batch_size,)).astype(np.int32)
                C_gt_data = topi.testing.sequence_mask(A_data, B_data, mask_value, axis)

                def check_device(device):
                    ctx = tvm.context(device, 0)
                    if not ctx.exist:
                        print("Skip because %s is not enabled" % device)
                        return
                    tvm_A = tvm.nd.array(A_data, ctx)
                    tvm_B = tvm.nd.array(B_data, ctx)
                    tvm_C = tvm.nd.empty(in_shape, ctx=ctx, dtype="float32")
                    print("Running on target: %s" % device)
                    with tvm.target.create(device):
                        s = topi.generic.schedule_injective(C)
                    f = tvm.build(s, [A, B, C], device, name="SequenceMask")
                    f(tvm_A, tvm_B, tvm_C)
                    tvm.testing.assert_allclose(tvm_C.asnumpy(), C_gt_data)
                for backend in get_all_backend():
                    check_device(backend)

750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776
def test_ndarray_size():
    in_shape = (5, 11, 7)
    dtype = "int32"
    A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
    B = topi.ndarray_size(A, dtype)

    input = np.random.uniform(size=in_shape).astype(A.dtype)
    output = np.asarray(np.size(input)).astype(dtype)

    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        tvm_input = tvm.nd.array(input, ctx=ctx)
        tvm_output = tvm.nd.empty((1,), ctx=ctx, dtype=B.dtype)
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)
        f = tvm.build(s, [A, B], device, name="ndarray_size")
        f(tvm_input, tvm_output)
        tvm.testing.assert_allclose(tvm_output.asnumpy(), output)

    for backend in get_all_backend():
        check_device(backend)


777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801
def test_where_fusion():
    """integration test that where and zeros should be properly inlined"""
    def check_device(device):
        with tvm.target.create(device):
            ctx = tvm.context(device, 0)
            if not ctx.exist:
                print("Skip because %s is not enabled" % device)
                return
            print("Running on target: %s" % device)
            data = tvm.placeholder((2, 1, 2, 4), 'int8', 'data')
            w = tvm.placeholder((3, 1, 2, 2), 'int8', 'w')
            conv1 = topi.nn.conv2d(data, w, 1, 0, 1, out_dtype='int32')
            zeros = topi.full((2, 3, 1, 3), 'int32', tvm.const(0, dtype='int32'))
            gt = topi.greater_equal(conv1, zeros)
            one = topi.full((2, 3, 1, 3), 'int32', tvm.const(1, dtype='int32'))
            two = topi.full((2, 3, 1, 3), 'int32', tvm.const(2, dtype='int32'))
            where = topi.where(gt, one, two)
            add = topi.add(conv1, where)
            outs = [add]
            s = topi.generic.schedule_conv2d_nchw(outs)
            tvm.build(s, [data, w, add], target=backend)

    for backend in get_all_backend():
        check_device(backend)

802 803 804 805 806 807 808
def test_one_hot():
    verify_one_hot((3,), 3, 1, 0, -1, "int32")
    verify_one_hot((3,), 3, 1.0, 0.0, -1, "float32")
    verify_one_hot((2, 2), 5, 2, -2, 0, "int32")
    verify_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32")
    verify_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
    verify_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
809

810
if __name__ == "__main__":
811
    test_strided_slice()
812
    test_concatenate()
813
    test_stack()
814
    test_transpose()
815
    test_expand_dims()
816
    test_reshape()
817
    test_where()
818
    test_squeeze()
819
    test_split()
820
    test_flip()
821
    test_expand_like()
822
    test_take()
823
    test_gather_nd()
824
    test_arange()
825
    test_layout_transform()
826 827
    test_repeat()
    test_tile()
828
    test_shape()
829
    test_sequence_mask()
830
    test_ndarray_size()
831
    test_where_fusion()
832
    test_one_hot()