Commit 25e4dc51 by Haichen Shen Committed by Tianqi Chen

[Frontend][MXNet] Change mxnet graph traversal from recursion to iteration (#2007)

parent 92f82c8e
......@@ -381,6 +381,55 @@ def _as_list(arr):
return arr
return [arr]
def _topo_sort(symbol):
"""Sort all symbols in the mxnet graph in topological order.
Parameters
----------
symbol : mxnet.sym.Symbol
Returns:
-------
list
List of mxnet symbol
"""
queue = []
symbol_map = {}
deps = {}
dep_cnts = {}
for s in symbol:
symbol_map[s.attr('name')] = s
queue.append(s)
while queue:
sym = queue.pop(0)
name = sym.attr('name')
childs = sym.get_children()
if childs is None:
dep_cnts[name] = 0
else:
dep_cnts[name] = len(set([c.attr('name') for c in childs]))
for child in childs:
child_name = child.attr('name')
if child_name not in deps:
deps[child_name] = set()
deps[child_name].add(name)
if child_name not in symbol_map:
symbol_map[child_name] = child
queue.append(child)
order = []
while dep_cnts:
remove = []
for name in dep_cnts:
if dep_cnts[name] == 0:
order.append(symbol_map[name])
remove.append(name)
if name in deps:
for other in deps[name]:
dep_cnts[other] -= 1
for name in remove:
del dep_cnts[name]
return order
def _from_mxnet_impl(symbol, graph):
"""Convert mxnet symbol to nnvm implementation.
Reconstruct a nnvm symbol by traversing the mxnet symbol.
......@@ -398,28 +447,37 @@ def _from_mxnet_impl(symbol, graph):
nnvm.sym.Symbol
Converted symbol
"""
if len(symbol.list_outputs()) > 1:
return [_from_mxnet_impl(s, graph) for s in symbol]
name = symbol.attr('name')
output_index = json.loads(symbol.tojson())['heads'][0][1]
node = graph.get(name, None)
if node:
return node[output_index]
attr = symbol.list_attr()
op_name = symbol.attr('op_name')
childs = symbol.get_children()
def get_node(sym):
name = sym.attr('name')
if name not in graph:
return None
output_index = json.loads(sym.tojson())['heads'][0][1]
return graph[name][output_index]
assert symbol is not None
# Traverse all symbols in topological order
for sym in _topo_sort(symbol):
name = sym.attr('name')
attr = sym.list_attr()
op_name = sym.attr('op_name')
childs = sym.get_children()
if childs is not None:
childs = [_from_mxnet_impl(childs[i], graph) for i in range(len(childs.list_outputs()))]
childs = [x for y in childs for x in _as_list(y)] # expand group symbol
childs = [get_node(child) for child in childs]
childs = [x for y in childs for x in _as_list(y)]
node = _convert_symbol(op_name, childs, attr)
elif op_name != 'null':
node = _convert_symbol(op_name, [], attr) # no input symbol
node = _convert_symbol(op_name, [], attr)
else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
node = _sym.Variable(name=name, **attr)
graph[name] = node
return node[output_index]
nodes = []
for sym in symbol:
node = get_node(sym)
assert node is not None
nodes.append(node)
if len(nodes) > 1:
return _sym.Group(nodes)
return nodes[0]
def from_mxnet(symbol, arg_params=None, aux_params=None):
"""Convert from MXNet's model into compatible NNVM format.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment