test_forward.py 33.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import numpy as np
18
import operator
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

import tvm
from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
from tvm import relay
import mxnet as mx

from mxnet import gluon
from mxnet.gluon.model_zoo import vision
import model_zoo


def verify_mxnet_frontend_impl(mx_symbol,
                               data_shape=(1, 3, 224, 224),
                               out_shape=(1, 1000),
                               gluon_impl=False,
                               name=None,
                               dtype='float32'):
    """Use name different from test to avoid let nose pick it up"""
    if gluon_impl:
        def get_gluon_output(name, x):
            net = vision.get_model(name)
            net.collect_params().initialize(mx.init.Xavier())
            net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')),
                                           inputs=mx.sym.var('data'),
                                           params=net.collect_params())
            out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy()
            return out, net_sym
    else:
        def get_mxnet_output(symbol, x, dtype='float32'):
            from collections import namedtuple
            Batch = namedtuple('Batch', ['data'])
            mod = mx.mod.Module(symbol, label_names=None)
            mod.bind(data_shapes=[('data', x.shape)], for_training=False)
            mod.init_params()
            mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
            out = mod.get_outputs()[0].asnumpy()
            args, auxs = mod.get_params()
            return out, args, auxs

    def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
        shape_dict = {"data": x.shape}
        if gluon_impl:
62
            mod, params = relay.frontend.from_mxnet(symbol, shape_dict)
63
        else:
64 65 66 67
            mod, params = relay.frontend.from_mxnet(symbol,
                                                    shape_dict,
                                                    arg_params=args,
                                                    aux_params=auxs)
68
        with relay.build_config(opt_level=3):
69
            graph, lib, params = relay.build(mod, target, params=params)
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
        m = graph_runtime.create(graph, lib, ctx)
        # set inputs
        m.set_input("data", tvm.nd.array(x.astype(dtype)))
        m.set_input(**params)
        m.run()
        # get outputs
        out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
        return out.asnumpy()

    # random input
    x = np.random.uniform(size=data_shape)
    if gluon_impl:
        gluon_out, gluon_sym = get_gluon_output(name, x)
        for target, ctx in ctx_list():
            tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype)
            tvm.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5)
    else:
        mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
        assert "data" not in args
        for target, ctx in ctx_list():
            tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
            tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)

def test_forward_mlp():
    mlp = model_zoo.mx_mlp()
    verify_mxnet_frontend_impl(mlp,
                               data_shape=(1, 1, 28, 28),
                               out_shape=(1, 10))

def test_forward_vgg():
    for n in [11]:
        mx_sym = model_zoo.mx_vgg(n)
        verify_mxnet_frontend_impl(mx_sym)

def test_forward_resnet():
    for n in [18]:
        mx_sym = model_zoo.mx_resnet(18)
        verify_mxnet_frontend_impl(mx_sym)

def test_forward_elu():
    data = mx.sym.var('data')
    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
    mx_sym = mx.sym.LeakyReLU(data, act_type='elu')
    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_rrelu():
    data = mx.sym.var('data')
    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
    mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_prelu():
    data = mx.sym.var('data')
    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
    mx_sym = mx.sym.LeakyReLU(data, act_type='prelu')
    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_softrelu():
    data = mx.sym.var('data')
    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
    mx_sym = mx.sym.Activation(data, act_type='softrelu')
    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_fc_flatten():
    # test flatten=True option in mxnet 0.11.1
    data = mx.sym.var('data')
    try:
        mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True)
        verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
        mx_sym = mx.sym.FullyConnected(mx.sym.Flatten(data), num_hidden=100, flatten=False)
        verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
    except:
        pass

def test_forward_clip():
    data = mx.sym.var('data')
146
    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
    mx_sym = mx.sym.clip(data, a_min=0, a_max=1)
    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_split():
    data = mx.sym.var('data')
    mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False)
    verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1))

def test_forward_split_squeeze():
    data = mx.sym.var('data')
    mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True)
    verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1))

def test_forward_expand_dims():
    data = mx.sym.var('data')
    mx_sym = mx.sym.expand_dims(data, axis=1)
    verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4))

def test_forward_pooling():
    data = mx.sym.var('data')
    mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='avg')
    verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))

    mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max')
    verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))

173 174 175 176 177 178 179 180
def test_forward_adaptive_pooling():
    data = mx.sym.var('data')
    mx_sym = mx.sym.contrib.AdaptiveAvgPooling2D(data, output_size=(1,))
    verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 1, 1))

    mx_sym = mx.sym.contrib.AdaptiveAvgPooling2D(data, output_size=(3, 3))
    verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 3, 3))

181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
def test_forward_lrn():
    data = mx.sym.var('data')
    mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5)
    verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24))

def test_forward_ones():
    data = mx.sym.var('data')
    ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
    mx_sym = mx.sym.elemwise_add(data, ones)
    verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))

def test_forward_zeros():
    data = mx.sym.var('data')
    zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
    mx_sym = mx.sym.elemwise_add(data, zeros)
    verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))

def test_forward_ones_like():
    data = mx.sym.var('data')
    mx_sym = mx.sym.ones_like(data, dtype='float32')
    verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))

def test_forward_zeros_like():
    data = mx.sym.var('data')
    mx_sym = mx.sym.zeros_like(data, dtype='float32')
    verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))

def test_forward_argmax():
    data = mx.sym.var('data')
    mx_sym = mx.sym.argmax(data, axis=1)
    verify_mxnet_frontend_impl(mx_sym, (5, 3), (5,))

def test_forward_argmin():
    data = mx.sym.var('data')
    mx_sym = mx.sym.argmin(data, axis=0)
    verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))

218 219 220 221 222 223 224
def test_forward_slice():
    data = mx.sym.var('data')
    mx_sym = mx.sym.slice(data, begin=(0, 1), end=(2, 4))
    verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 3))
    mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2))
    verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2))

225 226 227 228 229 230 231 232 233 234 235 236 237
def test_forward_where():
    cond = mx.sym.var('cond')
    x = mx.sym.var('x')
    y = mx.sym.var('y')
    dshape = (2, 2)
    dtype = 'float32'
    mx_sym = mx.sym.where(cond, x, y)
    np_cond = np.array([[0, 1], [-1, 0]]).astype(dtype)
    np_x = np.random.uniform(size=dshape).astype(dtype)
    np_y = np.random.uniform(size=dshape).astype(dtype)
    mx_cond = mx.nd.array(np_cond)
    mx_x = mx.nd.array(np_x)
    mx_y = mx.nd.array(np_y)
238
    shapes = {'cond': dshape, 'x': dshape, 'y': dshape}
239
    mod = mx.mod.Module(mx_sym, label_names=None, data_names=['cond', 'x', 'y'])
240
    mod.bind(data_shapes=shapes.items(), for_training=False)
241 242 243
    mod.init_params()
    args, auxs = mod.get_params()
    mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy()
244

245
    mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, args, auxs)
246
    for target, ctx in ctx_list():
247
        for kind in ["graph", "debug"]:
248 249
            intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
            op_res = intrp.evaluate()(np_cond, np_x, np_y)
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
            tvm.testing.assert_allclose(op_res.asnumpy(), mx_out)


def test_forward_arange():
    def _mx_symbol(F, start, stop, step):
        if start is None and step is None:
            sym = F.arange(stop)
        elif start is None:
            sym = F.arange(stop, step=step)
        elif step is None:
            sym = F.arange(start, stop)
        else:
            sym = F.arange(start, stop, step)
        return sym

    def verify(start, stop, step):
        ref_res = _mx_symbol(mx.nd, start, stop, step).asnumpy()
        mx_sym = _mx_symbol(mx.sym, start, stop, step)
268
        mod, _ = relay.frontend.from_mxnet(mx_sym, {})
269 270
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
271 272
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()()
273 274 275 276 277 278 279 280 281 282
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
    verify(0, 20, None)
    verify(0, 20, 2)
    verify(1, 20, None)
    verify(1, 20, 2)
    verify(1, 20, 1.5)
    verify(1, 20.5, None)
    verify(1, 20, 3)
    verify(20, 1, -1)
    verify(20, 1, -1.5)
283

284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
def _mx_symbol(F, op_name, inputs):
    op = getattr(F, op_name)
    return op(*inputs)

def test_forward_broadcast_ops():
    for op in ["broadcast_add", "broadcast_sub", "broadcast_mul",
               "broadcast_div", "broadcast_mod", "broadcast_maximum",
               "broadcast_minimum", "broadcast_equal", "broadcast_not_equal",
               "broadcast_greater", "broadcast_greater_equal",
               "broadcast_lesser", "broadcast_lesser_equal"]:
        a_shape = (3, 4, 5)
        b_shape = (4, 5)
        if op == "broadcast_mod":
            dtype = 'int32'
            a_np = np.random.randint(1, 100, size=a_shape).astype(dtype)
            b_np = np.random.randint(1, 100, size=b_shape).astype(dtype)
        else:
            dtype = 'float32'
            a_np = np.random.uniform(size=a_shape).astype(dtype)
            b_np = np.random.uniform(size=b_shape).astype(dtype)
        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
        ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
        shapes = {'a': a_shape, 'b': b_shape}
307
        mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
308 309
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
310 311
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(a_np, b_np)
312 313 314 315 316 317 318 319 320 321 322 323
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

def test_forward_elemwise_ops():
    for op in ["elemwise_add", "elemwise_sub", "elemwise_mul",
               "elemwise_div", "maximum", "minimum"]:
        shape = (3, 4, 5)
        dtype = 'float32'
        a_np = np.random.uniform(size=shape).astype(dtype)
        b_np = np.random.uniform(size=shape).astype(dtype)
        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
        ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
        shapes = {'a': shape, 'b': shape}
324
        mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
325 326
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
327 328
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(a_np, b_np)
329 330 331 332 333 334 335 336 337 338 339 340 341
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

def test_forward_scalar_ops():
    for op in [operator.add, operator.sub, operator.mul, operator.truediv,
               operator.pow, operator.lt, operator.le, operator.eq,
               operator.ne, operator.gt, operator.ge]:
        dtype='float32'
        a_shape = (3, 4, 5)
        a_np = np.random.uniform(size=a_shape).astype(dtype)
        b_scalar = 2.3
        mx_sym = op(mx.sym.var('a'), b_scalar)
        ref_res = op(mx.nd.array(a_np), b_scalar)
        shapes = {'a': a_shape}
342
        mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
343 344
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
345 346
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(a_np)
347 348 349 350 351 352 353 354 355
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    for op in ["maximum", "minimum"]:
        dtype='float32'
        a_shape = (3, 4, 5)
        a_np = np.random.uniform(size=a_shape).astype(dtype)
        b_scalar = 2.3
        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), b_scalar])
        ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), b_scalar])
        shapes = {'a': a_shape}
356
        mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
357 358
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
359 360
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(a_np)
361 362
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

363 364 365 366 367
def test_forward_slice_axis():
    def verify(shape, axis, begin, end):
        data_np = np.random.uniform(size=shape).astype("float32")
        ref_res = mx.nd.slice_axis(mx.nd.array(data_np), axis, begin, end)
        mx_sym = mx.sym.slice_axis(mx.sym.var("data"), axis, begin, end)
368
        mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape})
369 370
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
371 372
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(data_np)
373 374 375 376 377
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify((3, 4), 0, 1, 2)
    verify((3, 4), 0, 1, None)
    verify((3, 4), 1, 0, 2)
    verify((3, 4), 1, -3, -1)
378
    verify((3, 4), -1, -3, -1)
379

380 381 382 383 384 385 386 387 388 389
def test_forward_slice_like():
    def verify(x_shape, y_shape, axes):
        x_np = np.random.uniform(size=x_shape).astype("float32")
        y_np = np.random.uniform(size=y_shape).astype("float32")
        if axes is None:
            ref_res = mx.nd.slice_like(mx.nd.array(x_np), mx.nd.array(y_np))
            mx_sym = mx.sym.slice_like(mx.sym.var("x"), mx.sym.var("y"))
        else:
            ref_res = mx.nd.slice_like(mx.nd.array(x_np), mx.nd.array(y_np), axes=axes)
            mx_sym = mx.sym.slice_like(mx.sym.var("x"), mx.sym.var("y"), axes=axes)
390
        mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": x_shape, "y": y_shape})
391 392
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
393 394
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(x_np, y_np)
395 396 397 398 399 400
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify((3, 4), (2, 3), None)
    verify((3, 4), (2, 3), (0, 1))
    verify((3, 4), (2, 3), (0))
    verify((3, 4), (2, 3), (-1))

401 402 403 404 405
def test_forward_l2_normalize():
    data = mx.sym.var('data')
    mx_sym = mx.sym.L2Normalization(data, mode="channel")
    verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5))

406 407 408 409 410
def test_forward_shape_array():
    def verify(shape):
        x_np = np.random.uniform(size=shape).astype("float32")
        ref_res = mx.nd.shape_array(mx.nd.array(x_np))
        mx_sym = mx.sym.shape_array(mx.sym.var("x"))
411
        mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
412 413
        for target, ctx in ctx_list():
            for kind in ["debug"]:
414 415
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(x_np)
416 417 418 419 420
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify((1,))
    verify((3, 4, 5))
    verify((3, 4, 5, 6))

421 422 423 424 425 426 427 428 429
def test_forward_squeeze():
    def verify(shape, axis):
        x_np = np.random.uniform(size=shape).astype("float32")
        if axis is None:
            ref_res = mx.nd.squeeze(mx.nd.array(x_np))
            mx_sym = mx.sym.squeeze(mx.sym.var("x"))
        else:
            ref_res = mx.nd.squeeze(mx.nd.array(x_np), axis=axis)
            mx_sym = mx.sym.squeeze(mx.sym.var("x"), axis=axis)
430
        mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
431 432
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
433 434
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(x_np)
435 436 437 438 439 440 441 442 443 444 445
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify((1, 3, 1), None)
    verify((1, 3, 1), 0)
    verify((1, 3, 1), 2)
    verify((1, 3, 1), (0, 2))

def test_forward_broadcast_axis():
    def verify(shape, axis, size):
        x_np = np.random.uniform(size=shape).astype("float32")
        ref_res = mx.nd.broadcast_axis(mx.nd.array(x_np), axis=axis, size=size)
        mx_sym = mx.sym.broadcast_axis(mx.sym.var("x"), axis=axis, size=size)
446
        mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
447 448
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
449 450
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(x_np)
451 452 453 454 455 456 457 458 459
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify((1, 2, 1), 2, 3)
    verify((1, 2, 1), (0, 2), (2, 3))

def test_forward_full():
    def verify(val, shape, dtype):
        ctx = mx.cpu()
        ref_res = mx.nd.full(shape, val, dtype=dtype)
        mx_sym = mx.sym.full(shape, val, dtype=dtype)
460
        mod, _ = relay.frontend.from_mxnet(mx_sym, {})
461 462 463 464
        for target, ctx in ctx_list():
            # Skip testing graph runtime because this op will be optimized out
            # by constant folding.
            for kind in ["debug"]:
465 466
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()()
467 468 469 470 471 472 473 474 475 476 477 478 479 480
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify(2, (3, 4), "float32")
    verify(2, (3, 4), "int32")
    verify(3.5, (1, 3, 4), "float32")

def test_forward_embedding():
    def verify(data_shape, weight_shape):
        in_dim, out_dim = weight_shape
        x_np = np.random.randint(0, weight_shape[0], size=data_shape).astype("float32")
        w_np = np.random.uniform(size=weight_shape).astype("float32")
        ref_res = mx.nd.Embedding(mx.nd.array(x_np), mx.nd.array(w_np),
                                  input_dim=in_dim, output_dim=out_dim)
        mx_sym = mx.sym.Embedding(mx.sym.var("x"), mx.sym.var("w"),
                                  input_dim=in_dim, output_dim=out_dim)
481
        mod, _ = relay.frontend.from_mxnet(
482 483 484
            mx_sym, {"x": data_shape, "w": weight_shape})
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
485 486
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(x=x_np, w=w_np)
487 488 489
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify((2, 2), (4, 5))
    verify((2, 3, 4), (4, 5))
490

491 492 493 494 495 496 497
def test_forward_smooth_l1():
    data = mx.sym.var('data')
    mx_sym = mx.sym.smooth_l1(data)
    verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
    mx_sym = mx.sym.smooth_l1(data, scalar=1.0)
    verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))

498 499 500 501 502 503
def test_forward_take():
    def verify(shape, indices_src, axis, mode="clip"):
        x_np = np.random.uniform(size=shape).astype("float32")
        indices_np = np.array(indices_src, dtype="float32")
        ref_res = mx.nd.take(mx.nd.array(x_np), mx.nd.array(indices_np), axis, mode)
        mx_sym = mx.sym.take(mx.sym.var("x"), mx.sym.var("y"), axis, mode)
504
        mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape})
505 506
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
507 508
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(x_np, indices_np)
509 510 511 512 513 514 515 516 517
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify((2,2), [[[1,0],[0,1]]], 0)
    verify((2,2), [[[1,0],[0,1]]], 1)
    verify((4,3,5,6), [[2,1,0,0]], -2)
    verify((3,4), [-1, 5], 0)
    verify((3,4), [-1, 5], 0, mode="wrap")
    verify((3,4), [-1, 5], 1)
    verify((3,4), [-1, 5], 1, mode="wrap")

518 519 520 521 522
def test_forward_gather_nd():
    def verify(xshape, yshape, y_data):
        x_data = np.random.uniform(size=xshape).astype("float32")
        ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
        mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
523
        mod, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"})
524 525
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
526 527
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(x_data, y_data)
528 529 530 531
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
    verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]])

532 533 534 535 536 537
def test_forward_bilinear_resize():
    # add tests including scale_height and scale_width when mxnet is updated to version 1.5
    data = mx.sym.var('data')
    mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10)
    verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))

538
def test_forward_rnn_layer():
539 540
    def verify(mode, seq_len, input_size, hidden_size, num_layers,
               batch=1, init_states=True, bidirectional=False):
541
        if mode == "rnn":
542
            layer = gluon.rnn.RNN(hidden_size, num_layers, bidirectional=bidirectional)
543
        elif mode == "gru":
544
            layer = gluon.rnn.GRU(hidden_size, num_layers, bidirectional=bidirectional)
545
        else: # mode == "lstm"
546
            layer = gluon.rnn.LSTM(hidden_size, num_layers, bidirectional=bidirectional)
547 548
        num_states = 2 if mode == "lstm" else 1
        layer.initialize()
549
        layer.hybridize()
550 551

        dtype = "float32"
552
        directions = 2 if bidirectional else 1
553
        data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype)
554 555 556 557 558
        data_mx = mx.nd.array(data_np)

        if init_states:
            shape_dict = {'data0': data_np.shape}
            inputs = {'data0': data_np}
559
            state_shape = (num_layers*directions, batch, hidden_size)
560 561 562
            states_np = []
            states_mx = []
            for i in range(num_states):
563
                s = np.random.uniform(size=state_shape).astype(dtype)
564 565 566 567 568 569 570 571 572 573
                states_np.append(s)
                states_mx.append(mx.nd.array(s))
                shape_dict['data%s' % (i+1)] = s.shape
                inputs['data%s' % (i+1)] = s
            mx_out, mx_states = layer(data_mx, states_mx)
            mx_res = [mx_out] + mx_states
        else:
            shape_dict = {'data': data_np.shape}
            inputs = {'data': data_np}
            mx_res = layer(data_mx)
574 575 576 577 578 579

        mx_sym = layer._cached_graph[1]
        mx_params = {}
        for name, param in layer.collect_params().items():
            mx_params[name] = param._reduce()

580
        mod, params = relay.frontend.from_mxnet(
581 582 583 584
            mx_sym, shape=shape_dict, arg_params=mx_params)
        for target, ctx in ctx_list():
            # only test graph runtime because debug runtime is too slow
            for kind in ["graph"]:
585 586
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(**inputs, **params)
587 588 589 590 591 592 593 594
                if init_states:
                    assert len(op_res) == len(mx_res)
                    for i, val in enumerate(op_res):
                        tvm.testing.assert_allclose(
                            val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
                else:
                    tvm.testing.assert_allclose(
                        op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3)
595 596

    for mode in ["rnn", "gru", "lstm"]:
597 598 599 600 601 602 603
        verify(mode, 1, 64, 64, 1)
        verify(mode, 10, 64, 64, 2)
        verify(mode, 10, 64, 32, 2)
        verify(mode, 10, 64, 32, 2, batch=2)
        verify(mode, 10, 64, 64, 3, init_states=False)
        verify(mode, 10, 32, 64, 1, bidirectional=True)
        verify(mode, 10, 64, 64, 3, batch=2, bidirectional=True, init_states=False)
604

605 606 607 608 609 610 611 612 613 614
def test_forward_Crop():
    def verify(xshape, yshape, offset=None):
        x_data = np.random.uniform(size=xshape).astype("float32")
        y_data = np.random.uniform(size=yshape).astype("float32")
        if offset is None:
            mx_sym = mx.sym.Crop(mx.sym.var("x"), mx.sym.var("y"))
            ref_res = mx.nd.Crop(mx.nd.array(x_data), mx.nd.array(y_data))
        else:
            mx_sym = mx.sym.Crop(mx.sym.var("x"), mx.sym.var("y"), offset=offset)
            ref_res = mx.nd.Crop(mx.nd.array(x_data), mx.nd.array(y_data), offset=offset)
615
        mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": xshape, "y": yshape})
616 617
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
618
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
619
                if offset is None or offset == (0, 0):
620
                    op_res = intrp.evaluate()(x_data, y_data)
621
                else:
622
                    op_res = intrp.evaluate()(x_data)
623 624 625 626 627 628 629
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify((1, 3, 40, 40), (1, 3, 20, 20))
    verify((1, 3, 40, 40), (1, 3, 20, 20), (0, 0))
    verify((1, 3, 40, 40), (1, 3, 20, 20), (10, 10))
    verify((5, 32, 40, 40), (5, 32, 25, 25))
    verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5))

630 631 632 633 634
def test_forward_argsort():
    def verify(shape, axis, is_ascend, dtype="float32"):
        x_np = np.random.uniform(size=shape).astype("float32")
        ref_res = mx.nd.argsort(mx.nd.array(x_np), axis=axis, is_ascend=is_ascend, dtype=dtype)
        mx_sym = mx.sym.argsort(mx.sym.var("x"), axis=axis, is_ascend=is_ascend, dtype=dtype)
635
        mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
636 637
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
638 639
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(x_np)
640 641 642 643 644 645 646 647 648 649 650 651
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify((2, 3, 4), axis=0, is_ascend=False)
    verify((1, 4, 6), axis=1, is_ascend=True)
    verify((3, 5, 6), axis=-3, is_ascend=False, dtype="int32")

def test_forward_topk():
    def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"):
        x_np = np.random.uniform(size=shape).astype("float32")
        ref_res = mx.nd.topk(mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type,
                             is_ascend=is_ascend, dtype=dtype)
        mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type,
                             is_ascend=is_ascend, dtype=dtype)
652
        mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
653 654
        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
655 656
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                op_res = intrp.evaluate()(x_np)
657 658 659 660 661 662 663 664 665 666 667 668
                if isinstance(ref_res, list):
                    assert len(op_res) == len(ref_res)
                    for i, t in enumerate(op_res):
                        tvm.testing.assert_allclose(t.asnumpy(), ref_res[i].asnumpy())
                else:
                    tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify((3, 4), k=1, axis=0, ret_type="both")
    verify((3, 4), k=1, axis=-1, ret_type="indices")
    verify((3, 5, 6), k=2, axis=2, ret_type="value")
    verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True)
    verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32")

669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713
def test_forward_sequence_mask():
    def verify(shape, use_sequence_length, value, axis, dtype, itype):
        data_np = np.random.uniform(size=shape).astype(dtype)
        valid_length_np = np.random.randint(0, shape[axis], size=shape[1-axis]).astype(itype)
        if use_sequence_length:
            ref_res = mx.nd.SequenceMask(mx.nd.array(data_np, dtype=dtype),
                                         sequence_length=mx.nd.array(valid_length_np, dtype=itype),
                                         use_sequence_length=use_sequence_length,
                                         value=value,
                                         axis=axis)
            mx_sym = mx.sym.SequenceMask(mx.sym.var('data'),
                                         sequence_length=mx.sym.var('valid_length'),
                                         use_sequence_length=use_sequence_length,
                                         value=value,
                                         axis=axis)
            mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape,
                                                        'valid_length': valid_length_np.shape},
                                               dtype={"data": dtype,
                                                      "valid_length": itype})
        else:
            ref_res = mx.nd.SequenceMask(mx.nd.array(data_np, dtype=dtype),
                                         use_sequence_length=use_sequence_length,
                                         value=value,
                                         axis=axis)
            mx_sym = mx.sym.SequenceMask(mx.sym.var('data'),
                                         use_sequence_length=use_sequence_length,
                                         value=value,
                                         axis=axis)
            mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape}, dtype={"data": dtype})
        for target, ctx in ctx_list():
            for kind in ['graph', 'debug']:
                if use_sequence_length is False and kind == 'graph':
                    # Disable the test for 'graph' when it's identity.
                    continue
                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                if use_sequence_length:
                    op_res = intrp.evaluate()(data_np, valid_length_np)
                else:
                    op_res = intrp.evaluate()(data_np)
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
    verify((5, 10), True, 0.0, 0, 'float32', 'float32')
    verify((5, 4, 3), True, 1.0, 1, 'float32', 'float32')
    verify((5, 4, 3), False, 1.0, 1, 'float64', 'float64')
    verify((5, 4, 3, 2), True, 1.0, 0, 'float32', 'float32')

714

715 716 717 718 719 720 721 722 723 724 725 726 727 728
if __name__ == '__main__':
    test_forward_mlp()
    test_forward_vgg()
    test_forward_resnet()
    test_forward_elu()
    test_forward_rrelu()
    test_forward_prelu()
    test_forward_softrelu()
    test_forward_fc_flatten()
    test_forward_clip()
    test_forward_split()
    test_forward_split_squeeze()
    test_forward_expand_dims()
    test_forward_pooling()
729
    test_forward_adaptive_pooling()
730 731 732 733 734 735 736
    test_forward_lrn()
    test_forward_ones()
    test_forward_zeros()
    test_forward_ones_like()
    test_forward_zeros_like()
    test_forward_argmax()
    test_forward_argmin()
737
    test_forward_where()
738
    test_forward_arange()
739 740 741
    test_forward_broadcast_ops()
    test_forward_elemwise_ops()
    test_forward_scalar_ops()
742
    test_forward_slice_like()
743 744
    test_forward_slice_axis()
    test_forward_l2_normalize()
745
    test_forward_shape_array()
746 747 748 749
    test_forward_squeeze()
    test_forward_broadcast_axis()
    test_forward_full()
    test_forward_embedding()
750
    test_forward_smooth_l1()
751
    test_forward_take()
752
    test_forward_gather_nd()
753
    test_forward_bilinear_resize()
754
    test_forward_rnn_layer()
755
    test_forward_Crop()
756 757
    test_forward_argsort()
    test_forward_topk()
758
    test_forward_sequence_mask()