Commit ba6f194b by Jared Roesch Committed by Tianqi Chen

Fix bug in ONNX importer (#3084)

parent a706ad16
...@@ -944,7 +944,10 @@ class GraphProto(object): ...@@ -944,7 +944,10 @@ class GraphProto(object):
dtype=self._params[i_name].dtype) dtype=self._params[i_name].dtype)
else: else:
self._num_input += 1 self._num_input += 1
tshape = self._shape[i_name] if i_name in self._shape else () if i_name in self._shape:
tshape = self._shape[i_name]
else:
raise ValueError("Must provide an input shape for `{0}`.".format(i_name))
if isinstance(self._dtype, dict): if isinstance(self._dtype, dict):
dtype = self._dtype[i_name] if i_name in self._dtype else d_type dtype = self._dtype[i_name] if i_name in self._dtype else d_type
else: else:
......
...@@ -724,10 +724,15 @@ def verify_constantfill(is_shape, input_dim, out_dim, value, dtype, **kwargs): ...@@ -724,10 +724,15 @@ def verify_constantfill(is_shape, input_dim, out_dim, value, dtype, **kwargs):
else: else:
fill_node = helper.make_node("ConstantFill", ["input_a"], ["out"], value=value, dtype=dtype, **kwargs) fill_node = helper.make_node("ConstantFill", ["input_a"], ["out"], value=value, dtype=dtype, **kwargs)
if is_shape == True:
inputs = []
else:
inputs = [helper.make_tensor_value_info("input_a",
TensorProto.FLOAT, list(input_dim))]
graph = helper.make_graph([fill_node], graph = helper.make_graph([fill_node],
"fill_test", "fill_test",
inputs = [helper.make_tensor_value_info("input_a", inputs,
TensorProto.FLOAT, list(input_dim))],
outputs = [helper.make_tensor_value_info("out", outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out.shape))]) TensorProto.FLOAT, list(out.shape))])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment