Commit e9e12f03 by Zhi Committed by Tianqi Chen

[Relay][doc] Update the description of returns in mxnet.py (#2309)

parent 10a9df22
...@@ -343,7 +343,7 @@ _convert_map.update({k : _rename(k) for k in _identity_list}) ...@@ -343,7 +343,7 @@ _convert_map.update({k : _rename(k) for k in _identity_list})
def _from_mxnet_impl(symbol, shape_dict, dtype_info): def _from_mxnet_impl(symbol, shape_dict, dtype_info):
"""Convert mxnet symbol to nnvm implementation. """Convert mxnet symbol to compatible relay Function.
Reconstruct a relay Function by traversing the mxnet symbol. Reconstruct a relay Function by traversing the mxnet symbol.
...@@ -361,15 +361,14 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): ...@@ -361,15 +361,14 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
Returns: Returns:
------- -------
nnvm.sym.Symbol func : tvm.relay.Function
Converted symbol Converted relay Function
""" """
assert symbol is not None assert symbol is not None
jgraph = json.loads(symbol.tojson()) jgraph = json.loads(symbol.tojson())
jnodes = jgraph["nodes"] jnodes = jgraph["nodes"]
node_map = {} node_map = {}
for nid, node in enumerate(jnodes): for nid, node in enumerate(jnodes):
children = [node_map[e[0]][e[1]] for e in node["inputs"]] children = [node_map[e[0]][e[1]] for e in node["inputs"]]
attrs = StrAttrsDict(node.get("attrs", {})) attrs = StrAttrsDict(node.get("attrs", {}))
...@@ -444,8 +443,8 @@ def from_mxnet(symbol, ...@@ -444,8 +443,8 @@ def from_mxnet(symbol,
Returns Returns
------- -------
sym : nnvm.Symbol sym : tvm.relay.Function
Compatible nnvm symbol Compatible relay Function
params : dict of str to tvm.NDArray params : dict of str to tvm.NDArray
The parameter dict to be used by nnvm The parameter dict to be used by nnvm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment