Commit 6a3a9572 by Sergei Grechanik Committed by Tianqi Chen

[NNVM][TEST] Numgrad: fix nan and multioutput (#1754)

parent 06f91dd2
...@@ -55,84 +55,84 @@ def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None): ...@@ -55,84 +55,84 @@ def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None):
""" """
# Preprocess input parameters # Preprocess input parameters
if shape is None: if shape is None:
shape = {} provided_shapes = {}
elif isinstance(shape, dict):
provided_shapes = shape
else:
provided_shapes = {x: shape for x in graph.symbol.list_input_variables()}
if dtype is None: if dtype is None:
dtype = {} provided_dtypes = {}
elif isinstance(dtype, dict):
if not isinstance(shape, dict): provided_dtypes = dtype
shape = {x: shape for x in graph.symbol.list_input_variables()} else:
provided_dtypes = {x: dtype for x in graph.symbol.list_input_variables()}
if not isinstance(dtype, dict):
dtype = {x: dtype for x in graph.symbol.list_input_variables()}
shape = _dict_var_to_dict_str(shape) provided_shapes = _dict_var_to_dict_str(provided_shapes)
dtype = _dict_var_to_dict_str(dtype) provided_dtypes = _dict_var_to_dict_str(provided_dtypes)
# The graph may already contain shape and dtype info, so extract it and merge with # The graph may already contain shape and dtype info, so extract it and merge with
# the user-specified shapes and dtypes (use the user-specified one on contradiction) # the user-specified shapes and dtypes (use the user-specified one on contradiction)
all_initial_shapes = graph.json_attr('shape') preexisting_shapes = graph.json_attr('shape')
all_initial_dtypes = graph.json_attr('dtype') preexisting_dtypes = graph.json_attr('dtype')
if all_initial_shapes: if preexisting_shapes:
for x in graph.index.input_names: for x in graph.index.input_names:
if x not in shape: if x not in provided_shapes:
x_shape = tuple(all_initial_shapes[graph.index.entry_id(x)]) x_shape = tuple(preexisting_shapes[graph.index.entry_id(x)])
shape[x] = x_shape provided_shapes[x] = x_shape
if all_initial_dtypes: if preexisting_dtypes:
for x in graph.index.input_names: for x in graph.index.input_names:
if x not in dtype: if x not in provided_dtypes:
x_dtype = TCODE_TO_DTYPE[all_initial_dtypes[graph.index.entry_id(x)]] x_dtype = TCODE_TO_DTYPE[preexisting_dtypes[graph.index.entry_id(x)]]
dtype[x] = x_dtype provided_dtypes[x] = x_dtype
# Perform inference # Perform inference
nnvm.compiler.graph_attr.set_shape_inputs(graph, shape) nnvm.compiler.graph_attr.set_shape_inputs(graph, provided_shapes)
nnvm.compiler.graph_attr.set_dtype_inputs(graph, dtype) nnvm.compiler.graph_attr.set_dtype_inputs(graph, provided_dtypes)
graph = graph.apply('InferShape').apply('InferType') graph = graph.apply('InferShape').apply('InferType')
shapes = graph.json_attr('shape') inferred_shapes = graph.json_attr('shape')
dtypes = graph.json_attr('dtype') inferred_dtypes = graph.json_attr('dtype')
out_len = len(graph.symbol.list_output_names())
index = graph.index index = graph.index
output_shapes = \ output_shapes = [tuple(inferred_shapes[index.entry_id(entry)])
[tuple(shapes[index.entry_id(index.output_entries[i])]) for i in range(out_len)] for entry in index.output_entries]
output_dtypes = \ output_dtypes = [TCODE_TO_DTYPE[inferred_dtypes[index.entry_id(entry)]]
[TCODE_TO_DTYPE[dtypes[index.entry_id(index.output_entries[i])]] for i in range(out_len)] for entry in index.output_entries]
# Postprocess the results # Postprocess the results
input_shapes = shape.copy() input_shapes = provided_shapes.copy()
input_dtypes = dtype.copy() input_dtypes = provided_dtypes.copy()
for x in graph.symbol.list_input_variables(): for x in graph.symbol.list_input_variables():
x_name = x.attr('name') x_name = x.attr('name')
x_node_id = graph.index.node_id(x_name) x_entry_id = graph.index.entry_id(x_name)
input_shapes[x_name] = tuple(shapes[x_node_id]) input_shapes[x_name] = tuple(inferred_shapes[x_entry_id])
input_dtypes[x_name] = TCODE_TO_DTYPE[dtypes[x_node_id]] input_dtypes[x_name] = TCODE_TO_DTYPE[inferred_dtypes[x_entry_id]]
# Merge the original user-specified shapes in case some of them are specified for non-existing # Merge the original user-specified shapes in case some of them are specified for non-existing
# variables # variables
for x_name, x_shape in shape.items(): for x_name, x_shape in provided_shapes.items():
x_shape = tuple(x_shape) x_shape = tuple(x_shape)
if input_shapes.get(x_name, x_shape) != x_shape: if input_shapes.get(x_name, x_shape) != x_shape:
raise RuntimeError("Inferred shape differs from the provided shape.\n" raise RuntimeError("Inferred shape differs from the provided shape.\n"
"Provided shapes: {}\nInferred shapes: {}" "Provided shapes: {}\nInferred shapes: {}"
.format(shapes, input_shapes)) .format(provided_shapes, input_shapes))
else: else:
input_shapes[x_name] = x_shape input_shapes[x_name] = x_shape
# Merge the original user-specified dtypes # Merge the original user-specified dtypes
for x_name, x_dtype in dtype.items(): for x_name, x_dtype in provided_dtypes.items():
if not isinstance(x_dtype, str): if not isinstance(x_dtype, str):
x_dtype = TCODE_TO_DTYPE[x_dtype] x_dtype = TCODE_TO_DTYPE[x_dtype]
if input_dtypes.get(x_name, x_dtype) != x_dtype: if input_dtypes.get(x_name, x_dtype) != x_dtype:
raise RuntimeError("Inferred dtype differs from the provided dtype.\n" raise RuntimeError("Inferred dtype differs from the provided dtype.\n"
"Provided dtypes: {}\nInferred dtypes: {}" "Provided dtypes: {}\nInferred dtypes: {}"
.format(dtypes, input_dtypes)) .format(provided_dtypes, input_dtypes))
else: else:
input_dtypes[x_name] = x_dtype input_dtypes[x_name] = x_dtype
...@@ -622,6 +622,12 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No ...@@ -622,6 +622,12 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
dist = np.sqrt(np.sum((ngrad - grad)**2)) dist = np.sqrt(np.sum((ngrad - grad)**2))
grad_norm = np.sqrt(np.sum(ngrad**2)) grad_norm = np.sqrt(np.sum(ngrad**2))
if not (np.isfinite(dist) and np.isfinite(grad_norm)):
raise ValueError(
"NaN or infinity detected during numerical gradient checking wrt {}\n"
"analytical grad = {}\n numerical grad = {}\n"
.format(x_name, grad, ngrad))
# we multiple atol by this number to make it more universal for different sizes # we multiple atol by this number to make it more universal for different sizes
sqrt_n = np.sqrt(float(np.prod(grad.shape))) sqrt_n = np.sqrt(float(np.prod(grad.shape)))
......
...@@ -96,6 +96,7 @@ def test_check_function(): ...@@ -96,6 +96,7 @@ def test_check_function():
_check_function_must_fail(sym.block_grad(x + 2*y), numerical_grads=True) _check_function_must_fail(sym.block_grad(x + 2*y), numerical_grads=True)
_check_function_must_fail(x*x, numerical_grads=True, _check_function_must_fail(x*x, numerical_grads=True,
numerical_grads_params={'atol': 0.0, 'rtol': 0.0}) numerical_grads_params={'atol': 0.0, 'rtol': 0.0})
_check_function_must_fail(sym.log(-x*x), numerical_grads=True, error=ValueError)
# different styles of returning results from the forward function # different styles of returning results from the forward function
check_function(x + 2*y, lambda x, y: [x + 2*y], numerical_grads=False) check_function(x + 2*y, lambda x, y: [x + 2*y], numerical_grads=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment