Commit 1ad6a2af by 雾雨魔理沙 Committed by Thierry Moreau

[Relay] crossentropy_with_logits and its gradient (#4075)

* save

* lint
parent 493c98d3
......@@ -37,3 +37,4 @@ _reg.register_schedule("prod", _schedule_reduce)
_reg.register_schedule("mean", _schedule_reduce)
_reg.register_schedule("variance", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)
......@@ -449,3 +449,12 @@ def cross_entropy_grad(orig, grad):
batch_size = take(shape, const(0, dtype='int32'), axis=0)
grad = grad / batch_size.astype('float32')
return [-grad * y / x, -grad * log(x)]
@register_gradient("nn.cross_entropy_with_logits")
def cross_entropy_with_logits_grad(orig, grad):
x, y = orig.args
shape = shape_of(x)
batch_size = take(shape, const(0, dtype='int32'), axis=0)
grad = grad / batch_size.astype('float32')
return [-grad * y, -grad * x]
......@@ -770,3 +770,12 @@ reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
def compute_cross_entropy(attrs, inputs, out_dtype, target):
x, y = inputs
return [-topi.sum(topi.log(x) * y) / x.shape[0]]
reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)
@reg.register_compute("nn.cross_entropy_with_logits")
def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target):
x, y = inputs
return [-topi.sum(x * y) / x.shape[0]]
......@@ -1807,3 +1807,22 @@ def cross_entropy(predictions, targets):
The computed result.
"""
return _make.cross_entropy(predictions, targets)
def cross_entropy_with_logits(predictions, targets):
"""CrossEntropy with logits.
Parameters
----------
predictions : tvm.relay.Expr
The predictions.
targets : tvm.relay.Expr
The targets.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.cross_entropy_with_logits(predictions, targets)
......@@ -910,7 +910,7 @@ bool CrossEntropyRel(const Array<Type>& types,
return true;
}
// Positional relay function to create batch_matmul operator used by frontend FFI.
// Positional relay function to create cross_entropy operator used by frontend FFI.
Expr MakeCrossEntropy(Expr predictions, Expr targets) {
static const Op& op = Op::Get("nn.cross_entropy");
return CallNode::make(op, {predictions, targets}, Attrs(), {});
......@@ -933,5 +933,28 @@ Do log on the data - do not accept logits.
.add_type_rel("CrossEntropy", CrossEntropyRel);
// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) {
static const Op& op = Op::Get("nn.cross_entropy_with_logits");
return CallNode::make(op, {predictions, targets}, Attrs(), {});
}
TVM_REGISTER_API("relay.op.nn._make.cross_entropy_with_logits")
.set_body_typed(MakeCrossEntropyWithLogits);
RELAY_REGISTER_OP("nn.cross_entropy_with_logits")
.describe(R"code(
Computes cross entropy given predictions and targets.
Accept logits.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("x", "1D Tensor", "Predictions.")
.add_argument("y", "1D Tensor", "Targets.")
.set_support_level(10)
.add_type_rel("CrossEntropy", CrossEntropyRel);
} // namespace relay
} // namespace tvm
......@@ -14,15 +14,23 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from tvm import relay
from tvm.relay.testing import check_grad
def test_cross_entropy_grad():
x = relay.var("x", shape=(1, 5))
y = relay.var("y", shape=(1, 5))
x = relay.var("x", shape=(2, 5))
y = relay.var("y", shape=(2, 5))
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1)
def test_cross_entropy_with_logits_grad():
x = relay.var("x", shape=(2, 5))
y = relay.var("y", shape=(2, 5))
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)
if __name__ == "__main__":
test_cross_entropy_grad()
pytest.main([__file__])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment