Commit 7d71dd8b by 雾雨魔理沙 Committed by Wuwei Lin

[Relay][Training] Add gradient for Crossentropy (#3925)

* save

save

redo max test

save

address comment

fix

* address comment

* increase rtol

* address review comment
parent 59d8d400
...@@ -36,3 +36,4 @@ _reg.register_schedule("min", _schedule_reduce) ...@@ -36,3 +36,4 @@ _reg.register_schedule("min", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce) _reg.register_schedule("prod", _schedule_reduce)
_reg.register_schedule("mean", _schedule_reduce) _reg.register_schedule("mean", _schedule_reduce)
_reg.register_schedule("variance", _schedule_reduce) _reg.register_schedule("variance", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
...@@ -25,7 +25,18 @@ from ..expr import Tuple, TupleGetItem, const ...@@ -25,7 +25,18 @@ from ..expr import Tuple, TupleGetItem, const
from . import nn as _nn from . import nn as _nn
from .op import register_gradient from .op import register_gradient
from .reduce import sum as _sum from .reduce import sum as _sum
from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like, equal from .tensor import (
cos,
exp,
less,
negative,
ones_like,
power,
sin,
zeros_like,
equal,
shape_of,
log)
from .transform import ( from .transform import (
broadcast_to_like, broadcast_to_like,
collapse_sum_like, collapse_sum_like,
...@@ -33,6 +44,7 @@ from .transform import ( ...@@ -33,6 +44,7 @@ from .transform import (
reshape, reshape,
reshape_like, reshape_like,
strided_slice, strided_slice,
take,
tile, tile,
transpose, transpose,
where, where,
...@@ -353,3 +365,12 @@ def sum_grad(orig, grad): ...@@ -353,3 +365,12 @@ def sum_grad(orig, grad):
"""Returns grad broadcasted to data dims""" """Returns grad broadcasted to data dims"""
data = orig.args[0] data = orig.args[0]
return [broadcast_to_like(grad, data)] return [broadcast_to_like(grad, data)]
@register_gradient("nn.cross_entropy")
def cross_entropy_grad(orig, grad):
x, y = orig.args
shape = shape_of(x)
batch_size = take(shape, const(0, dtype='int32'), axis=0)
grad = grad / batch_size.astype('float32')
return [-grad * y / x, -grad * log(x)]
...@@ -745,3 +745,12 @@ def schedule_bitserial_dense(attrs, outputs, target): ...@@ -745,3 +745,12 @@ def schedule_bitserial_dense(attrs, outputs, target):
reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
@reg.register_compute("nn.cross_entropy")
def compute_cross_entropy(attrs, inputs, out_dtype, target):
x, y = inputs
return [-topi.sum(topi.log(x) * y) / x.shape[0]]
...@@ -1758,3 +1758,22 @@ def bitserial_dense(data, ...@@ -1758,3 +1758,22 @@ def bitserial_dense(data,
""" """
return _make.bitserial_dense(data, weight, units, data_bits, weight_bits, return _make.bitserial_dense(data, weight, units, data_bits, weight_bits,
pack_dtype, out_dtype, unipolar) pack_dtype, out_dtype, unipolar)
def cross_entropy(predictions, targets):
"""CrossEntropy without logits.
Parameters
----------
predictions : tvm.relay.Expr
The predictions.
targets : tvm.relay.Expr
The targets.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.cross_entropy(predictions, targets)
...@@ -56,11 +56,11 @@ def run_infer_type(expr): ...@@ -56,11 +56,11 @@ def run_infer_type(expr):
return run_opt_pass(expr, transform.InferType()) return run_opt_pass(expr, transform.InferType())
def _np_randn_from_type(t, scale=1): def _np_randn_from_type(t, scale=1, mean=0):
return (scale * np.random.randn(*(int(d) for d in t.shape))).astype(t.dtype) return (mean + (scale * np.random.randn(*(int(d) for d in t.shape)))).astype(t.dtype)
def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3): def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, mean=0):
"""Perform numerical gradient checking given a relay function. """Perform numerical gradient checking given a relay function.
Compare analytical gradients to numerical gradients derived from two-sided approximation. Note Compare analytical gradients to numerical gradients derived from two-sided approximation. Note
...@@ -86,15 +86,23 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3): ...@@ -86,15 +86,23 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3):
The relative tolerance on difference between numerical and analytical gradients. Note that The relative tolerance on difference between numerical and analytical gradients. Note that
this needs to be scaled appropriately relative to the chosen eps. this needs to be scaled appropriately relative to the chosen eps.
scale: float
The standard deviation of the inputs.
mean: float
The mean of the inputs.
""" """
fwd_func = run_infer_type(func) fwd_func = run_infer_type(func)
bwd_func = run_infer_type(gradient(fwd_func)) bwd_func = run_infer_type(gradient(fwd_func))
if scale is None:
scale = 10 * eps
if inputs is None: if inputs is None:
params = fwd_func.params params = fwd_func.params
# Generate random inputs on the same scale as epsilon to avoid numerical precision loss. # Generate random inputs on the same scale as epsilon to avoid numerical precision loss.
inputs = [_np_randn_from_type(x.checked_type, scale=(10 * eps)) for x in params] inputs = [_np_randn_from_type(x.checked_type, scale=scale, mean=mean) for x in params]
for target, ctx in ctx_list(): for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target) intrp = relay.create_executor(ctx=ctx, target=target)
......
...@@ -817,5 +817,54 @@ are data in batch. ...@@ -817,5 +817,54 @@ are data in batch.
.add_type_rel("BatchMatmul", BatchMatmulRel); .add_type_rel("BatchMatmul", BatchMatmulRel);
// relay.nn.cross_entropy
bool CrossEntropyRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* x = types[0].as<TensorTypeNode>();
const auto* y = types[1].as<TensorTypeNode>();
if (x == nullptr || y == nullptr) return false;
CHECK(x->shape.size() == 2 && y->shape.size() == 2)
<< "CrossEntropy: shapes of x and y is inconsistent, "
<< "x shape = " << x->shape << ", "
<< "y shape = " << y->shape;
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
<< "CrossEntropy: shapes of x and y is inconsistent, "
<< "x shape = " << x->shape << ", "
<< "y shape = " << y->shape;
CHECK(reporter->AssertEQ(x->shape[1], y->shape[1]))
<< "CrossEntropy: shapes of x and y is inconsistent, "
<< "x shape = " << x->shape << ", "
<< "y shape = " << y->shape;
// assign output type
reporter->Assign(types[2], TensorTypeNode::make({}, x->dtype));
return true;
}
// Positional relay function to create batch_matmul operator used by frontend FFI.
Expr MakeCrossEntropy(Expr predictions, Expr targets) {
static const Op& op = Op::Get("nn.cross_entropy");
return CallNode::make(op, {predictions, targets}, Attrs(), {});
}
TVM_REGISTER_API("relay.op.nn._make.cross_entropy")
.set_body_typed(MakeCrossEntropy);
RELAY_REGISTER_OP("nn.cross_entropy")
.describe(R"code(
Computes cross entropy given predictions and targets.
Do log on the data - do not accept logits.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("x", "1D Tensor", "Predictions.")
.add_argument("y", "1D Tensor", "Targets.")
.set_support_level(10)
.add_type_rel("CrossEntropy", CrossEntropyRel);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm import relay
from tvm.relay.testing import check_grad
def test_cross_entropy_grad():
x = relay.var("x", shape=(1, 5))
y = relay.var("y", shape=(1, 5))
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1)
if __name__ == "__main__":
test_cross_entropy_grad()
...@@ -32,14 +32,14 @@ def test_sum_grad(): ...@@ -32,14 +32,14 @@ def test_sum_grad():
def test_max_grad(): def test_max_grad():
s = (5, 10) s = (10, 10)
t = relay.TensorType(s) t = relay.TensorType(s)
x = relay.var("x", t) x = relay.var("x", t)
axis = 0 axis = 0
z = relay.max(x, axis) z = relay.max(x, axis)
fwd_func = relay.Function([x], z) fwd_func = relay.Function([x], z)
check_grad(fwd_func, eps=1e-7, rtol=1) check_grad(fwd_func, scale=1e-3)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -67,7 +67,7 @@ def test_resize(): ...@@ -67,7 +67,7 @@ def test_resize():
for kind in ["graph", "debug"]: for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target) intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data) op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4)
for method in ["bilinear", "nearest_neighbor"]: for method in ["bilinear", "nearest_neighbor"]:
for layout in ["NHWC", "NCHW"]: for layout in ["NHWC", "NCHW"]:
verify_resize((1, 4, 4, 4), 2, method, layout) verify_resize((1, 4, 4, 4), 2, method, layout)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment