Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
52fde8f7
Commit
52fde8f7
authored
Aug 09, 2019
by
雾雨魔理沙
Committed by
Thierry Moreau
Aug 09, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] [Training] Fix ad for concatenate (#3729)
* reproduce error * fix * lint * lint
parent
45827220
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
108 additions
and
13 deletions
+108
-13
python/tvm/relay/op/_tensor_grad.py
+12
-1
src/relay/ir/alpha_equal.cc
+3
-0
src/relay/pass/gradient.cc
+77
-11
tests/python/relay/test_pass_gradient.py
+16
-1
No files found.
python/tvm/relay/op/_tensor_grad.py
View file @
52fde8f7
...
...
@@ -17,7 +17,7 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from
__future__
import
absolute_import
from
..expr
import
const
from
..expr
import
const
,
Tuple
,
TupleGetItem
from
.op
import
register_gradient
from
.transform
import
collapse_sum_like
,
broadcast_to_like
,
where
from
.tensor
import
exp
,
negative
,
power
,
less
,
cos
,
sin
...
...
@@ -176,3 +176,14 @@ def avg_pool2d_grad(orig, grad):
layout
=
attrs
.
layout
,
ceil_mode
=
attrs
.
ceil_mode
,
count_include_pad
=
attrs
.
count_include_pad
)
return
[
pool_grad
]
# not implemented, this is only for testing.
@register_gradient
(
"concatenate"
)
def
concatenate_grad
(
orig
,
grad
):
assert
len
(
orig
.
args
)
==
1
t
=
orig
.
args
[
0
]
x
=
TupleGetItem
(
t
,
0
)
y
=
TupleGetItem
(
t
,
1
)
# Assume only two element in tuple rn.
# In the real implementation, concatenate_grad probably need to be implemented by an operator.
return
[
Tuple
([
zeros_like
(
x
),
zeros_like
(
y
)])]
src/relay/ir/alpha_equal.cc
View file @
52fde8f7
...
...
@@ -117,9 +117,12 @@ class AlphaEqualHandler:
* \return the comparison result.
*/
bool
TypeEqual
(
const
Type
&
lhs
,
const
Type
&
rhs
)
{
auto
compute
=
[
&
](){
if
(
lhs
.
same_as
(
rhs
))
return
true
;
if
(
!
lhs
.
defined
()
||
!
rhs
.
defined
())
return
false
;
return
this
->
VisitType
(
lhs
,
rhs
);
};
return
Compare
(
compute
(),
lhs
,
rhs
);
}
bool
Compare
(
bool
result
,
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
)
{
...
...
src/relay/pass/gradient.cc
View file @
52fde8f7
...
...
@@ -29,6 +29,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include "pattern_util.h"
#include "pass_util.h"
#include "let_list.h"
#include "../ir/type_functor.h"
...
...
@@ -257,11 +258,79 @@ struct ReverseADType : TypeMutator {
}
};
Type
ReverseType
(
const
Type
&
t
)
{
return
ReverseADType
()(
t
);
}
/*! \brief Lift a function that transform Tensor to a function that also transform more type
* by doing a structure preserving map.
*/
Expr
LiftTensor
(
const
std
::
function
<
Expr
(
const
Expr
&
t
)
>&
f
,
const
Type
&
t
,
const
Expr
&
e
,
LetList
*
ll
)
{
CHECK
(
IsAtomic
(
e
))
<<
e
;
if
(
t
.
as
<
TensorTypeNode
>
())
{
return
f
(
e
);
}
else
if
(
auto
*
tt
=
t
.
as
<
TupleTypeNode
>
())
{
tvm
::
Array
<
Expr
>
fields
;
for
(
size_t
i
=
0
;
i
<
tt
->
fields
.
size
();
++
i
)
{
fields
.
push_back
(
LiftTensor
(
f
,
tt
->
fields
[
i
],
ll
->
Push
(
GetField
(
e
,
i
)),
ll
));
}
return
TupleNode
::
make
(
fields
);
}
else
{
LOG
(
FATAL
)
<<
"unsupported input/output type: "
<<
tt
;
throw
;
}
}
/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
Expr
GetRev
(
const
Type
&
t
,
const
Expr
&
e
,
LetList
*
ll
)
{
auto
rev
=
[
&
](
const
Expr
&
e
)
{
return
Pair
(
e
,
ll
->
Push
(
RefCreateNode
::
make
(
ZerosLike
(
e
))));
};
return
LiftTensor
(
rev
,
t
,
e
,
ll
);
}
/*! \brief ReverseType(t) -> t. Get the original value. */
Expr
GetValue
(
const
Type
&
t
,
const
Expr
&
e
,
LetList
*
ll
)
{
return
LiftTensor
([
&
](
const
Expr
&
e
)
{
return
GetField
(
e
,
0
);
},
t
,
e
,
ll
);
}
/*! \brief ReverseType(t) -> t. Get the gradient. */
Expr
GetGrad
(
const
Type
&
t
,
const
Expr
&
e
,
LetList
*
ll
)
{
auto
grad
=
[
&
](
const
Expr
&
e
)
{
return
ll
->
Push
(
RefReadNode
::
make
(
GetField
(
e
,
1
)));
};
return
LiftTensor
(
grad
,
t
,
e
,
ll
);
}
void
UpdateGrad
(
const
Type
&
t
,
const
Expr
&
arg
,
const
Expr
&
grad
,
LetList
*
ll
)
{
if
(
t
.
as
<
TensorTypeNode
>
())
{
ll
->
Push
(
RefWriteNode
::
make
(
GetField
(
arg
,
1
),
Add
(
ll
->
Push
(
RefReadNode
::
make
(
GetField
(
arg
,
1
))),
grad
)));
}
else
if
(
auto
*
tt
=
t
.
as
<
TupleTypeNode
>
())
{
for
(
size_t
i
=
0
;
i
<
tt
->
fields
.
size
();
++
i
)
{
UpdateGrad
(
tt
->
fields
[
i
],
ll
->
Push
(
GetField
(
arg
,
i
)),
ll
->
Push
(
GetField
(
grad
,
i
)),
ll
);
}
}
else
{
LOG
(
FATAL
)
<<
"unsupported arg type of operator: "
<<
t
;
throw
;
}
}
struct
ReverseAD
:
ExprMutator
{
Var
bp
;
const
OpMap
<
FPrimalGradient
>
rev_map
=
Op
::
GetAttr
<
FPrimalGradient
>
(
"FPrimalGradient"
);
ReverseAD
(
const
Var
&
bp
)
:
bp
(
bp
)
{
}
/// NOLINT(*)
explicit
ReverseAD
(
const
Var
&
bp
)
:
bp
(
bp
)
{
}
Expr
VisitExpr_
(
const
OpNode
*
op
)
final
{
LOG
(
FATAL
)
<<
"op should only be inside call"
;
...
...
@@ -279,29 +348,26 @@ struct ReverseAD : ExprMutator {
args
.
push_back
(
ll
->
Push
(
VisitExpr
(
arg
)));
}
std
::
vector
<
Expr
>
orig_args
;
for
(
const
auto
&
arg
:
args
)
{
orig_args
.
push_back
(
Get
Field
(
arg
,
0
));
for
(
size_t
i
=
0
;
i
<
args
.
size
();
++
i
)
{
orig_args
.
push_back
(
Get
Value
(
op
->
args
[
i
]
->
checked_type
(),
args
[
i
],
ll
));
}
Expr
orig
=
CallNode
::
make
(
op
->
op
,
orig_args
,
op
->
attrs
,
op
->
type_args
);
Var
orig_var
=
ll
->
Push
(
orig
);
auto
ref
=
ll
->
Push
(
RefCreateNode
::
make
(
ZerosLike
(
orig_var
)));
auto
ret
=
ll
->
Push
(
GetRev
(
op
->
checked_type
(),
ll
->
Push
(
orig
),
ll
));
auto
bpv
=
ll
->
Push
(
RefReadNode
::
make
(
bp
));
Expr
nbp
=
FunctionNode
::
make
(
{},
LetList
::
With
([
&
](
LetList
*
ll
)
{
tvm
::
Array
<
Expr
>
rev
=
rev_map
[
op_ref
](
orig
,
ll
->
Push
(
RefReadNode
::
make
(
ref
)
));
tvm
::
Array
<
Expr
>
rev
=
rev_map
[
op_ref
](
orig
,
GetGrad
(
op
->
checked_type
(),
ret
,
ll
));
CHECK
(
args
.
size
()
==
rev
.
size
());
for
(
size_t
i
=
0
;
i
<
args
.
size
();
++
i
)
{
ll
->
Push
(
RefWriteNode
::
make
(
GetField
(
args
[
i
],
1
),
Add
(
ll
->
Push
(
RefReadNode
::
make
(
GetField
(
args
[
i
],
1
))),
rev
[
i
])));
UpdateGrad
(
op
->
args
[
i
]
->
checked_type
(),
args
[
i
],
rev
[
i
],
ll
);
}
return
CallNode
::
make
(
bpv
,
{});
}),
TupleTypeNode
::
make
({}),
{});
ll
->
Push
(
RefWriteNode
::
make
(
bp
,
nbp
));
return
Pair
(
orig_var
,
ref
)
;
return
ret
;
});
}
return
ExprMutator
::
VisitExpr_
(
op
);
...
...
@@ -319,7 +385,7 @@ struct ReverseAD : ExprMutator {
}
Type
VisitType
(
const
Type
&
t
)
final
{
return
t
.
defined
()
?
Reverse
ADType
()
(
t
)
:
t
;
return
t
.
defined
()
?
Reverse
Type
(
t
)
:
t
;
}
};
...
...
tests/python/relay/test_pass_gradient.py
View file @
52fde8f7
...
...
@@ -18,11 +18,12 @@ import numpy as np
import
tvm
from
tvm
import
relay
from
tvm.relay.analysis
import
free_vars
,
free_type_vars
from
tvm.relay.analysis
import
free_vars
,
free_type_vars
,
assert_alpha_equal
from
tvm.relay
import
create_executor
,
transform
from
tvm.relay.transform
import
gradient
from
tvm.relay.prelude
import
Prelude
from
tvm.relay.testing
import
add_nat_definitions
,
make_nat_expr
,
run_infer_type
,
check_grad
,
rand
import
tvm.relay.op
as
op
def
test_id
():
...
...
@@ -280,6 +281,20 @@ def test_grad_tuple():
tvm
.
testing
.
assert_allclose
(
grad
.
asnumpy
(),
4
*
np
.
ones_like
(
x
.
asnumpy
()))
def
test_concat
():
shape
=
(
10
,
10
)
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
rt
=
relay
.
TensorType
((
10
,
20
),
dtype
)
x
=
relay
.
var
(
"x"
,
t
)
y
=
op
.
concatenate
([
x
,
x
],
axis
=
1
)
func
=
relay
.
Function
([
x
],
y
)
func
=
run_infer_type
(
func
)
back_func
=
run_infer_type
(
gradient
(
func
))
assert_alpha_equal
(
back_func
.
checked_type
,
relay
.
FuncType
([
t
],
relay
.
TupleType
([
rt
,
relay
.
TupleType
([
t
])])))
# no value validation as concatenate has dummy gradient right now.
if
__name__
==
"__main__"
:
test_id
()
test_add
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment