Commit 8d960c4b by Joshua Z. Zhang Committed by Tianqi Chen

[FRONTEND] Fix mxnet multi outputs (#271)

* fix mxnet multi outputs

* add test case for multi outputs

* fix

* fix

* fix

* fix index

* use json hack

* fix test cases

* fix test cases

* fix test cases

* fix test cases

* fix test cases

* fix test cases
parent 8d39aaba
......@@ -222,6 +222,8 @@ _convert_map = {
'Pooling' : _pooling,
'Pooling_v1' : _pooling,
'Reshape' : _reshape,
'SliceChannel' : _split,
'split' : _split,
'Softmax' : _rename('softmax'),
'SoftmaxOutput' : _softmax_output,
'concat' : _concat,
......@@ -269,10 +271,6 @@ def _convert_symbol(op_name, inputs, attrs,
_raise_not_supported('Operator: ' + op_name)
return sym
def _is_mxnet_group_symbol(symbol):
"""Internal check for mxnet group symbol."""
return len(symbol.list_outputs()) > 1
def _as_list(arr):
"""Force being a list, ignore if already is."""
if isinstance(arr, list):
......@@ -296,26 +294,27 @@ def _from_mxnet_impl(symbol, graph):
nnvm.sym.Symbol
Converted symbol
"""
if _is_mxnet_group_symbol(symbol):
if len(symbol.list_outputs()) > 1:
return [_from_mxnet_impl(s, graph) for s in symbol]
name = symbol.attr('name')
output_index = json.loads(symbol.tojson())['heads'][0][1]
node = graph.get(name, None)
if node:
return node
return node[output_index]
attr = symbol.list_attr()
# op_name = symbol.attr('op_name')
childs = symbol.get_children()
if childs:
op_name = symbol.attr('op_name')
childs = [_from_mxnet_impl(c, graph) for c in _as_list(childs)]
childs = [_from_mxnet_impl(c, graph) for c in childs]
childs = [x for y in childs for x in _as_list(y)] # expand group symbol
node = _convert_symbol(op_name, childs, attr)
else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
node = _sym.Variable(name=name, **attr)
graph[name] = node
return node
return node[output_index]
def from_mxnet(symbol, arg_params=None, aux_params=None):
"""Convert from MXNet's model into compatible NNVM format.
......
......@@ -23,10 +23,18 @@ import mxnet as mx
def get_symbol(num_classes=10, **kwargs):
data = mx.symbol.Variable('data')
data = mx.sym.Flatten(data=data)
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
try:
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128, flatten=False)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, flatten=False)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes, flatten=False)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
except:
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
return mlp
......@@ -137,7 +137,10 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck
# Although kernel is not used here when global_pool=True, we should put one
pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
flat = mx.sym.Flatten(data=pool1)
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1')
try:
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1', flatten=False)
except:
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1')
if dtype == 'float16':
fc1 = mx.sym.Cast(data=fc1, dtype=np.float32)
return mx.sym.softmax(data=fc1, name='softmax')
......
......@@ -36,13 +36,22 @@ def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs):
def get_classifier(input_data, num_classes, **kwargs):
flatten = mx.sym.Flatten(data=input_data, name="flatten")
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6")
relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7")
relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
try:
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6", flatten=False)
relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7", flatten=False)
relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8", flatten=False)
except:
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6")
relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7")
relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
......
......@@ -32,7 +32,19 @@ def test_resnet():
nnvm_sym = model_zoo.nnvm_resnet[n]
compare_graph(from_mx_sym, nnvm_sym)
def test_multi_outputs():
def compose(F, **kwargs):
x = F.sym.Variable('x')
y = F.sym.Variable('y')
z = F.sym.split(x, **kwargs)
return F.sym.broadcast_sub(F.sym.broadcast_add(z[0], z[2]), y)
mx_sym = compose(mx, num_outputs=3, axis=1)
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = compose(nnvm, indices_or_sections=3, axis=1)
compare_graph(from_mx_sym, nnvm_sym)
if __name__ == '__main__':
test_mlp()
test_vgg()
test_resnet()
test_multi_outputs()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment