Commit 8e0aaa29 by 雾雨魔理沙 Committed by Wuwei Lin

[Relay] [Training] Add numerical gradient check. (#3630)

* add check_grad

* finish

* what does the fox say?

* lint lint lint lint lint lint lint lint lint
parent 87e18a44
...@@ -14,11 +14,17 @@ ...@@ -14,11 +14,17 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
#pylint: disable=invalid-name
"""Utilities for testing and benchmarks""" """Utilities for testing and benchmarks"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm
import tvm.relay as relay import tvm.relay as relay
import tvm.relay.op as op
from tvm.relay import transform from tvm.relay import transform
from tvm.relay import Function, GlobalVar, ScopeBuilder, Tuple, TupleGetItem, create_executor
from tvm.relay import TensorType, TupleType
import numpy as np
from . import mlp from . import mlp
from . import resnet from . import resnet
...@@ -36,7 +42,7 @@ from .config import ctx_list ...@@ -36,7 +42,7 @@ from .config import ctx_list
from .init import create_workload from .init import create_workload
from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
from .py_converter import to_python, run_as_python from .py_converter import to_python, run_as_python
from ..transform import gradient
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, transform.Pass)
...@@ -48,3 +54,74 @@ def run_opt_pass(expr, opt_pass): ...@@ -48,3 +54,74 @@ def run_opt_pass(expr, opt_pass):
def run_infer_type(expr): def run_infer_type(expr):
return run_opt_pass(expr, transform.InferType()) return run_opt_pass(expr, transform.InferType())
def rand_from_type(t):
return relay.Constant(rand(t.dtype, *[int(d) for d in t.shape]))
CHECK_GRAD_COUNTER = 0
def check_grad(func, mod=None):
"""
Test that directional gradient calculated by reverse mode
is close to the one calculated by finite difference.
"""
global CHECK_GRAD_COUNTER
if mod is None:
mod = relay.Module()
def make(name):
return GlobalVar(name + str(CHECK_GRAD_COUNTER))
func_name = make("func_")
back_func_name = make("back_func_")
finite_difference_func_name = make("finite_difference_")
reverse_mode_func_name = make("reverse_mode_")
check_func_name = make("check_func_")
CHECK_GRAD_COUNTER = CHECK_GRAD_COUNTER + 1
epsilon = relay.const(0.01)
mod[func_name] = func
mod[back_func_name] = gradient(mod[func_name], mod=mod)
params = mod[func_name].params
directions = [rand_from_type(x.checked_type) for x in params]
ft = TensorType(())
sb = ScopeBuilder()
def get_reverse_mode_result(e, d, t):
assert isinstance(t, TensorType)
return op.cast(e * d, 'float32')
bf = sb.let("bf", TupleGetItem(back_func_name(*params), 1))
reverse_mode_results = [get_reverse_mode_result(TupleGetItem(bf, i),
directions[i],
x.checked_type)
for i, x in enumerate(params)]
reverse_mode_result = relay.const(0.0)
for x in reverse_mode_results:
reverse_mode_result = reverse_mode_result + op.reduce.sum(x)
sb.ret(reverse_mode_result)
reverse_mode_result = sb.get()
mod[reverse_mode_func_name] = Function(params,
reverse_mode_result,
ft,
mod[func_name].type_params,
mod[func_name].attrs)
finite_difference_result = op.reduce.sum((func_name(*[x + epsilon * y for x, y in
zip(params, directions)]) -
func_name(*params)) /
epsilon)
mod[finite_difference_func_name] = Function(params,
finite_difference_result,
ft,
mod[func_name].type_params,
mod[func_name].attrs)
check_func_result = op.abs(reverse_mode_func_name(*params) -
finite_difference_func_name(*params))
mod[check_func_name] = Function(params,
check_func_result,
ft,
mod[func_name].type_params,
mod[func_name].attrs)
ex = create_executor(mod=mod)
res = ex.evaluate(check_func_name(*[rand_from_type(x.checked_type) for x in params]))
assert res.data.asnumpy() < 0.001
def rand(dtype, *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
...@@ -22,11 +22,7 @@ from tvm.relay.analysis import free_vars, free_type_vars ...@@ -22,11 +22,7 @@ from tvm.relay.analysis import free_vars, free_type_vars
from tvm.relay import create_executor, transform from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient from tvm.relay.transform import gradient
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand
def rand(dtype='float32', *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
def test_id(): def test_id():
...@@ -61,6 +57,16 @@ def test_add(): ...@@ -61,6 +57,16 @@ def test_add():
tvm.testing.assert_allclose(grad.asnumpy(), 2 * np.ones_like(x.asnumpy())) tvm.testing.assert_allclose(grad.asnumpy(), 2 * np.ones_like(x.asnumpy()))
def test_check_grad():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = relay.var("y", t)
func = relay.Function([x, y], x + y)
check_grad(func)
def test_temp_add(): def test_temp_add():
shape = (10, 10) shape = (10, 10)
dtype = 'float32' dtype = 'float32'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment