Commit 10b6e7e0 by Sergei Grechanik Committed by Tianqi Chen

[TVM] Move check_numerical_grads to tvm.testing_ (#2314)

parent 03872132
......@@ -7,6 +7,7 @@ import numpy as np
import tvm
from tvm.contrib import graph_runtime
from tvm.testing import check_numerical_grads
import nnvm
from nnvm.compiler import graph_util
......@@ -535,113 +536,3 @@ def check_function(symbol, forward=None, backward=None, grad_input_vars=None,
if nothing_was_done:
logging.warning("Nothing was done in check_function. Check ctx_list().")
def check_numerical_grads(function, input_values, grad_values, function_value=None,
delta=1e-3, atol=1e-2, rtol=0.1):
"""A helper function that checks that numerical gradients of a function are equal to
gradients computed in some different way (analytical gradients).
Numerical gradients are computed using finite difference approximation. To reduce the number of
function evaluations, the number of points used is gradually increased if the error value is
too high (up to 5 points).
Parameters
----------
function
A function that takes inputs as keyword arguments (like `function(**input_values)`) and
returns a scalar result. Should accept numpy ndarrays.
input_values : Dict[str, numpy.ndarray]
A dict assigning values to variables. Represents the point at which gradients should be
computed.
grad_values : Dict[str, numpy.ndarray]
Gradients computed using a different method.
function_value : float, optional
Should be equal to `function(**input_values)`.
delta : float, optional
A small number used for numerical computation of partial derivatives. The default 1e-3 is a
good choice for float32.
atol : float, optional
Absolute tolerance.
rtol : float, optional
Relative tolerance.
"""
if function_value is None:
function_value = function(**input_values)
# a helper to modify j-th element of val by a_delta
def modify(val, j, a_delta):
val = val.copy()
val.reshape(-1)[j] = val.reshape(-1)[j] + a_delta
return val
# numerically compute a partial derivative with respect to j-th element of the var `name`
def derivative(x_name, j, a_delta):
modified_values = {n: modify(val, j, a_delta) if n == x_name else val
for n, val in input_values.items()}
return (function(**modified_values) - function_value)/a_delta
def compare_derivative(j, n_der, grad):
der = grad.reshape(-1)[j]
return np.abs(n_der - der) < atol + rtol*np.abs(n_der)
for x_name, grad in grad_values.items():
if grad.shape != input_values[x_name].shape:
raise AssertionError(
"Gradient wrt '{}' has unexpected shape {}, expected {} "
.format(x_name, grad.shape, input_values[x_name].shape))
ngrad = np.zeros_like(grad)
# compute partial derivatives for each position in this variable
for j in range(np.prod(grad.shape)):
# forward difference approximation
nder = derivative(x_name, j, delta)
# if the derivative is not equal to the analytical one, try to use more
# precise and expensive methods
if not compare_derivative(j, nder, grad):
# central difference approximation
nder = (derivative(x_name, j, -delta) + nder)/2
if not compare_derivative(j, nder, grad):
# central difference approximation using h = delta/2
cnder2 = (derivative(x_name, j, delta/2) + derivative(x_name, j, -delta/2))/2
# five-point derivative
nder = (4*cnder2 - nder)/3
ngrad.reshape(-1)[j] = nder
dist = np.sqrt(np.sum((ngrad - grad)**2))
grad_norm = np.sqrt(np.sum(ngrad**2))
if not (np.isfinite(dist) and np.isfinite(grad_norm)):
raise ValueError(
"NaN or infinity detected during numerical gradient checking wrt {}\n"
"analytical grad = {}\n numerical grad = {}\n"
.format(x_name, grad, ngrad))
# we multiple atol by this number to make it more universal for different sizes
sqrt_n = np.sqrt(float(np.prod(grad.shape)))
if dist > atol*sqrt_n + rtol*grad_norm:
raise AssertionError(
"Analytical and numerical grads wrt {} differ too much\n"
"analytical grad = {}\n numerical grad = {}\n"
"distance > atol*sqrt(n) + rtol*grad_norm\n"
"distance {} > {}*{} + {}*{}"
.format(x_name, grad, ngrad,
dist, atol, sqrt_n, rtol, grad_norm))
max_diff = np.max(np.abs(ngrad - grad))
avg_diff = np.mean(np.abs(ngrad - grad))
logging.info("Numerical grad test wrt %s of shape %s passes, "
"dist = %f, max_diff = %f, avg_diff = %f",
x_name, grad.shape, dist, max_diff, avg_diff)
""" TVM testing utilities """
import logging
import numpy as np
def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
......@@ -10,3 +11,137 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
often allow `desired` to be close to zero, we generally want non-zero `atol`.
"""
np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)
def check_numerical_grads(function, input_values, grad_values, function_value=None,
delta=1e-3, atol=1e-2, rtol=0.1):
"""A helper function that checks that numerical gradients of a function are
equal to gradients computed in some different way (analytical gradients).
Numerical gradients are computed using finite difference approximation. To
reduce the number of function evaluations, the number of points used is
gradually increased if the error value is too high (up to 5 points).
Parameters
----------
function
A function that takes inputs either as positional or as keyword
arguments (either `function(*input_values)` or `function(**input_values)`
should be correct) and returns a scalar result. Should accept numpy
ndarrays.
input_values : Dict[str, numpy.ndarray] or List[numpy.ndarray]
A list of values or a dict assigning values to variables. Represents the
point at which gradients should be computed.
grad_values : Dict[str, numpy.ndarray] or List[numpy.ndarray]
Gradients computed using a different method.
function_value : float, optional
Should be equal to `function(**input_values)`.
delta : float, optional
A small number used for numerical computation of partial derivatives.
The default 1e-3 is a good choice for float32.
atol : float, optional
Absolute tolerance. Gets multiplied by `sqrt(n)` where n is the size of a
gradient.
rtol : float, optional
Relative tolerance.
"""
# If input_values is a list then function accepts positional arguments
# In this case transform it to a function taking kwargs of the form {"0": ..., "1": ...}
if not isinstance(input_values, dict):
input_len = len(input_values)
input_values = {str(idx): val for idx, val in enumerate(input_values)}
def _function(_input_len=input_len, _orig_function=function, **kwargs):
return _orig_function(*(kwargs[str(i)] for i in range(input_len)))
function = _function
grad_values = {str(idx): val for idx, val in enumerate(grad_values)}
if function_value is None:
function_value = function(**input_values)
# a helper to modify j-th element of val by a_delta
def modify(val, j, a_delta):
val = val.copy()
val.reshape(-1)[j] = val.reshape(-1)[j] + a_delta
return val
# numerically compute a partial derivative with respect to j-th element of the var `name`
def derivative(x_name, j, a_delta):
modified_values = {n: modify(val, j, a_delta) if n == x_name else val
for n, val in input_values.items()}
return (function(**modified_values) - function_value)/a_delta
def compare_derivative(j, n_der, grad):
der = grad.reshape(-1)[j]
return np.abs(n_der - der) < atol + rtol*np.abs(n_der)
for x_name, grad in grad_values.items():
if grad.shape != input_values[x_name].shape:
raise AssertionError(
"Gradient wrt '{}' has unexpected shape {}, expected {} "
.format(x_name, grad.shape, input_values[x_name].shape))
ngrad = np.zeros_like(grad)
wrong_positions = []
# compute partial derivatives for each position in this variable
for j in range(np.prod(grad.shape)):
# forward difference approximation
nder = derivative(x_name, j, delta)
# if the derivative is not equal to the analytical one, try to use more
# precise and expensive methods
if not compare_derivative(j, nder, grad):
# central difference approximation
nder = (derivative(x_name, j, -delta) + nder)/2
if not compare_derivative(j, nder, grad):
# central difference approximation using h = delta/2
cnder2 = (derivative(x_name, j, delta/2) + derivative(x_name, j, -delta/2))/2
# five-point derivative
nder = (4*cnder2 - nder)/3
# if the derivatives still don't match, add this position to the
# list of wrong positions
if not compare_derivative(j, nder, grad):
wrong_positions.append(np.unravel_index(j, grad.shape))
ngrad.reshape(-1)[j] = nder
wrong_percentage = int(100*len(wrong_positions)/np.prod(grad.shape))
dist = np.sqrt(np.sum((ngrad - grad)**2))
grad_norm = np.sqrt(np.sum(ngrad**2))
if not (np.isfinite(dist) and np.isfinite(grad_norm)):
raise ValueError(
"NaN or infinity detected during numerical gradient checking wrt '{}'\n"
"analytical grad = {}\n numerical grad = {}\n"
.format(x_name, grad, ngrad))
# we multiply atol by this number to make it more universal for different sizes
sqrt_n = np.sqrt(float(np.prod(grad.shape)))
if dist > atol*sqrt_n + rtol*grad_norm:
raise AssertionError(
"Analytical and numerical grads wrt '{}' differ too much\n"
"analytical grad = {}\n numerical grad = {}\n"
"{}% of elements differ, first 10 of wrong positions: {}\n"
"distance > atol*sqrt(n) + rtol*grad_norm\n"
"distance {} > {}*{} + {}*{}"
.format(x_name, grad, ngrad, wrong_percentage, wrong_positions[:10],
dist, atol, sqrt_n, rtol, grad_norm))
max_diff = np.max(np.abs(ngrad - grad))
avg_diff = np.mean(np.abs(ngrad - grad))
logging.info("Numerical grad test wrt '%s' of shape %s passes, "
"dist = %f, max_diff = %f, avg_diff = %f",
x_name, grad.shape, dist, max_diff, avg_diff)
import numpy as np
import tvm
from tvm.testing import check_numerical_grads
def test_check_numerical_grads():
# Functions and their derivatives
functions = [
lambda x: (x*x*x, 3*x*x),
lambda x: (x*x, 2*x),
lambda x: (np.abs(x), np.sign(x)),
lambda x: (np.log(np.abs(x)), 1/x),
lambda x: (np.sqrt(np.abs(x)), np.sign(x)/(2*np.sqrt(np.abs(x)))),
lambda x: (1/x, -1/(x*x)),
lambda x: (np.sign(np.sin(1/x)), np.zeros_like(x)),
lambda x: (x*np.sin(1/x), np.sin(1/x) - np.cos(1/x)/x),
lambda x: (np.sin(1/x), - np.cos(1/x)/(x*x)),
]
# Avoid values too close to 0 since singularities of our functions are there
min_x = 0.5
for func in functions:
x_input = np.random.uniform(min_x, 10, size=(3, 4))
# We need a function returning a scalar, so sum the results
func_forw = lambda x: np.sum(func(x)[0])
grads = [func(x_input)[1]]
check_numerical_grads(func_forw, [x_input], grads)
# Check functions with multiple arguments
for f1 in functions:
for f2 in functions:
x_input = np.random.uniform(min_x, 10, size=(3, 4))
y_input = np.random.uniform(min_x, 10, size=(3, 4))
func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
grads = [f1(x_input)[1], f2(y_input)[1]]
check_numerical_grads(func_forw, [x_input, y_input], grads)
# Same thing but with keyword arguments
func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
grads = {'x': f1(x_input)[1], 'y': f2(y_input)[1]}
check_numerical_grads(func_forw, {'x': x_input, 'y': y_input}, grads)
def _noise1(x, atol=1e-2, rtol=0.1):
# We go in random direction using twice the original tolerance to be sure this
# results in an error
sqrt_n = np.sqrt(float(np.prod(x.shape)))
tol = 2*(np.linalg.norm(x)*rtol + atol*sqrt_n)
noise = np.random.normal(size=x.shape)
noise = tol * noise / np.linalg.norm(noise)
return x + noise
def _noise2(x, atol=1e-2, rtol=0.1):
# This noise affects just a single component
sqrt_n = np.sqrt(float(np.prod(x.shape)))
tol = 2*(np.linalg.norm(x)*rtol + atol*sqrt_n)
n = np.random.randint(np.prod(x.shape))
noise = np.zeros_like(x)
noise.reshape(-1)[n] = tol
return x + noise
# Add noise to gradients and check that the function throws
for f1 in functions:
for f2 in functions:
x_input = np.random.uniform(min_x, 10, size=(3, 4))
y_input = np.random.uniform(min_x, 10, size=(3, 4))
func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
grads = [_noise1(f1(x_input)[1]), _noise1(f2(y_input)[1])]
try:
check_numerical_grads(func_forw, [x_input, y_input], grads)
except AssertionError as e:
pass
else:
raise AssertionError("check_numerical_grads didn't raise an exception")
func_forw = lambda x, y: np.sum(f1(x)[0] + f2(y)[0])
grads = {'x': _noise2(f1(x_input)[1]), 'y': _noise2(f2(y_input)[1])}
try:
check_numerical_grads(func_forw, {'x': x_input, 'y': y_input}, grads)
except AssertionError as e:
pass
else:
raise AssertionError("check_numerical_grads didn't raise an exception")
if __name__ == "__main__":
test_check_numerical_grads()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment