Commit 97ca4031 by Haichen Shen Committed by Lianmin Zheng

[Relay][Frontend] Add MXNet test example for relay (#2316)

* Add MXNet test example for relay
* Fix a bug in BiasAddSimplifier
parent 2bf40f9e
......@@ -22,7 +22,7 @@ class BiasAddSimplifier : public ExprMutator {
CHECK_EQ(call->args.size(), 2);
const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>();
auto ttype = call->args[0]->type_as<TensorTypeNode>();
auto ttype = n->args[0]->type_as<TensorTypeNode>();
size_t n_dim = ttype->shape.size();
Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {param->axis});
Expr ret = Add(call->args[0], expanded_bias);
......
"""MXNet and Relay model zoo."""
from __future__ import absolute_import
from . import mlp, resnet, vgg, dqn, dcgan, squeezenet, inception_v3
import tvm.relay.testing
_num_class = 1000
_batch = 2
# mlp fc
mx_mlp = mlp.get_symbol(_num_class)
relay_mlp = tvm.relay.testing.mlp.get_workload(_batch, _num_class)[0]
# vgg fc
mx_vgg = {}
relay_vgg = {}
for num_layers in [11, 13, 16, 19]:
mx_vgg[num_layers] = vgg.get_symbol(_num_class, num_layers)
relay_vgg[num_layers] = tvm.relay.testing.vgg.get_workload(
_batch, _num_class, num_layers=num_layers)[0]
# resnet fc
mx_resnet = {}
relay_resnet = {}
for num_layers in [18, 34, 50, 101, 152, 200, 269]:
mx_resnet[num_layers] = resnet.get_symbol(_num_class, num_layers, '3,224,224')
relay_resnet[num_layers] = tvm.relay.testing.resnet.get_workload(
_batch, _num_class, num_layers=num_layers)[0]
# squeezenet
mx_squeezenet = {}
relay_squeezenet = {}
for version in ['1.0', '1.1']:
mx_squeezenet[version] = squeezenet.get_symbol(version=version)
relay_squeezenet[version] = tvm.relay.testing.squeezenet.get_workload(_batch, version=version)[0]
# inception
mx_inception_v3 = inception_v3.get_symbol()
relay_inception_v3 = tvm.relay.testing.inception_v3.get_workload(_batch)[0]
# dqn
mx_dqn = dqn.get_symbol()
relay_dqn = tvm.relay.testing.dqn.get_workload(_batch)[0]
# dcgan generator
mx_dcgan = dcgan.get_symbol()
relay_dcgan = tvm.relay.testing.dcgan.get_workload(_batch)[0]
# pylint: disable=unused-argument
"""
The MXNet symbol of DCGAN generator
Adopted from:
https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py
Reference:
Radford, Alec, Luke Metz, and Soumith Chintala.
"Unsupervised representation learning with deep convolutional generative adversarial networks."
arXiv preprint arXiv:1511.06434 (2015).
"""
import mxnet as mx
def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
"""a deconv layer that enlarges the feature map"""
target_shape = (oshape[-2], oshape[-1])
pad_y = (kshape[0] - 1) // 2
pad_x = (kshape[1] - 1) // 2
adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
net = mx.sym.Deconvolution(data,
kernel=kshape,
stride=stride,
pad=(pad_y, pad_x),
adj=(adj_y, adj_x),
num_filter=oshape[0],
no_bias=True,
name=name)
return net
def deconv2d_bn_relu(data, prefix, **kwargs):
"""a block of deconv + batch norm + relu"""
eps = 1e-5 + 1e-12
net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
net = mx.sym.BatchNorm(net, eps=eps, name="%s_bn" % prefix)
net = mx.sym.Activation(net, name="%s_act" % prefix, act_type='relu')
return net
def get_symbol(oshape=(3, 64, 64), ngf=128, code=None):
"""get symbol of dcgan generator"""
assert oshape[-1] == 64, "Only support 64x64 image"
assert oshape[-2] == 64, "Only support 64x64 image"
code = mx.sym.Variable("data") if code is None else code
net = mx.sym.FullyConnected(code, name="g1", num_hidden=ngf*8*4*4, no_bias=True, flatten=False)
net = mx.sym.Activation(net, act_type='relu')
# 4 x 4
net = mx.sym.reshape(net, shape=(-1, ngf * 8, 4, 4))
# 8 x 8
net = deconv2d_bn_relu(
net, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4), prefix="g2")
# 16x16
net = deconv2d_bn_relu(
net, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4), prefix="g3")
# 32x32
net = deconv2d_bn_relu(
net, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4), prefix="g4")
# 64x64
net = deconv2d(
net, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv")
net = mx.sym.Activation(net, act_type='tanh')
return net
"""
The mxnet symbol of Nature DQN
Reference:
Mnih, Volodymyr, et al.
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529.
"""
import mxnet as mx
def get_symbol(num_action=18):
data = mx.sym.Variable(name='data')
net = mx.sym.Convolution(data, kernel=(8, 8), stride=(4, 4),
num_filter=32, name='conv1')
net = mx.sym.Activation(net, act_type='relu', name='relu1')
net = mx.sym.Convolution(net, kernel=(4, 4), stride=(2, 2),
num_filter=64, name='conv2')
net = mx.sym.Activation(net, act_type='relu', name='relu2')
net = mx.sym.Convolution(net, kernel=(3, 3), stride=(1, 1),
num_filter=64, name='conv3')
net = mx.sym.Activation(net, act_type='relu', name='relu3')
net = mx.sym.FullyConnected(net, num_hidden=512, name='fc4')
net = mx.sym.Activation(net, act_type='relu', name='relu4')
net = mx.sym.FullyConnected(net, num_hidden=num_action, name='fc5', flatten=False)
return net
"""
Inception V3, suitable for images with around 299 x 299
Reference:
Szegedy, Christian, et al. "Rethinking the Inception Architecture for Computer Vision." arXiv preprint arXiv:1512.00567 (2015).
Adopted from https://github.com/apache/incubator-mxnet/blob/
master/example/image-classification/symbols/inception-v3.py
"""
import mxnet as mx
import numpy as np
def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''):
conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
bn = mx.sym.BatchNorm(data=conv, eps=2e-5, name='%s%s_batchnorm' % (name, suffix))
act = mx.sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix))
return act
def Inception7A(data,
num_1x1,
num_3x3_red, num_3x3_1, num_3x3_2,
num_5x5_red, num_5x5,
pool, proj,
name):
tower_1x1 = Conv(data, num_1x1, name=('%s_conv' % name))
tower_5x5 = Conv(data, num_5x5_red, name=('%s_tower' % name), suffix='_conv')
tower_5x5 = Conv(tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=('%s_tower' % name), suffix='_conv_1')
tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1')
tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_2')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = Conv(pooling, proj, name=('%s_tower_2' % name), suffix='_conv')
concat = mx.sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name)
return concat
# First Downsample
def Inception7B(data,
num_3x3,
num_d3x3_red, num_d3x3_1, num_d3x3_2,
pool,
name):
tower_3x3 = Conv(data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_conv' % name))
tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv')
tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_tower' % name), suffix='_conv_1')
tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_2')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name))
concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name)
return concat
def Inception7C(data,
num_1x1,
num_d7_red, num_d7_1, num_d7_2,
num_q7_red, num_q7_1, num_q7_2, num_q7_3, num_q7_4,
pool, proj,
name):
tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
tower_d7 = Conv(data=data, num_filter=num_d7_red, name=('%s_tower' % name), suffix='_conv')
tower_d7 = Conv(data=tower_d7, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower' % name), suffix='_conv_1')
tower_d7 = Conv(data=tower_d7, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower' % name), suffix='_conv_2')
tower_q7 = Conv(data=data, num_filter=num_q7_red, name=('%s_tower_1' % name), suffix='_conv')
tower_q7 = Conv(data=tower_q7, num_filter=num_q7_1, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_1')
tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_2')
tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_3')
tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_4')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv')
# concat
concat = mx.sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name)
return concat
def Inception7D(data,
num_3x3_red, num_3x3,
num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3,
pool,
name):
tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=('%s_tower' % name), suffix='_conv')
tower_3x3 = Conv(data=tower_3x3, num_filter=num_3x3, kernel=(3, 3), pad=(0,0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_1')
tower_d7_3x3 = Conv(data=data, num_filter=num_d7_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_1')
tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_2')
tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2), name=('%s_tower_1' % name), suffix='_conv_3')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
# concat
concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name)
return concat
def Inception7E(data,
num_1x1,
num_d3_red, num_d3_1, num_d3_2,
num_3x3_d3_red, num_3x3, num_3x3_d3_1, num_3x3_d3_2,
pool, proj,
name):
tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
tower_d3 = Conv(data=data, num_filter=num_d3_red, name=('%s_tower' % name), suffix='_conv')
tower_d3_a = Conv(data=tower_d3, num_filter=num_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower' % name), suffix='_mixed_conv')
tower_d3_b = Conv(data=tower_d3, num_filter=num_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower' % name), suffix='_mixed_conv_1')
tower_3x3_d3 = Conv(data=data, num_filter=num_3x3_d3_red, name=('%s_tower_1' % name), suffix='_conv')
tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1')
tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower_1' % name), suffix='_mixed_conv')
tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower_1' % name), suffix='_mixed_conv_1')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv')
# concat
concat = mx.sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name)
return concat
def get_symbol(num_classes=1000, **kwargs):
data = mx.sym.Variable(name="data")
# stage 1
conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv")
conv_1 = Conv(conv, 32, kernel=(3, 3), name="conv_1")
conv_2 = Conv(conv_1, 64, kernel=(3, 3), pad=(1, 1), name="conv_2")
pool = mx.sym.Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool")
# stage 2
conv_3 = Conv(pool, 80, kernel=(1, 1), name="conv_3")
conv_4 = Conv(conv_3, 192, kernel=(3, 3), name="conv_4")
pool1 = mx.sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1")
# # stage 3
in3a = Inception7A(pool1, 64,
64, 96, 96,
48, 64,
"avg", 32, "mixed")
in3b = Inception7A(in3a, 64,
64, 96, 96,
48, 64,
"avg", 64, "mixed_1")
in3c = Inception7A(in3b, 64,
64, 96, 96,
48, 64,
"avg", 64, "mixed_2")
in3d = Inception7B(in3c, 384,
64, 96, 96,
"max", "mixed_3")
# stage 4
in4a = Inception7C(in3d, 192,
128, 128, 192,
128, 128, 128, 128, 192,
"avg", 192, "mixed_4")
in4b = Inception7C(in4a, 192,
160, 160, 192,
160, 160, 160, 160, 192,
"avg", 192, "mixed_5")
in4c = Inception7C(in4b, 192,
160, 160, 192,
160, 160, 160, 160, 192,
"avg", 192, "mixed_6")
in4d = Inception7C(in4c, 192,
192, 192, 192,
192, 192, 192, 192, 192,
"avg", 192, "mixed_7")
in4e = Inception7D(in4d, 192, 320,
192, 192, 192, 192,
"max", "mixed_8")
# stage 5
in5a = Inception7E(in4e, 320,
384, 384, 384,
448, 384, 384, 384,
"avg", 192, "mixed_9")
in5b = Inception7E(in5a, 320,
384, 384, 384,
448, 384, 384, 384,
"max", 192, "mixed_10")
# pool
pool = mx.sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool")
flatten = mx.sym.Flatten(data=pool, name="flatten")
fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1', flatten=False)
softmax = mx.sym.SoftmaxOutput(data=fc1, name='softmax')
return softmax
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
a simple multilayer perceptron
"""
import mxnet as mx
def get_symbol(num_classes=10, **kwargs):
data = mx.symbol.Variable('data')
data = mx.sym.Flatten(data=data)
try:
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128, flatten=False)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, flatten=False)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes, flatten=False)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
except:
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
return mlp
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
'''
Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py
Original author Wei Wu
Implemented the following paper:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
'''
import mxnet as mx
import numpy as np
def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tuple
Stride used in convolution
dim_match : Boolean
True means channel number between input and output is the same, otherwise means differ
name : str
Base name of the operators
workspace : int
Workspace used in convolution operator
"""
if bottle_neck:
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=stride, pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv1')
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2')
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3')
conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True,
workspace=workspace, name=name + '_conv3')
if dim_match:
shortcut = data
else:
shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return conv3 + shortcut
else:
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn1')
act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
conv1 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3,3), stride=stride, pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv1')
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
conv2 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2')
if dim_match:
shortcut = data
else:
shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return conv2 + shortcut
def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False):
"""Return ResNet symbol of
Parameters
----------
units : list
Number of units in each stage
num_stages : int
Number of stage
filter_list : list
Channel size of each stage
num_classes : int
Ouput size of symbol
dataset : str
Dataset type, only cifar10 and imagenet supports
workspace : int
Workspace used in convolution operator
dtype : str
Precision (float32 or float16)
"""
num_unit = len(units)
assert(num_unit == num_stages)
data = mx.sym.Variable(name='data')
if dtype == 'float32':
# data = mx.sym.identity(data=data, name='id')
data = data
else:
if dtype == 'float16':
data = mx.sym.Cast(data=data, dtype=np.float16)
data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
(nchannel, height, width) = image_shape
if height <= 32: # such as cifar10
body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(3, 3), stride=(1,1), pad=(1, 1),
no_bias=True, name="conv0", workspace=workspace)
else: # often expected to be 224 such as imagenet
body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3),
no_bias=True, name="conv0", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
body = mx.sym.Activation(data=body, act_type='relu', name='relu0')
body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max')
for i in range(num_stages):
body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace,
memonger=memonger)
for j in range(units[i]-1):
body = residual_unit(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2),
bottle_neck=bottle_neck, workspace=workspace, memonger=memonger)
bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1')
relu1 = mx.sym.Activation(data=bn1, act_type='relu', name='relu1')
# Although kernel is not used here when global_pool=True, we should put one
pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
flat = mx.sym.Flatten(data=pool1)
try:
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1', flatten=False)
except:
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1')
if dtype == 'float16':
fc1 = mx.sym.Cast(data=fc1, dtype=np.float32)
return mx.sym.softmax(data=fc1, name='softmax')
def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='float32', **kwargs):
"""
Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
Original author Wei Wu
"""
image_shape = [int(l) for l in image_shape.split(',')]
(nchannel, height, width) = image_shape
if height <= 28:
num_stages = 3
if (num_layers-2) % 9 == 0 and num_layers >= 164:
per_unit = [(num_layers-2)//9]
filter_list = [16, 64, 128, 256]
bottle_neck = True
elif (num_layers-2) % 6 == 0 and num_layers < 164:
per_unit = [(num_layers-2)//6]
filter_list = [16, 16, 32, 64]
bottle_neck = False
else:
raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers))
units = per_unit * num_stages
else:
if num_layers >= 50:
filter_list = [64, 256, 512, 1024, 2048]
bottle_neck = True
else:
filter_list = [64, 64, 128, 256, 512]
bottle_neck = False
num_stages = 4
if num_layers == 18:
units = [2, 2, 2, 2]
elif num_layers == 34:
units = [3, 4, 6, 3]
elif num_layers == 50:
units = [3, 4, 6, 3]
elif num_layers == 101:
units = [3, 4, 23, 3]
elif num_layers == 152:
units = [3, 8, 36, 3]
elif num_layers == 200:
units = [3, 24, 36, 3]
elif num_layers == 269:
units = [3, 30, 48, 8]
else:
raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers))
return resnet(units = units,
num_stages = num_stages,
filter_list = filter_list,
num_classes = num_classes,
image_shape = image_shape,
bottle_neck = bottle_neck,
workspace = conv_workspace,
dtype = dtype)
"""
Symbol of SqueezeNet
Reference:
Iandola, Forrest N., et al.
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
"""
import mxnet as mx
# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
net = _make_fire_conv(net, squeeze_channels, 1, 0)
left = _make_fire_conv(net, expand1x1_channels, 1, 0)
right = _make_fire_conv(net, expand3x3_channels, 3, 1)
# NOTE : Assume NCHW layout here
net = mx.sym.concat(left, right, dim=1)
return net
def _make_fire_conv(net, channels, kernel_size, padding=0):
net = mx.sym.Convolution(net, num_filter=channels, kernel=(kernel_size, kernel_size),
pad=(padding, padding))
net = mx.sym.Activation(net, act_type='relu')
return net
# Net
def get_symbol(num_classes=1000, version='1.0', **kwargs):
"""Get symbol of SqueezeNet
Parameters
----------
num_classes: int
The number of classification results
version : str, optional
"1.0" or "1.1" of SqueezeNet
"""
assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
net = mx.sym.Variable("data")
if version == '1.0':
net = mx.sym.Convolution(net, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3))
net = mx.sym.Activation(net, act_type='relu')
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 32, 128, 128)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 64, 256, 256)
else:
net = mx.sym.Convolution(net, num_filter=64, kernel=(3, 3), stride=(2, 2), pad=(1, 1))
net = mx.sym.Activation(net, act_type='relu')
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256)
net = mx.sym.Dropout(net, p=0.5)
net = mx.sym.Convolution(net, num_filter=num_classes, kernel=(1, 1))
net = mx.sym.Activation(net, act_type='relu')
net = mx.sym.Pooling(data=net, global_pool=True, kernel=(13, 13), pool_type='avg')
net = mx.sym.flatten(net)
return mx.sym.softmax(net)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""References:
Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for
large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
"""
import mxnet as mx
import numpy as np
def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs):
for i, num in enumerate(layers):
for j in range(num):
internel_layer = mx.sym.Convolution(data = internel_layer, kernel=(3, 3), pad=(1, 1), num_filter=filters[i], name="conv%s_%s" %(i + 1, j + 1))
if batch_norm:
internel_layer = mx.symbol.BatchNorm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = mx.sym.Activation(data=internel_layer, act_type="relu", name="relu%s_%s" %(i + 1, j + 1))
internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1))
return internel_layer
def get_classifier(input_data, num_classes, **kwargs):
flatten = mx.sym.Flatten(data=input_data, name="flatten")
try:
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6", flatten=False)
relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7", flatten=False)
relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8", flatten=False)
except:
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6")
relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7")
relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
"""
Parameters
----------
num_classes : int, default 1000
Number of classification classes.
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
dtype: str, float32 or float16
Data precision.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
if num_layers not in vgg_spec:
raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data = mx.sym.Variable(name="data")
if dtype == 'float16':
data = mx.sym.Cast(data=data, dtype=np.float16)
feature = get_feature(data, layers, filters, batch_norm)
classifier = get_classifier(feature, num_classes)
if dtype == 'float16':
classifier = mx.sym.Cast(data=classifier, dtype=np.float32)
symbol = mx.sym.softmax(data=classifier, name='softmax')
return symbol
import numpy as np
import topi
import tvm
from tvm.contrib import graph_runtime
from tvm import relay
from tvm.relay.testing.config import ctx_list
import mxnet as mx
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
import model_zoo
def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000),
gluon_impl=False, name=None, dtype='float32'):
"""Use name different from test to avoid let nose pick it up"""
if gluon_impl:
def get_gluon_output(name, x):
net = vision.get_model(name)
net.collect_params().initialize(mx.init.Xavier())
net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')),
inputs=mx.sym.var('data'),
params=net.collect_params())
out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy()
return out, net_sym
else:
def get_mxnet_output(symbol, x, dtype='float32'):
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
mod = mx.mod.Module(symbol, label_names=None)
mod.bind(data_shapes=[('data', x.shape)], for_training=False)
mod.init_params()
mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
out = mod.get_outputs()[0].asnumpy()
args, auxs = mod.get_params()
return out, args, auxs
def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
dshape = x.shape
shape_dict = {'data': dshape}
if gluon_impl:
new_sym, params = relay.frontend.from_mxnet(symbol, shape_dict)
else:
new_sym, params = relay.frontend.from_mxnet(symbol, shape_dict, arg_params=args, aux_params=auxs)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(new_sym, target, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("data", tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
m.run()
# get outputs
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()
# random input
x = np.random.uniform(size=data_shape)
if gluon_impl:
gluon_out, gluon_sym = get_gluon_output(name, x)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype)
tvm.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5)
else:
mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
assert "data" not in args
for target, ctx in ctx_list():
tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_mlp():
mlp = model_zoo.mx_mlp
verify_mxnet_frontend_impl(mlp)
def test_forward_vgg():
for n in [11]:
mx_sym = model_zoo.mx_vgg[n]
verify_mxnet_frontend_impl(mx_sym)
def test_forward_resnet():
for n in [18]:
mx_sym = model_zoo.mx_resnet[n]
verify_mxnet_frontend_impl(mx_sym)
def test_forward_elu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='elu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_rrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_prelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='prelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_softrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.Activation(data, act_type='softrelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_fc_flatten():
# test flatten=True option in mxnet 0.11.1
data = mx.sym.var('data')
try:
mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
mx_sym = mx.sym.FullyConnected(mx.sym.Flatten(data), num_hidden=100, flatten=False)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
except:
pass
def test_forward_clip():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicity
mx_sym = mx.sym.clip(data, a_min=0, a_max=1)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_split():
data = mx.sym.var('data')
mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=False)
verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 1, 2, 1))
def test_forward_split_squeeze():
data = mx.sym.var('data')
mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True)
verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1))
def test_forward_expand_dims():
data = mx.sym.var('data')
mx_sym = mx.sym.expand_dims(data, axis=1)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4))
def test_forward_pooling():
data = mx.sym.var('data')
mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='avg')
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max')
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
def test_forward_lrn():
data = mx.sym.var('data')
mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5)
verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24))
def test_forward_ones():
data = mx.sym.var('data')
ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, ones)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_zeros():
data = mx.sym.var('data')
zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, zeros)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_ones_like():
data = mx.sym.var('data')
mx_sym = mx.sym.ones_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_zeros_like():
data = mx.sym.var('data')
mx_sym = mx.sym.zeros_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_argmax():
data = mx.sym.var('data')
mx_sym = mx.sym.argmax(data, axis=1)
verify_mxnet_frontend_impl(mx_sym, (5, 3), (5,))
def test_forward_argmin():
data = mx.sym.var('data')
mx_sym = mx.sym.argmin(data, axis=0)
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
test_forward_resnet()
test_forward_elu()
test_forward_rrelu()
test_forward_prelu()
test_forward_softrelu()
test_forward_fc_flatten()
test_forward_clip()
test_forward_split()
test_forward_split_squeeze()
test_forward_expand_dims()
test_forward_pooling()
test_forward_lrn()
test_forward_ones()
test_forward_zeros()
test_forward_ones_like()
test_forward_zeros_like()
test_forward_argmax()
test_forward_argmin()
import mxnet as mx
import tvm
from tvm import relay
import model_zoo
from model_zoo import _batch
def test_mlp():
mx_sym = model_zoo.mx_mlp
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 1, 28, 28)})
from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
relay_sym = model_zoo.relay_mlp
assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
def test_vgg():
for n in [11, 13, 16, 19]:
mx_sym = model_zoo.mx_vgg[n]
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 3, 224, 224)})
from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
relay_sym = model_zoo.relay_vgg[n]
assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
def test_resnet():
for n in [18, 34, 50, 101, 152, 200, 269]:
mx_sym = model_zoo.mx_resnet[n]
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 3, 224, 224)})
from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
relay_sym = model_zoo.relay_resnet[n]
assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
def test_squeezenet():
for version in ['1.0', '1.1']:
mx_sym = model_zoo.mx_squeezenet[version]
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 3, 224, 224)})
from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
relay_sym = model_zoo.relay_squeezenet[version]
assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
def test_inception_v3():
mx_sym = model_zoo.mx_inception_v3
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 3, 299, 299)})
from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
relay_sym = model_zoo.relay_inception_v3
assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
def test_dqn():
mx_sym = model_zoo.mx_dqn
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 4, 84, 84)})
from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
relay_sym = model_zoo.relay_dqn
assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
def test_dcgan():
mx_sym = model_zoo.mx_dcgan
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'data': (_batch, 100)})
from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
relay_sym = model_zoo.relay_dcgan
assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
def test_multi_outputs():
def compose_mxnet(**kwargs):
x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = mx.sym.split(x, **kwargs)
return mx.sym.broadcast_sub(mx.sym.broadcast_add(z[0], z[2]), y)
def compose_relay(**kwargs):
x = relay.var("x", shape=(_batch, 3, 224, 224))
y = relay.var("y", shape=(1,))
z = relay.split(x, **kwargs)
ret = z[0] + z[2] - y
args = relay.ir_pass.free_vars(ret)
return relay.Function(args, ret)
mx_sym = compose_mxnet(num_outputs=3, axis=1)
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, {'x': (_batch, 3, 224, 224), 'y': (1,)})
from_mx_sym = relay.ir_pass.infer_type(from_mx_sym)
relay_sym = compose_relay(indices_or_sections=3, axis=1)
relay_sym = relay.ir_pass.infer_type(relay_sym)
assert relay.ir_pass.alpha_equal(from_mx_sym, relay_sym)
if __name__ == '__main__':
test_mlp()
test_vgg()
test_resnet()
test_squeezenet()
test_inception_v3()
test_dqn()
test_dcgan()
test_multi_outputs()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment