test_top_level1.py 20.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
Yuwei HU committed
17 18
import numpy as np
import tvm
19
from tvm.contrib import graph_runtime
20
import topi.testing
Yuwei HU committed
21 22
import nnvm.symbol as sym
import nnvm.compiler
23
from nnvm.testing.config import ctx_list
24
from nnvm.testing.check_computation import check_function
Yuwei HU committed
25

26 27
def test_check_function():
    # test the testing function
28

29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
    x = sym.Variable("x")
    y = sym.Variable("y")

    # different styles of returning gradients from the backward function
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: [head_grads, 2*head_grads],
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float32')
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: (head_grads, 2*head_grads),
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float32')
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: {'x': head_grads, 'y': 2*head_grads},
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float32')
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: {'y': 2*head_grads},
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float32')
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: [2*head_grads],
                   grad_input_vars=[y],
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float32')
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: 2*head_grads,
                   grad_input_vars=[y],
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float32')
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: 2*head_grads,
                   grad_input_vars=[y],
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float64')

    # test just numerical gradients
    # different styles of shape and dtype passing
    check_function(x + 2*y, shape={'x': (1, 2), y: (1, 2)},
                   numerical_grads=True)
    check_function(x + 2*y, shape={'x': (1, 2), y: (1, 2)}, dtype='float32',
                   numerical_grads=True)
    check_function(x + 2*y, shape={'x': (1, 2), y: (1, 2)}, dtype={x: 'float32', 'y': 'float32'},
                   numerical_grads=True)
    check_function(x + 2*y, shape=(1, 2), dtype='float32',
                   numerical_grads=True)

    # specifying variable attributes on variable creation
    # (in this case type codes must be used)
    x = sym.Variable("x", dtype=0, shape=(1, 2))
    check_function(x + 2*y, shape={y: (1, 2)}, dtype={'y': 'float32'}, numerical_grads=True)
    y = sym.Variable("y", dtype=0, shape=(1, 2))

    # shape overriding
    def _fwd1(x, y):
        assert x.shape == (1, 1)
        assert y.shape == (1, 2)
        return x + 2*y
    check_function(x + 2*y, _fwd1, shape={x: (1, 1)})

    # in_range
    def _fwd2(x, y):
        assert x.shape == (100,)
        assert (x <= 0.9).all()
        assert (x >= 0.8).all()
        return x + 2*y
    check_function(x + 2*y, _fwd2, shape=(100,), in_range=(0.8, 0.9), numerical_grads=False)
    check_function(x + 2*y, _fwd2, shape=(100,), in_range={'x': (0.8, 0.9)}, numerical_grads=False)
    check_function(x + 2*y, backward=lambda x, y, head_grads: [1.0, 2.0],
                   in_range={'head_grads_0': (1.0, 1.0)})
    # explicit passing of values
    check_function(x + 2*y, backward=lambda x, y, head_grads: [1.0, 2.0],
                   values={'head_grads_0': np.full((1, 2), 1.0)})

    # check that the function reports errors
    def _check_function_must_fail(*args, **kwargs):
        error = AssertionError
        if 'error' in kwargs:
            error = kwargs['error']
            del kwargs['error']
        try:
            check_function(*args, quiet=True, **kwargs)
        except error:
            pass
        else:
            raise AssertionError("check_function didn't raise an exception")

    _check_function_must_fail(x + 2*y, error=ValueError)
    _check_function_must_fail(x + 2*y, lambda x, y: x + y)
    _check_function_must_fail(x + 2*y, backward=lambda x, y, head_grads: [1.0, 2.0])
    _check_function_must_fail(sym.block_grad(x + 2*y), numerical_grads=True)
    _check_function_must_fail(x*x, numerical_grads=True,
                              numerical_grads_params={'atol': 0.0, 'rtol': 0.0})
115
    _check_function_must_fail(sym.log(-x*x), numerical_grads=True, error=ValueError)
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180

    # different styles of returning results from the forward function
    check_function(x + 2*y, lambda x, y: [x + 2*y], numerical_grads=False)
    _check_function_must_fail(x + 2*y, lambda x, y: [x + 2*y, x], numerical_grads=False,
                              error=ValueError)
    _check_function_must_fail(x + 2*y, lambda x, y: [], numerical_grads=False,
                              error=ValueError)

    # multiple outputs
    z = sym.Group([2*x + y, x + 2*y])
    check_function(z, lambda x, y: [2*x + y, x + 2*y])
    check_function(z, lambda x, y: (2*x + y, x + 2*y))
    check_function(z, backward=lambda x, y, head_grads: [2*head_grads[0] + head_grads[1],
                                                         head_grads[0] + 2*head_grads[1]])
    _check_function_must_fail(z, backward=lambda x, y, head_grads: [2*head_grads[0],
                                                                    2*head_grads[1]])
    check_function(z, backward=lambda x, y, head_grads: [head_grads[1], 2*head_grads[1]],
                   in_range={'head_grads_0': (0, 0)})
    check_function(z, numerical_grads=True)

    z = sym.Group([sym.block_grad(2*x + y), x + 2*y])
    check_function(z, lambda x, y: [2*x + y, x + 2*y], numerical_grads=False)
    _check_function_must_fail(z, lambda x, y: [2*x + y, x + 2*y])
    _check_function_must_fail(z, numerical_grads=True)

    z = sym.Group([2*x + y, sym.block_grad(x + 2*y)])
    _check_function_must_fail(z, numerical_grads=True)

    z = sym.Group([2*x + y, x + 2*y, x, y, sym.sum(x)])
    check_function(z, lambda x, y: [2*x + y, x + 2*y, x, y, np.sum(x)])

    # passing additional parameters to forward and backward
    def _fwd3(x, p):
        assert p == 'v'
        return x + 1
    def _bwd3(x, p, head_grads):
        assert p == 'v'
        return head_grads
    check_function(x + 1, _fwd3, _bwd3, additional_params={'p': 'v'})

    # implicitly created variables and shape/dtype inference for inputs
    x = sym.Variable("x", shape=(2, 3), dtype=0)
    b = sym.Variable("b")
    y = sym.dense(data=x, bias=b, units=4)
    # Don't check gradients on cuda because is doesn't yet support ewise after reduce
    check_function(y, exclude_targets={'cuda'}, numerical_grads=True)
    check_function(y, shape={'x': (3, 4)}, exclude_targets={'cuda'}, numerical_grads=True)
    check_function(y, dtype={'x': 'float64'}, exclude_targets={'cuda'}, numerical_grads=True)

    x = sym.Variable("x")
    b = sym.Variable("b")
    w = sym.Variable("w")
    y = sym.dense(data=x, bias=b, weight=w, units=4)
    def _fwd_dense(x, w, b):
        return np.dot(x, w.T) + b
    check_function(y, _fwd_dense, shape={'x': (1,2)}, dtype={'x': 'float32'}, numerical_grads=False)
    check_function(y, _fwd_dense, shape={'x': (1,2)}, dtype={'w': 'float64'}, numerical_grads=False)
    _check_function_must_fail(y, _fwd_dense, shape={'x': (1,2)},
                              dtype={'w': 'float64', 'b': 'float32'},
                              numerical_grads=False,
                              error=nnvm._base.NNVMError)
    # fails because no shape
    _check_function_must_fail(y, _fwd_dense, numerical_grads=False, error=ValueError)
    # ok because type is float32 by default
    check_function(y, _fwd_dense, shape={'x': (1,2)}, numerical_grads=False)
181

182 183
def test_relu():
    x = sym.Variable("x")
184 185 186 187 188 189
    y = sym.relu(sym.leaky_relu(x, alpha=0.3) - 0.2)

    def forward(x):
        x = (x < 0) * x * 0.3 + (x > 0) * x - 0.2
        return (x > 0) * x

Yao Wang committed
190 191 192 193 194
    def backward(head_grads, x):
        sub = (x < 0) * x * 0.3 + (x > 0) * x - 0.2
        return [(sub > 0).astype("float") * \
                ((x > 0).astype("float") + 0.3 * (x < 0).astype("float")) * head_grads]

195 196
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, shape=shape)
197

198 199 200 201 202 203 204 205
def test_prelu_nchw():
    x = sym.Variable("x")
    a = sym.Variable("a")
    y = sym.prelu(data=x, alpha=a)

    def forward(x, a):
        return (x < 0) * (x * a.reshape(3, 1, 1)) + (x>=0) * x

206 207
    shape = {'x': (1, 3, 32, 32), 'a': (3,)}
    check_function(y, forward, shape=shape)
208 209 210 211 212 213 214 215 216

def test_prelu_nhwc():
    x = sym.Variable("x")
    a = sym.Variable("a")
    y = sym.prelu(data=x, alpha=a, axis=3)

    def forward(x, a):
        return (x < 0) * (x * a.reshape(1, 1, 3)) + (x>=0) * x

217 218
    shape = {'x': (1, 32, 32, 3), 'a': (3,)}
    check_function(y, forward, shape=shape)
219 220 221 222 223 224 225 226 227

def test_sym_scalar_pow():
    scalar = 3
    x = sym.Variable("x")
    y = x**scalar

    def forward(x):
        return x**scalar

Yao Wang committed
228 229
    def backward(head_grads, x):
        return [scalar * x**(scalar -  1) * head_grads]
230

231 232
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, shape=shape)
233 234 235 236 237 238 239 240 241 242


def test_scalar_sym_pow():
    scalar = 3
    x = sym.Variable("x")
    y = scalar**x

    def forward(x):
        return scalar**x

Yao Wang committed
243 244
    def backward(head_grads, x):
        return [np.log(scalar) * scalar**x * head_grads]
245

246 247
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, shape=shape)
248 249 250 251 252


def test_exp():
    x = sym.Variable("x")
    y = sym.exp(x)
253 254 255 256

    def forward(x):
        return np.exp(x)

Yao Wang committed
257 258
    def backward(head_grads, x):
        return [np.exp(x) * head_grads]
259

260 261
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, shape=shape)
262 263 264 265 266


def test_log():
    x = sym.Variable("x")
    y = sym.log(x)
267 268 269 270

    def forward(x):
        return np.log(x)

Yao Wang committed
271 272
    def backward(head_grads, x):
        return [1. / x * head_grads]
273

274 275
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, in_range=(0.002, 2.0), shape=shape)
276 277 278 279 280


def test_tanh():
    x = sym.Variable("x")
    y = sym.tanh(x)
281 282 283 284

    def forward(x):
        return np.sinh(x) / np.cosh(x)

Yao Wang committed
285
    def backward(head_grads, x):
286
        y_np = forward(x)
Yao Wang committed
287
        return [(1 - y_np**2) * head_grads]
288

289 290
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, shape=shape)
291 292 293 294 295


def test_sigmoid():
    x = sym.Variable("x")
    y = sym.sigmoid(x)
296 297 298 299

    def forward(x):
        return 1.0 / (1.0 + np.exp(-x))

Yao Wang committed
300
    def backward(head_grads, x):
301
        y_np = forward(x)
Yao Wang committed
302
        return [y_np *(1 - y_np) * head_grads]
303

304 305
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, shape=shape)
306 307


Yuwei HU committed
308 309 310
def test_softmax():
    x = sym.Variable("x")
    y = sym.softmax(x)
311 312 313 314

    def forward(x):
        return topi.testing.softmax_python(x)

Yao Wang committed
315 316 317 318 319
    def backward(head_grads, x):
        y = topi.testing.softmax_python(x)
        grad = y * (head_grads - np.sum(y * head_grads, axis=1, keepdims=True))
        return [grad]

320 321 322 323
    check_function(y, forward, backward,
                   shape={'x': (10, 1000)}, numerical_grads=False)
    check_function(y, forward, backward,
                   shape={'x': (2, 10)})
324 325


326 327 328
def test_log_softmax():
    x = sym.Variable("x")
    y = sym.log_softmax(x)
329 330 331 332

    def forward(x):
        return topi.testing.log_softmax_python(x)

Yao Wang committed
333 334
    def backward(head_grads, x):
        y = topi.testing.log_softmax_python(x)
335
        grad = head_grads - np.exp(y) * np.sum(head_grads, axis=1, keepdims=True)
Yao Wang committed
336 337
        return [grad]

338 339 340 341
    check_function(y, forward, backward,
                   shape={'x': (10, 1000)}, numerical_grads=False)
    check_function(y, forward, backward,
                   shape={'x': (2, 10)})
342 343


344
def test_dense():
Yao Wang committed
345 346 347 348
    x = sym.Variable("x", shape=(10, 100))
    w = sym.Variable("dense_weight", shape=(3, 100))
    b = sym.Variable("dense_bias", shape=(3,))
    y = sym.dense(x, w, b, use_bias=True, units=3, name="dense")
349
    y = sym.flatten(y)
350 351 352

    def forward(x, dense_weight, dense_bias):
        return np.dot(x, dense_weight.T) + dense_bias
353 354 355 356 357 358 359 360 361 362
    shape = {
        'x': (10, 100),
        'w': (3, 100),
        'b': (3,)
    }
    # Don't check gradients on cuda because is doesn't yet support ewise after reduce
    check_function(y, forward, shape=shape,
                   exclude_targets={'cuda'}, numerical_grads=True)
    check_function(y, forward, shape=shape,
                   only_targets={'cuda'}, numerical_grads=False)
363 364 365 366 367 368 369 370 371 372 373 374


def test_batchnorm():
    x = sym.Variable("x")
    beta = sym.Variable("beta")
    gamma = sym.Variable("gamma")
    moving_var = sym.Variable("moving_var")
    moving_mean = sym.Variable("moving_mean")
    eps = 1e-5
    y = sym.batch_norm(
        x, gamma, beta, moving_mean, moving_var, epsilon=eps)

375 376 377
    def forward(x, gamma, beta, moving_mean, moving_var):
        return (x - moving_mean) / np.sqrt(moving_var + eps) * gamma + beta

378 379 380 381 382 383 384
    shape = {
        'x': (10, 20),
        'gamma': (20,),
        'beta': (20,),
        'moving_mean': (20,),
        'moving_var': (20,)
    }
385

386
    check_function(y, forward, in_range=(0.001, 1.0), shape=shape)
Yuwei HU committed
387 388


389
def verify_concatenate(ishape, axis):
390
    x = [sym.Variable("x%d" % i, shape=ishape[i]) for i in range(len(ishape))]
391
    y = sym.concatenate(*x, axis=axis) + 1
392 393 394 395 396

    def forward(**kwargs):
        return np.concatenate(list(kwargs.values()), axis=axis) + 1

    check_function(y, forward)
397

398

399 400 401 402 403 404
def test_concatenate():
    verify_concatenate([(2, 3, 4), (1, 3, 4)], axis=0)
    verify_concatenate([(2, 4), (2, 7)], axis=1)


def verify_split(ishape, indices_or_sections, axis):
405
    x = sym.Variable("x", shape=ishape)
406
    y = sym.split(x, indices_or_sections=indices_or_sections, axis=axis)
407 408 409 410 411

    def forward(x):
        return np.split(x, indices_or_sections, axis=axis)

    check_function(y, forward)
412

413

414 415 416 417 418
def test_split():
    verify_split((2, 3), 2, axis=0)
    verify_split((5, 3), [3], axis=0)
    verify_split((5, 9, 3), [3, 4], axis=1)

419 420
def verify_strided_slice(ishape, begin, end, strideinp=None):
    stride = strideinp if strideinp else [1, 1, 1]
421
    x = sym.Variable("x", shape=ishape)
422 423 424 425
    if strideinp:
        y = sym.strided_slice(x, begin = begin, end = end, stride = stride) + 1
    else:
        y = sym.strided_slice(x, begin = begin, end = end) + 1
426

427 428 429 430
    for i in range(len(begin), 3):
        begin.append(0)
    for i in range(len(end), 3):
        end.append(ishape[i])
431 432

    def test_forward(x):
433 434 435
        return x[begin[0]:end[0]:stride[0],
                    begin[1]:end[1]:stride[1], begin[2]:end[2]:stride[2]] + 1

436
    check_function(y, test_forward)
437 438 439 440 441 442 443 444 445 446 447

def test_strided_slice():
    verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
    verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1])
    verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2])
    verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3])
    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4])
    verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3])
Xingjian Shi committed
448

449 450 451 452
def verify_take(src_shape, indices_src, axis=None):
    src_dtype = "float32"
    indices_dtype = "int32"
    indices_src = np.array(indices_src, dtype=indices_dtype)
453 454
    a = sym.Variable("a", shape=src_shape)
    indices = sym.Variable("indices", shape=indices_src.shape)
455
    y = sym.take(a, indices, axis=axis)
456 457 458 459 460 461 462 463 464

    def forward(a, indices):
        return np.take(a, indices=indices, axis=axis)

    a_src = np.arange(np.prod(src_shape), dtype=src_dtype).reshape(src_shape)

    check_function(y, forward,
                   dtype={'a': src_dtype, 'indices': indices_dtype},
                   values={'a': a_src, 'indices': indices_src})
465 466 467 468 469 470 471 472 473 474 475 476

def test_take():
    verify_take((4,), [1])
    verify_take((4,), [[0,1,2,3]])
    verify_take((3,3,3), [[11,25]])
    verify_take((4,), [[0,1],[2,3]])
    verify_take((4,), [1], 0)
    verify_take((2,2), [[[1,0],[0,1]]], 0)
    verify_take((2,2), [[[1,0],[0,1]]], 1)
    verify_take((4,3,5,6), [[2,1,0,0]], -2)


477
def verify_squeeze(shape, axis):
Xingjian Shi committed
478
    x = sym.Variable("x")
479
    if axis is not None:
Xingjian Shi committed
480 481 482 483
        y = sym.squeeze(x, axis=axis)
    else:
        y = sym.squeeze(x)
    y = y + 1
484 485 486 487

    def forward(x):
        return np.squeeze(x, axis=axis) + 1

Yao Wang committed
488 489 490
    def backward(head_grads, x):
        return [np.reshape(head_grads, x.shape)]

491
    check_function(y, forward, backward, shape=shape)
492

Xingjian Shi committed
493 494 495 496 497 498

def test_squeeze():
    verify_squeeze((1, 3, 2, 5), None)
    verify_squeeze((1, 3, 1), axis=0)
    verify_squeeze((1, 3, 2, 5, 1), axis=-1)

Yuwei Hu committed
499 500 501 502

def test_pad():
    x = sym.Variable("x")
    y = sym.pad(x, pad_width=((0, 0), (0, 0), (0, 1), (2, 3)), pad_value=1.)
503 504 505 506 507 508

    def forward(x):
        return np.pad(x,
                      pad_width=((0, 0), (0, 0), (0, 1), (2, 3)),
                      mode='constant', constant_values=1.)

509 510
    shape = {'x': (1, 3, 28, 28)}
    check_function(y, forward, shape=shape)
Yuwei Hu committed
511

512
def verify_lrn(ishape, size, axis, bias, alpha, beta):
513
    x = sym.Variable("x", shape=ishape)
514 515
    y = sym.lrn(x, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)

516 517 518 519 520 521 522 523
    def forward1(x):
        return topi.testing.lrn_python(x, size, axis, bias, alpha, beta)

    check_function(y, forward1)

    def forward2(x):
        y = forward1(x)
        return (y > 0)*y
524 525

    #Checking LRN op followed by elementwise op relu
526
    check_function(sym.relu(y), forward2, in_range={'x': (-10.0, 10.0)})
527 528

def verify_l2_normalize(ishape, eps, axis):
529
    x = sym.Variable("x", shape=ishape)
530 531
    y = sym.l2_normalize(x, eps=eps, axis=axis)

532 533 534 535 536 537 538 539
    def forward1(x):
        return topi.testing.l2_normalize_python(x, eps, axis)

    check_function(y, forward1)

    def forward2(x):
        y = forward1(x)
        return (y > 0)*y
540 541

    #Checking L2 normalization op followed by elementwise op relu
542
    check_function(sym.relu(y), forward2, in_range={'x': (-10.0, 10.0)})
543 544 545 546 547 548 549 550

def test_lrn():
    verify_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
    verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75)

def test_l2_normalize():
    verify_l2_normalize((1, 3, 20, 20), 0.001, (1,))
    verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2))
Yuwei Hu committed
551

552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581
def verify_gather_nd(src_shape, indices_src):
    src_dtype = "float32"
    indices_dtype = "int32"
    indices_src = np.array(indices_src, dtype=indices_dtype)
    a = sym.Variable("a", shape=src_shape)
    indices = sym.Variable("indices", shape=indices_src.shape)
    y = sym.gather_nd(a, indices)

    def forward(a, indices):
        return topi.testing.gather_nd_python(a, indices)

    a_src = np.arange(np.prod(src_shape), dtype=src_dtype).reshape(src_shape)

    check_function(y, forward,
                   dtype={'a': src_dtype, 'indices': indices_dtype},
                   values={'a': a_src, 'indices': indices_src})

def test_gather_nd():
    verify_gather_nd((4,), [[1]])
    verify_gather_nd((4,), [[1, 3, 2]])
    verify_gather_nd((2, 3), [[1]])
    verify_gather_nd((2, 3), [[1], [0]])
    verify_gather_nd((2, 3), [[1, 0], [0, 2]])
    verify_gather_nd((2, 3, 4), [[1, 0], [0, 2]])
    verify_gather_nd((2, 3, 4), [[1, 0], [0, 2], [3, 1]])
    verify_gather_nd((2, 3, 4), [[[1, 0], [0, 1]], [[0, 2], [1, 2]],
                                 [[3, 1], [0, 2]]])
    verify_gather_nd((2, 3, 4, 5), [[1, 0], [0, 2]])
    verify_gather_nd((2, 3, 4, 5), [[1, 0], [2, 1], [3, 2], [4, 2]])

Yuwei HU committed
582
if __name__ == "__main__":
583
    test_check_function()
584 585
    test_split()
    test_concatenate()
586
    test_log_softmax()
587 588
    test_batchnorm()
    test_dense()
589
    test_relu()
590 591
    test_prelu_nchw()
    test_prelu_nhwc()
592 593
    test_sym_scalar_pow()
    test_scalar_sym_pow()
594 595 596 597
    test_exp()
    test_log()
    test_tanh()
    test_sigmoid()
Yuwei HU committed
598
    test_softmax()
Xingjian Shi committed
599
    test_squeeze()
Yuwei Hu committed
600
    test_pad()
601
    test_take()
602 603
    test_lrn()
    test_l2_normalize()
604
    test_strided_slice()
605
    test_gather_nd()