Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6a377f77
Commit
6a377f77
authored
Sep 07, 2019
by
雾雨魔理沙
Committed by
Wuwei Lin
Sep 07, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Training] Add gradient for cast (#3894)
save fix fix grad
parent
184fa484
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
87 additions
and
0 deletions
+87
-0
python/tvm/relay/op/_tensor_grad.py
+7
-0
python/tvm/relay/op/_transform.py
+1
-0
python/tvm/relay/op/transform.py
+17
-0
src/relay/op/tensor/transform.cc
+57
-0
tests/python/relay/test_op_grad_level3.py
+5
-0
No files found.
python/tvm/relay/op/_tensor_grad.py
View file @
6a377f77
...
@@ -29,6 +29,7 @@ from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like
...
@@ -29,6 +29,7 @@ from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like
from
.transform
import
(
from
.transform
import
(
broadcast_to_like
,
broadcast_to_like
,
collapse_sum_like
,
collapse_sum_like
,
cast_like
,
reshape
,
reshape
,
reshape_like
,
reshape_like
,
strided_slice
,
strided_slice
,
...
@@ -296,6 +297,12 @@ def reshape_grad(orig, grad):
...
@@ -296,6 +297,12 @@ def reshape_grad(orig, grad):
return
[
reshape_like
(
grad
,
orig
.
args
[
0
])]
return
[
reshape_like
(
grad
,
orig
.
args
[
0
])]
@register_gradient
(
"cast"
)
def
cast_grad
(
orig
,
grad
):
x
=
orig
.
args
[
0
]
return
[
cast_like
(
grad
,
x
)]
@register_gradient
(
"nn.batch_flatten"
)
@register_gradient
(
"nn.batch_flatten"
)
def
batch_flatten_grad
(
orig
,
grad
):
def
batch_flatten_grad
(
orig
,
grad
):
"""Returns grad reshaped to data dims"""
"""Returns grad reshaped to data dims"""
...
...
python/tvm/relay/op/_transform.py
View file @
6a377f77
...
@@ -43,6 +43,7 @@ _reg.register_schedule("reverse", schedule_injective)
...
@@ -43,6 +43,7 @@ _reg.register_schedule("reverse", schedule_injective)
_reg
.
register_schedule
(
"repeat"
,
schedule_broadcast
)
_reg
.
register_schedule
(
"repeat"
,
schedule_broadcast
)
_reg
.
register_schedule
(
"tile"
,
schedule_broadcast
)
_reg
.
register_schedule
(
"tile"
,
schedule_broadcast
)
_reg
.
register_schedule
(
"cast"
,
schedule_injective
)
_reg
.
register_schedule
(
"cast"
,
schedule_injective
)
_reg
.
register_schedule
(
"cast_like"
,
schedule_injective
)
_reg
.
register_schedule
(
"reinterpret"
,
schedule_injective
)
_reg
.
register_schedule
(
"reinterpret"
,
schedule_injective
)
_reg
.
register_schedule
(
"strided_slice"
,
schedule_injective
)
_reg
.
register_schedule
(
"strided_slice"
,
schedule_injective
)
_reg
.
register_schedule
(
"slice_like"
,
schedule_injective
)
_reg
.
register_schedule
(
"slice_like"
,
schedule_injective
)
...
...
python/tvm/relay/op/transform.py
View file @
6a377f77
...
@@ -40,6 +40,23 @@ def cast(data, dtype):
...
@@ -40,6 +40,23 @@ def cast(data, dtype):
return
_relay_make
.
cast
(
data
,
dtype
)
return
_relay_make
.
cast
(
data
,
dtype
)
def
cast_like
(
data
,
dtype_like
):
"""Cast input tensor to data type of another tensor.
Parameters
----------
data : relay.Expr
The input data to the operator.
dtype_like: relay.Expr
The tensor to cast to.
Returns
-------
result : relay.Expr
The casted result.
"""
from
..
import
_make
as
_relay_make
return
_relay_make
.
cast_like
(
data
,
dtype_like
)
def
reinterpret
(
data
,
dtype
):
def
reinterpret
(
data
,
dtype
):
"""Reinterpret input tensor to data type.
"""Reinterpret input tensor to data type.
...
...
src/relay/op/tensor/transform.cc
View file @
6a377f77
...
@@ -98,6 +98,63 @@ RELAY_REGISTER_OP("cast")
...
@@ -98,6 +98,63 @@ RELAY_REGISTER_OP("cast")
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
kElemWise
)
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
kElemWise
)
.
set_attr
<
FInferCorrectLayout
>
(
"FInferCorrectLayout"
,
ElemwiseArbitraryLayout
);
.
set_attr
<
FInferCorrectLayout
>
(
"FInferCorrectLayout"
,
ElemwiseArbitraryLayout
);
// relay.cast_like
bool
CastLikeRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
3
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
if
(
data
==
nullptr
)
{
CHECK
(
types
[
0
].
as
<
IncompleteTypeNode
>
())
<<
"cast: expect input type to be TensorType but get "
<<
types
[
0
];
return
false
;
}
const
auto
*
dtype_like
=
types
[
1
].
as
<
TensorTypeNode
>
();
if
(
dtype_like
==
nullptr
)
{
CHECK
(
types
[
1
].
as
<
IncompleteTypeNode
>
())
<<
"cast: expect input type to be TensorType but get "
<<
types
[
1
];
return
false
;
}
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
data
->
shape
,
dtype_like
->
dtype
));
return
true
;
}
Array
<
Tensor
>
CastLikeCompute
(
const
Attrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Type
&
out_type
,
const
Target
&
target
)
{
return
{
topi
::
cast
(
inputs
[
0
],
inputs
[
1
]
->
dtype
)
};
}
Expr
MakeCastLike
(
Expr
data
,
Expr
dtype_like
)
{
static
const
Op
&
op
=
Op
::
Get
(
"cast_like"
);
return
CallNode
::
make
(
op
,
{
data
,
dtype_like
},
Attrs
(),
{});
}
TVM_REGISTER_API
(
"relay._make.cast_like"
)
.
set_body_typed
(
MakeCastLike
);
RELAY_REGISTER_OP
(
"cast_like"
)
.
describe
(
R"code(Cast the data into the type of another tensor.
)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"dtype_like"
,
"Tensor"
,
"The tensor to cast to."
)
.
set_support_level
(
3
)
.
add_type_rel
(
"CastLike"
,
CastLikeRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
CastLikeCompute
)
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
kElemWise
)
.
set_attr
<
FInferCorrectLayout
>
(
"FInferCorrectLayout"
,
ElemwiseArbitraryLayout
);
Array
<
Tensor
>
ReinterpretCompute
(
const
Attrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
Array
<
Tensor
>
ReinterpretCompute
(
const
Attrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Type
&
out_type
,
const
Target
&
target
)
{
const
Type
&
out_type
,
const
Target
&
target
)
{
const
CastAttrs
*
param
=
attrs
.
as
<
CastAttrs
>
();
const
CastAttrs
*
param
=
attrs
.
as
<
CastAttrs
>
();
...
...
tests/python/relay/test_op_grad_level3.py
View file @
6a377f77
...
@@ -58,5 +58,10 @@ def test_negative_grad():
...
@@ -58,5 +58,10 @@ def test_negative_grad():
check_grad
(
fwd_func
)
check_grad
(
fwd_func
)
def
test_cast_grad
():
data
=
relay
.
var
(
"data"
,
relay
.
TensorType
((
10
,
4
),
"float32"
))
fwd_func
=
relay
.
Function
([
data
],
relay
.
cast
(
data
,
"float64"
))
check_grad
(
fwd_func
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
()
pytest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment