Unverified Commit 5ce2c296 by Josh Fromm Committed by GitHub

Add ability to have multiple copies of same input to onnx_inputs. (#5389)

parent 4cebb1c7
......@@ -57,8 +57,7 @@ class onnx_input():
if isinstance(item, int):
self.input_dict[self.input_keys[item]] = value
elif isinstance(item, str):
if item not in self.input_dict:
self.input_keys.append(item)
self.input_keys.append(item)
self.input_dict[item] = value
else:
raise ValueError("Only integer and string indexed writes allowed.")
......
......@@ -1366,16 +1366,16 @@ def test_binary_ops():
dtype = "float32"
out_shape = in_shape
def verify_binary_ops(op, x, y, out_np, broadcast=None):
def verify_binary_ops(op, x, y, out_np, x_name='in1', y_name='in2', broadcast=None):
if broadcast is None:
z = helper.make_node(op, ['in1', 'in2'], ['out'])
z = helper.make_node(op, [x_name, y_name], ['out'])
else:
z = helper.make_node(op, ['in1', 'in2'], ['out'], broadcast=1)
z = helper.make_node(op, [x_name, y_name], ['out'], broadcast=1)
graph = helper.make_graph([z],
'_test',
inputs=[helper.make_tensor_value_info("in1",
inputs=[helper.make_tensor_value_info(x_name,
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("in2",
helper.make_tensor_value_info(y_name,
TensorProto.FLOAT, list(in_shape))],
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
......@@ -1393,6 +1393,7 @@ def test_binary_ops():
verify_binary_ops("Sub", x, z, x - z, broadcast=True)
verify_binary_ops("Mul", x, y, x * y, broadcast=None)
verify_binary_ops("Mul", x, z, x * z, broadcast=True)
verify_binary_ops("Mul", x, x, x * x, x_name='in1', y_name='in1', broadcast=None)
verify_binary_ops("Div", x, y, x / y, broadcast=None)
verify_binary_ops("Div", x, z, x / z, broadcast=True)
verify_binary_ops("Sum", x, y, x + y, broadcast=None)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment