Commit 3a9de905 by Joshua Z. Zhang Committed by Tianqi Chen

[Relay][ONNX] fix #3134 converter where initializers were not registered as nodes (#3143)

parent d4fb0a2d
......@@ -934,7 +934,7 @@ class GraphProto(object):
self._renames = {}
self._num_input = 0
self._num_param = 0
self._shape = shape
self._shape = shape if shape else {}
self._dtype = dtype
def from_onnx(self, graph, opset):
......@@ -966,6 +966,9 @@ class GraphProto(object):
if not init_tensor.name.strip():
raise ValueError("Tensor's name is required.")
self._params[init_tensor.name] = self._parse_array(init_tensor)
self._nodes[init_tensor.name] = new_var(init_tensor.name,
shape=self._params[init_tensor.name].shape,
dtype=self._params[init_tensor.name].dtype)
for i in graph.input:
# from onnx v0.2, GraphProto.input has type ValueInfoProto,
# and the name is 'i.name'
......@@ -1179,6 +1182,18 @@ def from_onnx(model,
params : dict of str to tvm.NDArray
The parameter dict to be used by relay
"""
try:
import onnx
if hasattr(onnx.checker, 'check_model'):
# try use onnx's own model checker before converting any model
try:
onnx.checker.check_model(model)
except onnx.onnx_cpp2py_export.checker.ValidationError as e:
import warnings
# the checker is a bit violent about errors, so simply print warnings here
warnings.warn(str(e))
except ImportError:
pass
g = GraphProto(shape, dtype)
graph = model.graph
try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment