test_top_level1.py 19.5 KB
Newer Older
Yuwei HU committed
1 2
import numpy as np
import tvm
3
from tvm.contrib import graph_runtime
4
import topi.testing
Yuwei HU committed
5 6
import nnvm.symbol as sym
import nnvm.compiler
7
from nnvm.testing.config import ctx_list
8
from nnvm.testing.check_computation import check_function
Yuwei HU committed
9

10 11
def test_check_function():
    # test the testing function
12

13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    x = sym.Variable("x")
    y = sym.Variable("y")

    # different styles of returning gradients from the backward function
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: [head_grads, 2*head_grads],
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float32')
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: (head_grads, 2*head_grads),
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float32')
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: {'x': head_grads, 'y': 2*head_grads},
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float32')
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: {'y': 2*head_grads},
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float32')
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: [2*head_grads],
                   grad_input_vars=[y],
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float32')
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: 2*head_grads,
                   grad_input_vars=[y],
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float32')
    check_function(x + 2*y, lambda x, y: x + 2*y,
                   lambda x, y, head_grads: 2*head_grads,
                   grad_input_vars=[y],
                   shape={'x': (1, 2), y: (1, 2)}, dtype='float64')

    # test just numerical gradients
    # different styles of shape and dtype passing
    check_function(x + 2*y, shape={'x': (1, 2), y: (1, 2)},
                   numerical_grads=True)
    check_function(x + 2*y, shape={'x': (1, 2), y: (1, 2)}, dtype='float32',
                   numerical_grads=True)
    check_function(x + 2*y, shape={'x': (1, 2), y: (1, 2)}, dtype={x: 'float32', 'y': 'float32'},
                   numerical_grads=True)
    check_function(x + 2*y, shape=(1, 2), dtype='float32',
                   numerical_grads=True)

    # specifying variable attributes on variable creation
    # (in this case type codes must be used)
    x = sym.Variable("x", dtype=0, shape=(1, 2))
    check_function(x + 2*y, shape={y: (1, 2)}, dtype={'y': 'float32'}, numerical_grads=True)
    y = sym.Variable("y", dtype=0, shape=(1, 2))

    # shape overriding
    def _fwd1(x, y):
        assert x.shape == (1, 1)
        assert y.shape == (1, 2)
        return x + 2*y
    check_function(x + 2*y, _fwd1, shape={x: (1, 1)})

    # in_range
    def _fwd2(x, y):
        assert x.shape == (100,)
        assert (x <= 0.9).all()
        assert (x >= 0.8).all()
        return x + 2*y
    check_function(x + 2*y, _fwd2, shape=(100,), in_range=(0.8, 0.9), numerical_grads=False)
    check_function(x + 2*y, _fwd2, shape=(100,), in_range={'x': (0.8, 0.9)}, numerical_grads=False)
    check_function(x + 2*y, backward=lambda x, y, head_grads: [1.0, 2.0],
                   in_range={'head_grads_0': (1.0, 1.0)})
    # explicit passing of values
    check_function(x + 2*y, backward=lambda x, y, head_grads: [1.0, 2.0],
                   values={'head_grads_0': np.full((1, 2), 1.0)})

    # check that the function reports errors
    def _check_function_must_fail(*args, **kwargs):
        error = AssertionError
        if 'error' in kwargs:
            error = kwargs['error']
            del kwargs['error']
        try:
            check_function(*args, quiet=True, **kwargs)
        except error:
            pass
        else:
            raise AssertionError("check_function didn't raise an exception")

    _check_function_must_fail(x + 2*y, error=ValueError)
    _check_function_must_fail(x + 2*y, lambda x, y: x + y)
    _check_function_must_fail(x + 2*y, backward=lambda x, y, head_grads: [1.0, 2.0])
    _check_function_must_fail(sym.block_grad(x + 2*y), numerical_grads=True)
    _check_function_must_fail(x*x, numerical_grads=True,
                              numerical_grads_params={'atol': 0.0, 'rtol': 0.0})
99
    _check_function_must_fail(sym.log(-x*x), numerical_grads=True, error=ValueError)
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164

    # different styles of returning results from the forward function
    check_function(x + 2*y, lambda x, y: [x + 2*y], numerical_grads=False)
    _check_function_must_fail(x + 2*y, lambda x, y: [x + 2*y, x], numerical_grads=False,
                              error=ValueError)
    _check_function_must_fail(x + 2*y, lambda x, y: [], numerical_grads=False,
                              error=ValueError)

    # multiple outputs
    z = sym.Group([2*x + y, x + 2*y])
    check_function(z, lambda x, y: [2*x + y, x + 2*y])
    check_function(z, lambda x, y: (2*x + y, x + 2*y))
    check_function(z, backward=lambda x, y, head_grads: [2*head_grads[0] + head_grads[1],
                                                         head_grads[0] + 2*head_grads[1]])
    _check_function_must_fail(z, backward=lambda x, y, head_grads: [2*head_grads[0],
                                                                    2*head_grads[1]])
    check_function(z, backward=lambda x, y, head_grads: [head_grads[1], 2*head_grads[1]],
                   in_range={'head_grads_0': (0, 0)})
    check_function(z, numerical_grads=True)

    z = sym.Group([sym.block_grad(2*x + y), x + 2*y])
    check_function(z, lambda x, y: [2*x + y, x + 2*y], numerical_grads=False)
    _check_function_must_fail(z, lambda x, y: [2*x + y, x + 2*y])
    _check_function_must_fail(z, numerical_grads=True)

    z = sym.Group([2*x + y, sym.block_grad(x + 2*y)])
    _check_function_must_fail(z, numerical_grads=True)

    z = sym.Group([2*x + y, x + 2*y, x, y, sym.sum(x)])
    check_function(z, lambda x, y: [2*x + y, x + 2*y, x, y, np.sum(x)])

    # passing additional parameters to forward and backward
    def _fwd3(x, p):
        assert p == 'v'
        return x + 1
    def _bwd3(x, p, head_grads):
        assert p == 'v'
        return head_grads
    check_function(x + 1, _fwd3, _bwd3, additional_params={'p': 'v'})

    # implicitly created variables and shape/dtype inference for inputs
    x = sym.Variable("x", shape=(2, 3), dtype=0)
    b = sym.Variable("b")
    y = sym.dense(data=x, bias=b, units=4)
    # Don't check gradients on cuda because is doesn't yet support ewise after reduce
    check_function(y, exclude_targets={'cuda'}, numerical_grads=True)
    check_function(y, shape={'x': (3, 4)}, exclude_targets={'cuda'}, numerical_grads=True)
    check_function(y, dtype={'x': 'float64'}, exclude_targets={'cuda'}, numerical_grads=True)

    x = sym.Variable("x")
    b = sym.Variable("b")
    w = sym.Variable("w")
    y = sym.dense(data=x, bias=b, weight=w, units=4)
    def _fwd_dense(x, w, b):
        return np.dot(x, w.T) + b
    check_function(y, _fwd_dense, shape={'x': (1,2)}, dtype={'x': 'float32'}, numerical_grads=False)
    check_function(y, _fwd_dense, shape={'x': (1,2)}, dtype={'w': 'float64'}, numerical_grads=False)
    _check_function_must_fail(y, _fwd_dense, shape={'x': (1,2)},
                              dtype={'w': 'float64', 'b': 'float32'},
                              numerical_grads=False,
                              error=nnvm._base.NNVMError)
    # fails because no shape
    _check_function_must_fail(y, _fwd_dense, numerical_grads=False, error=ValueError)
    # ok because type is float32 by default
    check_function(y, _fwd_dense, shape={'x': (1,2)}, numerical_grads=False)
165

166 167
def test_relu():
    x = sym.Variable("x")
168 169 170 171 172 173
    y = sym.relu(sym.leaky_relu(x, alpha=0.3) - 0.2)

    def forward(x):
        x = (x < 0) * x * 0.3 + (x > 0) * x - 0.2
        return (x > 0) * x

Yao Wang committed
174 175 176 177 178
    def backward(head_grads, x):
        sub = (x < 0) * x * 0.3 + (x > 0) * x - 0.2
        return [(sub > 0).astype("float") * \
                ((x > 0).astype("float") + 0.3 * (x < 0).astype("float")) * head_grads]

179 180
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, shape=shape)
181

182 183 184 185 186 187 188 189
def test_prelu_nchw():
    x = sym.Variable("x")
    a = sym.Variable("a")
    y = sym.prelu(data=x, alpha=a)

    def forward(x, a):
        return (x < 0) * (x * a.reshape(3, 1, 1)) + (x>=0) * x

190 191
    shape = {'x': (1, 3, 32, 32), 'a': (3,)}
    check_function(y, forward, shape=shape)
192 193 194 195 196 197 198 199 200

def test_prelu_nhwc():
    x = sym.Variable("x")
    a = sym.Variable("a")
    y = sym.prelu(data=x, alpha=a, axis=3)

    def forward(x, a):
        return (x < 0) * (x * a.reshape(1, 1, 3)) + (x>=0) * x

201 202
    shape = {'x': (1, 32, 32, 3), 'a': (3,)}
    check_function(y, forward, shape=shape)
203 204 205 206 207 208 209 210 211

def test_sym_scalar_pow():
    scalar = 3
    x = sym.Variable("x")
    y = x**scalar

    def forward(x):
        return x**scalar

Yao Wang committed
212 213
    def backward(head_grads, x):
        return [scalar * x**(scalar -  1) * head_grads]
214

215 216
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, shape=shape)
217 218 219 220 221 222 223 224 225 226


def test_scalar_sym_pow():
    scalar = 3
    x = sym.Variable("x")
    y = scalar**x

    def forward(x):
        return scalar**x

Yao Wang committed
227 228
    def backward(head_grads, x):
        return [np.log(scalar) * scalar**x * head_grads]
229

230 231
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, shape=shape)
232 233 234 235 236


def test_exp():
    x = sym.Variable("x")
    y = sym.exp(x)
237 238 239 240

    def forward(x):
        return np.exp(x)

Yao Wang committed
241 242
    def backward(head_grads, x):
        return [np.exp(x) * head_grads]
243

244 245
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, shape=shape)
246 247 248 249 250


def test_log():
    x = sym.Variable("x")
    y = sym.log(x)
251 252 253 254

    def forward(x):
        return np.log(x)

Yao Wang committed
255 256
    def backward(head_grads, x):
        return [1. / x * head_grads]
257

258 259
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, in_range=(0.002, 2.0), shape=shape)
260 261 262 263 264


def test_tanh():
    x = sym.Variable("x")
    y = sym.tanh(x)
265 266 267 268

    def forward(x):
        return np.sinh(x) / np.cosh(x)

Yao Wang committed
269
    def backward(head_grads, x):
270
        y_np = forward(x)
Yao Wang committed
271
        return [(1 - y_np**2) * head_grads]
272

273 274
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, shape=shape)
275 276 277 278 279


def test_sigmoid():
    x = sym.Variable("x")
    y = sym.sigmoid(x)
280 281 282 283

    def forward(x):
        return 1.0 / (1.0 + np.exp(-x))

Yao Wang committed
284
    def backward(head_grads, x):
285
        y_np = forward(x)
Yao Wang committed
286
        return [y_np *(1 - y_np) * head_grads]
287

288 289
    shape = {'x': (1, 3, 32, 32)}
    check_function(y, forward, backward, shape=shape)
290 291


Yuwei HU committed
292 293 294
def test_softmax():
    x = sym.Variable("x")
    y = sym.softmax(x)
295 296 297 298

    def forward(x):
        return topi.testing.softmax_python(x)

Yao Wang committed
299 300 301 302 303
    def backward(head_grads, x):
        y = topi.testing.softmax_python(x)
        grad = y * (head_grads - np.sum(y * head_grads, axis=1, keepdims=True))
        return [grad]

304 305 306 307
    check_function(y, forward, backward,
                   shape={'x': (10, 1000)}, numerical_grads=False)
    check_function(y, forward, backward,
                   shape={'x': (2, 10)})
308 309


310 311 312
def test_log_softmax():
    x = sym.Variable("x")
    y = sym.log_softmax(x)
313 314 315 316

    def forward(x):
        return topi.testing.log_softmax_python(x)

Yao Wang committed
317 318
    def backward(head_grads, x):
        y = topi.testing.log_softmax_python(x)
319
        grad = head_grads - np.exp(y) * np.sum(head_grads, axis=1, keepdims=True)
Yao Wang committed
320 321
        return [grad]

322 323 324 325
    check_function(y, forward, backward,
                   shape={'x': (10, 1000)}, numerical_grads=False)
    check_function(y, forward, backward,
                   shape={'x': (2, 10)})
326 327


328
def test_dense():
Yao Wang committed
329 330 331 332
    x = sym.Variable("x", shape=(10, 100))
    w = sym.Variable("dense_weight", shape=(3, 100))
    b = sym.Variable("dense_bias", shape=(3,))
    y = sym.dense(x, w, b, use_bias=True, units=3, name="dense")
333
    y = sym.flatten(y)
334 335 336

    def forward(x, dense_weight, dense_bias):
        return np.dot(x, dense_weight.T) + dense_bias
337 338 339 340 341 342 343 344 345 346
    shape = {
        'x': (10, 100),
        'w': (3, 100),
        'b': (3,)
    }
    # Don't check gradients on cuda because is doesn't yet support ewise after reduce
    check_function(y, forward, shape=shape,
                   exclude_targets={'cuda'}, numerical_grads=True)
    check_function(y, forward, shape=shape,
                   only_targets={'cuda'}, numerical_grads=False)
347 348 349 350 351 352 353 354 355 356 357 358


def test_batchnorm():
    x = sym.Variable("x")
    beta = sym.Variable("beta")
    gamma = sym.Variable("gamma")
    moving_var = sym.Variable("moving_var")
    moving_mean = sym.Variable("moving_mean")
    eps = 1e-5
    y = sym.batch_norm(
        x, gamma, beta, moving_mean, moving_var, epsilon=eps)

359 360 361
    def forward(x, gamma, beta, moving_mean, moving_var):
        return (x - moving_mean) / np.sqrt(moving_var + eps) * gamma + beta

362 363 364 365 366 367 368
    shape = {
        'x': (10, 20),
        'gamma': (20,),
        'beta': (20,),
        'moving_mean': (20,),
        'moving_var': (20,)
    }
369

370
    check_function(y, forward, in_range=(0.001, 1.0), shape=shape)
Yuwei HU committed
371 372


373
def verify_concatenate(ishape, axis):
374
    x = [sym.Variable("x%d" % i, shape=ishape[i]) for i in range(len(ishape))]
375
    y = sym.concatenate(*x, axis=axis) + 1
376 377 378 379 380

    def forward(**kwargs):
        return np.concatenate(list(kwargs.values()), axis=axis) + 1

    check_function(y, forward)
381

382

383 384 385 386 387 388
def test_concatenate():
    verify_concatenate([(2, 3, 4), (1, 3, 4)], axis=0)
    verify_concatenate([(2, 4), (2, 7)], axis=1)


def verify_split(ishape, indices_or_sections, axis):
389
    x = sym.Variable("x", shape=ishape)
390
    y = sym.split(x, indices_or_sections=indices_or_sections, axis=axis)
391 392 393 394 395

    def forward(x):
        return np.split(x, indices_or_sections, axis=axis)

    check_function(y, forward)
396

397

398 399 400 401 402
def test_split():
    verify_split((2, 3), 2, axis=0)
    verify_split((5, 3), [3], axis=0)
    verify_split((5, 9, 3), [3, 4], axis=1)

403 404
def verify_strided_slice(ishape, begin, end, strideinp=None):
    stride = strideinp if strideinp else [1, 1, 1]
405
    x = sym.Variable("x", shape=ishape)
406 407 408 409
    if strideinp:
        y = sym.strided_slice(x, begin = begin, end = end, stride = stride) + 1
    else:
        y = sym.strided_slice(x, begin = begin, end = end) + 1
410

411 412 413 414
    for i in range(len(begin), 3):
        begin.append(0)
    for i in range(len(end), 3):
        end.append(ishape[i])
415 416

    def test_forward(x):
417 418 419
        return x[begin[0]:end[0]:stride[0],
                    begin[1]:end[1]:stride[1], begin[2]:end[2]:stride[2]] + 1

420
    check_function(y, test_forward)
421 422 423 424 425 426 427 428 429 430 431

def test_strided_slice():
    verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
    verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1])
    verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2])
    verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3])
    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4])
    verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3])
Xingjian Shi committed
432

433 434 435 436
def verify_take(src_shape, indices_src, axis=None):
    src_dtype = "float32"
    indices_dtype = "int32"
    indices_src = np.array(indices_src, dtype=indices_dtype)
437 438
    a = sym.Variable("a", shape=src_shape)
    indices = sym.Variable("indices", shape=indices_src.shape)
439
    y = sym.take(a, indices, axis=axis)
440 441 442 443 444 445 446 447 448

    def forward(a, indices):
        return np.take(a, indices=indices, axis=axis)

    a_src = np.arange(np.prod(src_shape), dtype=src_dtype).reshape(src_shape)

    check_function(y, forward,
                   dtype={'a': src_dtype, 'indices': indices_dtype},
                   values={'a': a_src, 'indices': indices_src})
449 450 451 452 453 454 455 456 457 458 459 460

def test_take():
    verify_take((4,), [1])
    verify_take((4,), [[0,1,2,3]])
    verify_take((3,3,3), [[11,25]])
    verify_take((4,), [[0,1],[2,3]])
    verify_take((4,), [1], 0)
    verify_take((2,2), [[[1,0],[0,1]]], 0)
    verify_take((2,2), [[[1,0],[0,1]]], 1)
    verify_take((4,3,5,6), [[2,1,0,0]], -2)


461
def verify_squeeze(shape, axis):
Xingjian Shi committed
462
    x = sym.Variable("x")
463
    if axis is not None:
Xingjian Shi committed
464 465 466 467
        y = sym.squeeze(x, axis=axis)
    else:
        y = sym.squeeze(x)
    y = y + 1
468 469 470 471

    def forward(x):
        return np.squeeze(x, axis=axis) + 1

Yao Wang committed
472 473 474
    def backward(head_grads, x):
        return [np.reshape(head_grads, x.shape)]

475
    check_function(y, forward, backward, shape=shape)
476

Xingjian Shi committed
477 478 479 480 481 482

def test_squeeze():
    verify_squeeze((1, 3, 2, 5), None)
    verify_squeeze((1, 3, 1), axis=0)
    verify_squeeze((1, 3, 2, 5, 1), axis=-1)

Yuwei Hu committed
483 484 485 486

def test_pad():
    x = sym.Variable("x")
    y = sym.pad(x, pad_width=((0, 0), (0, 0), (0, 1), (2, 3)), pad_value=1.)
487 488 489 490 491 492

    def forward(x):
        return np.pad(x,
                      pad_width=((0, 0), (0, 0), (0, 1), (2, 3)),
                      mode='constant', constant_values=1.)

493 494
    shape = {'x': (1, 3, 28, 28)}
    check_function(y, forward, shape=shape)
Yuwei Hu committed
495

496
def verify_lrn(ishape, size, axis, bias, alpha, beta):
497
    x = sym.Variable("x", shape=ishape)
498 499
    y = sym.lrn(x, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)

500 501 502 503 504 505 506 507
    def forward1(x):
        return topi.testing.lrn_python(x, size, axis, bias, alpha, beta)

    check_function(y, forward1)

    def forward2(x):
        y = forward1(x)
        return (y > 0)*y
508 509

    #Checking LRN op followed by elementwise op relu
510
    check_function(sym.relu(y), forward2, in_range={'x': (-10.0, 10.0)})
511 512

def verify_l2_normalize(ishape, eps, axis):
513
    x = sym.Variable("x", shape=ishape)
514 515
    y = sym.l2_normalize(x, eps=eps, axis=axis)

516 517 518 519 520 521 522 523
    def forward1(x):
        return topi.testing.l2_normalize_python(x, eps, axis)

    check_function(y, forward1)

    def forward2(x):
        y = forward1(x)
        return (y > 0)*y
524 525

    #Checking L2 normalization op followed by elementwise op relu
526
    check_function(sym.relu(y), forward2, in_range={'x': (-10.0, 10.0)})
527 528 529 530 531 532 533 534

def test_lrn():
    verify_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
    verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75)

def test_l2_normalize():
    verify_l2_normalize((1, 3, 20, 20), 0.001, (1,))
    verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2))
Yuwei Hu committed
535

536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
def verify_gather_nd(src_shape, indices_src):
    src_dtype = "float32"
    indices_dtype = "int32"
    indices_src = np.array(indices_src, dtype=indices_dtype)
    a = sym.Variable("a", shape=src_shape)
    indices = sym.Variable("indices", shape=indices_src.shape)
    y = sym.gather_nd(a, indices)

    def forward(a, indices):
        return topi.testing.gather_nd_python(a, indices)

    a_src = np.arange(np.prod(src_shape), dtype=src_dtype).reshape(src_shape)

    check_function(y, forward,
                   dtype={'a': src_dtype, 'indices': indices_dtype},
                   values={'a': a_src, 'indices': indices_src})

def test_gather_nd():
    verify_gather_nd((4,), [[1]])
    verify_gather_nd((4,), [[1, 3, 2]])
    verify_gather_nd((2, 3), [[1]])
    verify_gather_nd((2, 3), [[1], [0]])
    verify_gather_nd((2, 3), [[1, 0], [0, 2]])
    verify_gather_nd((2, 3, 4), [[1, 0], [0, 2]])
    verify_gather_nd((2, 3, 4), [[1, 0], [0, 2], [3, 1]])
    verify_gather_nd((2, 3, 4), [[[1, 0], [0, 1]], [[0, 2], [1, 2]],
                                 [[3, 1], [0, 2]]])
    verify_gather_nd((2, 3, 4, 5), [[1, 0], [0, 2]])
    verify_gather_nd((2, 3, 4, 5), [[1, 0], [2, 1], [3, 2], [4, 2]])

Yuwei HU committed
566
if __name__ == "__main__":
567
    test_check_function()
568 569
    test_split()
    test_concatenate()
570
    test_log_softmax()
571 572
    test_batchnorm()
    test_dense()
573
    test_relu()
574 575
    test_prelu_nchw()
    test_prelu_nhwc()
576 577
    test_sym_scalar_pow()
    test_scalar_sym_pow()
578 579 580 581
    test_exp()
    test_log()
    test_tanh()
    test_sigmoid()
Yuwei HU committed
582
    test_softmax()
Xingjian Shi committed
583
    test_squeeze()
Yuwei Hu committed
584
    test_pad()
585
    test_take()
586 587
    test_lrn()
    test_l2_normalize()
588
    test_strided_slice()
589
    test_gather_nd()