Unverified Commit de346493 by Jeremy Johnson Committed by GitHub

[Frontend][Torch] Check graph inputs match expected (#4992)

* [Frontend][Torch] Check graph inputs match expected

* error/warn when missing/unused graph inputs

* Change to use get_graph_input_names
parent de0869de
......@@ -905,6 +905,20 @@ def _report_missing_conversion(op_names):
msg = "The following operators are not implemented: {}".format(missing)
raise NotImplementedError(msg)
def _check_input_names(script_module, input_shapes):
""" Check the graph inputs match the inputs """
ir_inputs = get_graph_input_names(script_module)
for ir_input in ir_inputs:
if ir_input not in input_shapes:
msg = "Missing graph input {} in input_shapes".format(ir_input)
raise RuntimeError(msg)
for input_name in input_shapes:
if input_name not in ir_inputs:
msg = "Unused graph input {} in input_shapes".format(input_name)
logging.warning(msg)
def _getattr_attr_name(node):
attribute_names = node.attributeNames()
......@@ -1150,6 +1164,7 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)
_check_input_names(script_module, input_shapes)
params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment