Commit 6a3a9572 by Sergei Grechanik Committed by Tianqi Chen

[NNVM][TEST] Numgrad: fix nan and multioutput (#1754)

parent 06f91dd2
......@@ -55,84 +55,84 @@ def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None):
"""
# Preprocess input parameters
if shape is None:
shape = {}
provided_shapes = {}
elif isinstance(shape, dict):
provided_shapes = shape
else:
provided_shapes = {x: shape for x in graph.symbol.list_input_variables()}
if dtype is None:
dtype = {}
if not isinstance(shape, dict):
shape = {x: shape for x in graph.symbol.list_input_variables()}
if not isinstance(dtype, dict):
dtype = {x: dtype for x in graph.symbol.list_input_variables()}
provided_dtypes = {}
elif isinstance(dtype, dict):
provided_dtypes = dtype
else:
provided_dtypes = {x: dtype for x in graph.symbol.list_input_variables()}
shape = _dict_var_to_dict_str(shape)
dtype = _dict_var_to_dict_str(dtype)
provided_shapes = _dict_var_to_dict_str(provided_shapes)
provided_dtypes = _dict_var_to_dict_str(provided_dtypes)
# The graph may already contain shape and dtype info, so extract it and merge with
# the user-specified shapes and dtypes (use the user-specified one on contradiction)
all_initial_shapes = graph.json_attr('shape')
all_initial_dtypes = graph.json_attr('dtype')
preexisting_shapes = graph.json_attr('shape')
preexisting_dtypes = graph.json_attr('dtype')
if all_initial_shapes:
if preexisting_shapes:
for x in graph.index.input_names:
if x not in shape:
x_shape = tuple(all_initial_shapes[graph.index.entry_id(x)])
shape[x] = x_shape
if x not in provided_shapes:
x_shape = tuple(preexisting_shapes[graph.index.entry_id(x)])
provided_shapes[x] = x_shape
if all_initial_dtypes:
if preexisting_dtypes:
for x in graph.index.input_names:
if x not in dtype:
x_dtype = TCODE_TO_DTYPE[all_initial_dtypes[graph.index.entry_id(x)]]
dtype[x] = x_dtype
if x not in provided_dtypes:
x_dtype = TCODE_TO_DTYPE[preexisting_dtypes[graph.index.entry_id(x)]]
provided_dtypes[x] = x_dtype
# Perform inference
nnvm.compiler.graph_attr.set_shape_inputs(graph, shape)
nnvm.compiler.graph_attr.set_dtype_inputs(graph, dtype)
nnvm.compiler.graph_attr.set_shape_inputs(graph, provided_shapes)
nnvm.compiler.graph_attr.set_dtype_inputs(graph, provided_dtypes)
graph = graph.apply('InferShape').apply('InferType')
shapes = graph.json_attr('shape')
dtypes = graph.json_attr('dtype')
out_len = len(graph.symbol.list_output_names())
inferred_shapes = graph.json_attr('shape')
inferred_dtypes = graph.json_attr('dtype')
index = graph.index
output_shapes = \
[tuple(shapes[index.entry_id(index.output_entries[i])]) for i in range(out_len)]
output_dtypes = \
[TCODE_TO_DTYPE[dtypes[index.entry_id(index.output_entries[i])]] for i in range(out_len)]
output_shapes = [tuple(inferred_shapes[index.entry_id(entry)])
for entry in index.output_entries]
output_dtypes = [TCODE_TO_DTYPE[inferred_dtypes[index.entry_id(entry)]]
for entry in index.output_entries]
# Postprocess the results
input_shapes = shape.copy()
input_dtypes = dtype.copy()
input_shapes = provided_shapes.copy()
input_dtypes = provided_dtypes.copy()
for x in graph.symbol.list_input_variables():
x_name = x.attr('name')
x_node_id = graph.index.node_id(x_name)
input_shapes[x_name] = tuple(shapes[x_node_id])
input_dtypes[x_name] = TCODE_TO_DTYPE[dtypes[x_node_id]]
x_entry_id = graph.index.entry_id(x_name)
input_shapes[x_name] = tuple(inferred_shapes[x_entry_id])
input_dtypes[x_name] = TCODE_TO_DTYPE[inferred_dtypes[x_entry_id]]
# Merge the original user-specified shapes in case some of them are specified for non-existing
# variables
for x_name, x_shape in shape.items():
for x_name, x_shape in provided_shapes.items():
x_shape = tuple(x_shape)
if input_shapes.get(x_name, x_shape) != x_shape:
raise RuntimeError("Inferred shape differs from the provided shape.\n"
"Provided shapes: {}\nInferred shapes: {}"
.format(shapes, input_shapes))
.format(provided_shapes, input_shapes))
else:
input_shapes[x_name] = x_shape
# Merge the original user-specified dtypes
for x_name, x_dtype in dtype.items():
for x_name, x_dtype in provided_dtypes.items():
if not isinstance(x_dtype, str):
x_dtype = TCODE_TO_DTYPE[x_dtype]
if input_dtypes.get(x_name, x_dtype) != x_dtype:
raise RuntimeError("Inferred dtype differs from the provided dtype.\n"
"Provided dtypes: {}\nInferred dtypes: {}"
.format(dtypes, input_dtypes))
.format(provided_dtypes, input_dtypes))
else:
input_dtypes[x_name] = x_dtype
......@@ -622,6 +622,12 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
dist = np.sqrt(np.sum((ngrad - grad)**2))
grad_norm = np.sqrt(np.sum(ngrad**2))
if not (np.isfinite(dist) and np.isfinite(grad_norm)):
raise ValueError(
"NaN or infinity detected during numerical gradient checking wrt {}\n"
"analytical grad = {}\n numerical grad = {}\n"
.format(x_name, grad, ngrad))
# we multiple atol by this number to make it more universal for different sizes
sqrt_n = np.sqrt(float(np.prod(grad.shape)))
......
......@@ -96,6 +96,7 @@ def test_check_function():
_check_function_must_fail(sym.block_grad(x + 2*y), numerical_grads=True)
_check_function_must_fail(x*x, numerical_grads=True,
numerical_grads_params={'atol': 0.0, 'rtol': 0.0})
_check_function_must_fail(sym.log(-x*x), numerical_grads=True, error=ValueError)
# different styles of returning results from the forward function
check_function(x + 2*y, lambda x, y: [x + 2*y], numerical_grads=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment