test_forward.py 6.2 KB
Newer Older
1 2 3 4
import numpy as np

import topi
import tvm
5
from tvm.contrib import graph_runtime
6 7
import nnvm.symbol as sym
import nnvm.compiler
8
from nnvm.testing.config import ctx_list
9 10
from nnvm import frontend
import mxnet as mx
11 12
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
13 14 15
import model_zoo


16 17
def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000),
                               gluon_impl=False, name=None):
18
    """Use name different from test to avoid let nose pick it up"""
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
    if gluon_impl:
        def get_gluon_output(name, x):
            net = vision.get_model(name)
            net.collect_params().initialize(mx.init.Xavier())
            net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')),
                                           inputs=mx.sym.var('data'),
                                           params=net.collect_params())
            out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy()
            return out, net_sym
    else:
        def get_mxnet_output(symbol, x, dtype='float32'):
            from collections import namedtuple
            Batch = namedtuple('Batch', ['data'])
            mod = mx.mod.Module(symbol, label_names=None)
            mod.bind(data_shapes=[('data', x.shape)], for_training=False)
            mod.init_params()
            mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
            out = mod.get_outputs()[0].asnumpy()
            args, auxs = mod.get_params()
            return out, args, auxs
39

40
    def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
41 42 43 44 45
        if gluon_impl:
            new_sym, params = frontend.from_mxnet(symbol)
        else:
            new_sym, params = frontend.from_mxnet(symbol, args, auxs)

46 47
        dshape = x.shape
        shape_dict = {'data': dshape}
48 49
        with nnvm.compiler.build_config(opt_level=3):
            graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
50
        m = graph_runtime.create(graph, lib, ctx)
51
        # set inputs
52 53 54
        m.set_input("data", tvm.nd.array(x.astype(dtype)))
        m.set_input(**params)
        m.run()
55
        # get outputs
56
        out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
57 58 59 60 61
        return out.asnumpy()

    # random input
    dtype = 'float32'
    x = np.random.uniform(size=data_shape)
62 63 64 65 66 67 68 69 70 71 72
    if gluon_impl:
        gluon_out, gluon_sym = get_gluon_output(name, x)
        for target, ctx in ctx_list():
            tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype)
            np.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5)
    else:
        mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
        assert "data" not in args
        for target, ctx in ctx_list():
            tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
            np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
73 74 75

def test_forward_mlp():
    mlp = model_zoo.mx_mlp
76
    verify_mxnet_frontend_impl(mlp)
77 78 79 80

def test_forward_vgg():
    for n in [11]:
        mx_sym = model_zoo.mx_vgg[n]
81
        verify_mxnet_frontend_impl(mx_sym)
82 83 84 85

def test_forward_resnet():
    for n in [18]:
        mx_sym = model_zoo.mx_resnet[n]
86
        verify_mxnet_frontend_impl(mx_sym)
87

88 89 90 91 92 93 94 95 96 97 98 99
def test_forward_elu():
    data = mx.sym.var('data')
    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
    mx_sym = mx.sym.LeakyReLU(data, act_type='elu')
    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_rrelu():
    data = mx.sym.var('data')
    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
    mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

100 101 102
def test_forward_prelu():
    data = mx.sym.var('data')
    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
103
    mx_sym = mx.sym.LeakyReLU(data, act_type='prelu')
104 105
    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
def test_forward_softrelu():
    data = mx.sym.var('data')
    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
    mx_sym = mx.sym.Activation(data, act_type='softrelu')
    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_fc_flatten():
    # test flatten=True option in mxnet 0.11.1
    data = mx.sym.var('data')
    try:
        mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True)
        verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
        mx_sym = mx.sym.FullyConnected(mx.sym.Flatten(data), num_hidden=100, flatten=False)
        verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
    except:
        pass

123 124 125 126 127 128
def test_forward_clip():
    data = mx.sym.var('data')
    data = mx.sym.concat(data, -data, dim=1)  # negative part explicity
    mx_sym = mx.sym.clip(data, a_min=0, a_max=1)
    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

129 130 131 132 133 134 135 136 137 138
def test_forward_split():
    data = mx.sym.var('data')
    mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False)
    verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1))

def test_forward_split_squeeze():
    data = mx.sym.var('data')
    mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True)
    verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1))

139 140 141 142 143
def test_forward_expand_dims():
    data = mx.sym.var('data')
    mx_sym = mx.sym.expand_dims(data, axis=1)
    verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4))

Yao Wang committed
144 145 146 147 148 149 150 151
def test_forward_pooling():
    data = mx.sym.var('data')
    mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='avg')
    verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))

    mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max')
    verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))

152 153
if __name__ == '__main__':
    test_forward_mlp()
154 155
    test_forward_vgg()
    test_forward_resnet()
156 157
    test_forward_elu()
    test_forward_rrelu()
158
    test_forward_prelu()
159 160
    test_forward_softrelu()
    test_forward_fc_flatten()
161
    test_forward_clip()
162 163
    test_forward_split()
    test_forward_split_squeeze()
164
    test_forward_expand_dims()
Yao Wang committed
165
    test_forward_pooling()