Commit 6a377f77 by 雾雨魔理沙 Committed by Wuwei Lin

[Relay][Training] Add gradient for cast (#3894)

save

fix

fix grad
parent 184fa484
......@@ -29,6 +29,7 @@ from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like
from .transform import (
broadcast_to_like,
collapse_sum_like,
cast_like,
reshape,
reshape_like,
strided_slice,
......@@ -296,6 +297,12 @@ def reshape_grad(orig, grad):
return [reshape_like(grad, orig.args[0])]
@register_gradient("cast")
def cast_grad(orig, grad):
x = orig.args[0]
return [cast_like(grad, x)]
@register_gradient("nn.batch_flatten")
def batch_flatten_grad(orig, grad):
"""Returns grad reshaped to data dims"""
......
......@@ -43,6 +43,7 @@ _reg.register_schedule("reverse", schedule_injective)
_reg.register_schedule("repeat", schedule_broadcast)
_reg.register_schedule("tile", schedule_broadcast)
_reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("cast_like", schedule_injective)
_reg.register_schedule("reinterpret", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
......
......@@ -40,6 +40,23 @@ def cast(data, dtype):
return _relay_make.cast(data, dtype)
def cast_like(data, dtype_like):
"""Cast input tensor to data type of another tensor.
Parameters
----------
data : relay.Expr
The input data to the operator.
dtype_like: relay.Expr
The tensor to cast to.
Returns
-------
result : relay.Expr
The casted result.
"""
from .. import _make as _relay_make
return _relay_make.cast_like(data, dtype_like)
def reinterpret(data, dtype):
"""Reinterpret input tensor to data type.
......
......@@ -98,6 +98,63 @@ RELAY_REGISTER_OP("cast")
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
// relay.cast_like
bool CastLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "cast: expect input type to be TensorType but get "
<< types[0];
return false;
}
const auto* dtype_like = types[1].as<TensorTypeNode>();
if (dtype_like == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "cast: expect input type to be TensorType but get "
<< types[1];
return false;
}
reporter->Assign(types[2], TensorTypeNode::make(data->shape, dtype_like->dtype));
return true;
}
Array<Tensor> CastLikeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::cast(inputs[0], inputs[1]->dtype) };
}
Expr MakeCastLike(Expr data,
Expr dtype_like) {
static const Op& op = Op::Get("cast_like");
return CallNode::make(op, {data, dtype_like}, Attrs(), {});
}
TVM_REGISTER_API("relay._make.cast_like")
.set_body_typed(MakeCastLike);
RELAY_REGISTER_OP("cast_like")
.describe(R"code(Cast the data into the type of another tensor.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("dtype_like", "Tensor", "The tensor to cast to.")
.set_support_level(3)
.add_type_rel("CastLike", CastLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", CastLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
Array<Tensor> ReinterpretCompute(const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_type, const Target& target) {
const CastAttrs* param = attrs.as<CastAttrs>();
......
......@@ -58,5 +58,10 @@ def test_negative_grad():
check_grad(fwd_func)
def test_cast_grad():
data = relay.var("data", relay.TensorType((10, 4), "float32"))
fwd_func = relay.Function([data], relay.cast(data, "float64"))
check_grad(fwd_func)
if __name__ == "__main__":
pytest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment