Commit 8d960c4b by Joshua Z. Zhang Committed by Tianqi Chen

[FRONTEND] Fix mxnet multi outputs (#271)

* fix mxnet multi outputs

* add test case for multi outputs

* fix

* fix

* fix

* fix index

* use json hack

* fix test cases

* fix test cases

* fix test cases

* fix test cases

* fix test cases

* fix test cases
parent 8d39aaba
...@@ -222,6 +222,8 @@ _convert_map = { ...@@ -222,6 +222,8 @@ _convert_map = {
'Pooling' : _pooling, 'Pooling' : _pooling,
'Pooling_v1' : _pooling, 'Pooling_v1' : _pooling,
'Reshape' : _reshape, 'Reshape' : _reshape,
'SliceChannel' : _split,
'split' : _split,
'Softmax' : _rename('softmax'), 'Softmax' : _rename('softmax'),
'SoftmaxOutput' : _softmax_output, 'SoftmaxOutput' : _softmax_output,
'concat' : _concat, 'concat' : _concat,
...@@ -269,10 +271,6 @@ def _convert_symbol(op_name, inputs, attrs, ...@@ -269,10 +271,6 @@ def _convert_symbol(op_name, inputs, attrs,
_raise_not_supported('Operator: ' + op_name) _raise_not_supported('Operator: ' + op_name)
return sym return sym
def _is_mxnet_group_symbol(symbol):
"""Internal check for mxnet group symbol."""
return len(symbol.list_outputs()) > 1
def _as_list(arr): def _as_list(arr):
"""Force being a list, ignore if already is.""" """Force being a list, ignore if already is."""
if isinstance(arr, list): if isinstance(arr, list):
...@@ -296,26 +294,27 @@ def _from_mxnet_impl(symbol, graph): ...@@ -296,26 +294,27 @@ def _from_mxnet_impl(symbol, graph):
nnvm.sym.Symbol nnvm.sym.Symbol
Converted symbol Converted symbol
""" """
if _is_mxnet_group_symbol(symbol): if len(symbol.list_outputs()) > 1:
return [_from_mxnet_impl(s, graph) for s in symbol] return [_from_mxnet_impl(s, graph) for s in symbol]
name = symbol.attr('name') name = symbol.attr('name')
output_index = json.loads(symbol.tojson())['heads'][0][1]
node = graph.get(name, None) node = graph.get(name, None)
if node: if node:
return node return node[output_index]
attr = symbol.list_attr() attr = symbol.list_attr()
# op_name = symbol.attr('op_name') # op_name = symbol.attr('op_name')
childs = symbol.get_children() childs = symbol.get_children()
if childs: if childs:
op_name = symbol.attr('op_name') op_name = symbol.attr('op_name')
childs = [_from_mxnet_impl(c, graph) for c in _as_list(childs)] childs = [_from_mxnet_impl(c, graph) for c in childs]
childs = [x for y in childs for x in _as_list(y)] # expand group symbol childs = [x for y in childs for x in _as_list(y)] # expand group symbol
node = _convert_symbol(op_name, childs, attr) node = _convert_symbol(op_name, childs, attr)
else: else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op'] op_name = json.loads(symbol.tojson())['nodes'][0]['op']
node = _sym.Variable(name=name, **attr) node = _sym.Variable(name=name, **attr)
graph[name] = node graph[name] = node
return node return node[output_index]
def from_mxnet(symbol, arg_params=None, aux_params=None): def from_mxnet(symbol, arg_params=None, aux_params=None):
"""Convert from MXNet's model into compatible NNVM format. """Convert from MXNet's model into compatible NNVM format.
......
...@@ -23,10 +23,18 @@ import mxnet as mx ...@@ -23,10 +23,18 @@ import mxnet as mx
def get_symbol(num_classes=10, **kwargs): def get_symbol(num_classes=10, **kwargs):
data = mx.symbol.Variable('data') data = mx.symbol.Variable('data')
data = mx.sym.Flatten(data=data) data = mx.sym.Flatten(data=data)
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) try:
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128, flatten=False)
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, flatten=False)
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes) act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
mlp = mx.symbol.softmax(data = fc3, name = 'softmax') fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes, flatten=False)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
except:
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
return mlp return mlp
...@@ -137,7 +137,10 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck ...@@ -137,7 +137,10 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck
# Although kernel is not used here when global_pool=True, we should put one # Although kernel is not used here when global_pool=True, we should put one
pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1') pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
flat = mx.sym.Flatten(data=pool1) flat = mx.sym.Flatten(data=pool1)
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1') try:
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1', flatten=False)
except:
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1')
if dtype == 'float16': if dtype == 'float16':
fc1 = mx.sym.Cast(data=fc1, dtype=np.float32) fc1 = mx.sym.Cast(data=fc1, dtype=np.float32)
return mx.sym.softmax(data=fc1, name='softmax') return mx.sym.softmax(data=fc1, name='softmax')
......
...@@ -36,13 +36,22 @@ def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs): ...@@ -36,13 +36,22 @@ def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs):
def get_classifier(input_data, num_classes, **kwargs): def get_classifier(input_data, num_classes, **kwargs):
flatten = mx.sym.Flatten(data=input_data, name="flatten") flatten = mx.sym.Flatten(data=input_data, name="flatten")
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6") try:
relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6") fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6", flatten=False)
drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6") relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7") drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7") fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7", flatten=False)
drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7") relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8") drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8", flatten=False)
except:
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6")
relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7")
relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
return fc8 return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs): def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
......
...@@ -32,7 +32,19 @@ def test_resnet(): ...@@ -32,7 +32,19 @@ def test_resnet():
nnvm_sym = model_zoo.nnvm_resnet[n] nnvm_sym = model_zoo.nnvm_resnet[n]
compare_graph(from_mx_sym, nnvm_sym) compare_graph(from_mx_sym, nnvm_sym)
def test_multi_outputs():
def compose(F, **kwargs):
x = F.sym.Variable('x')
y = F.sym.Variable('y')
z = F.sym.split(x, **kwargs)
return F.sym.broadcast_sub(F.sym.broadcast_add(z[0], z[2]), y)
mx_sym = compose(mx, num_outputs=3, axis=1)
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = compose(nnvm, indices_or_sections=3, axis=1)
compare_graph(from_mx_sym, nnvm_sym)
if __name__ == '__main__': if __name__ == '__main__':
test_mlp() test_mlp()
test_vgg() test_vgg()
test_resnet() test_resnet()
test_multi_outputs()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment