Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ac070f83
Commit
ac070f83
authored
Aug 27, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Add gradient pass (#28)
parent
803db5d1
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
357 additions
and
20 deletions
+357
-20
nnvm/example/src/operator.cc
+69
-8
nnvm/include/dmlc/base.h
+13
-0
nnvm/include/dmlc/json.h
+6
-3
nnvm/include/dmlc/parameter.h
+2
-1
nnvm/include/dmlc/registry.h
+3
-2
nnvm/include/nnvm/c_api.h
+17
-0
nnvm/include/nnvm/op.h
+2
-4
nnvm/include/nnvm/op_attr_types.h
+14
-0
nnvm/include/nnvm/pass_functions.h
+31
-0
nnvm/python/nnvm/graph.py
+21
-2
nnvm/src/c_api/c_api_graph.cc
+11
-0
nnvm/src/pass/gradient.cc
+144
-0
nnvm/tests/python/test_gradient.py
+24
-0
No files found.
nnvm/example/src/operator.cc
View file @
ac070f83
...
...
@@ -15,6 +15,10 @@ using nnvm::FMutateInputs;
using
nnvm
::
FInferShape
;
using
nnvm
::
FInferType
;
using
nnvm
::
FInplaceOption
;
using
nnvm
::
Node
;
using
nnvm
::
NodePtr
;
using
nnvm
::
NodeEntry
;
using
nnvm
::
FGradient
;
using
nnvm
::
NodeAttrs
;
using
nnvm
::
TShape
;
using
nnvm
::
array_view
;
...
...
@@ -37,6 +41,17 @@ inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs)
return
{{
0
,
0
}};
}
// quick helper to make node
inline
NodeEntry
MakeNode
(
const
char
*
op_name
,
std
::
string
node_name
,
std
::
vector
<
NodeEntry
>
inputs
)
{
NodePtr
p
=
Node
::
Create
();
p
->
op
=
nnvm
::
Op
::
Get
(
op_name
);
p
->
attrs
.
name
=
std
::
move
(
node_name
);
p
->
inputs
=
std
::
move
(
inputs
);
return
NodeEntry
{
p
,
0
,
0
};
}
// simple demonstration of reshape.
NNVM_REGISTER_OP
(
reshape
)
.
describe
(
"reshape source to target shape"
)
...
...
@@ -84,21 +99,67 @@ NNVM_REGISTER_OP(cast)
return
true
;
});
NNVM_REGISTER_OP
(
exp
)
.
describe
(
"take exponential"
)
.
set_num_inputs
(
1
)
.
attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
attr
<
FGradient
>
(
"FGradient"
,
[](
const
NodePtr
&
n
,
const
std
::
vector
<
NodeEntry
>&
ograds
)
{
return
std
::
vector
<
NodeEntry
>
{
MakeNode
(
"mul"
,
n
->
attrs
.
name
+
"_grad"
,
{
ograds
[
0
],
NodeEntry
{
n
,
0
,
0
}})
};
});
NNVM_REGISTER_OP
(
identity
)
.
describe
(
"identity function"
)
.
set_num_inputs
(
1
)
.
attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
attr
<
FGradient
>
(
"FGradient"
,
[](
const
NodePtr
&
n
,
const
std
::
vector
<
NodeEntry
>&
ograds
)
{
return
std
::
vector
<
NodeEntry
>
{
ograds
[
0
]};
});
NNVM_REGISTER_OP
(
add
)
.
describe
(
"add two data together"
)
.
set_num_inputs
(
2
)
.
attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
attr
<
FInplaceOption
>
(
"FInplaceOption"
,
InplaceIn0Out0
);
.
attr
<
FInplaceOption
>
(
"FInplaceOption"
,
InplaceIn0Out0
)
.
attr
<
FGradient
>
(
"FGradient"
,
[](
const
NodePtr
&
n
,
const
std
::
vector
<
NodeEntry
>&
ograds
){
return
std
::
vector
<
NodeEntry
>
{
ograds
[
0
],
ograds
[
0
]};
});
NNVM_REGISTER_OP
(
__add_symbol__
)
.
describe
(
"Alias of add"
)
.
set_num_inputs
(
2
);
NNVM_REGISTER_OP
(
mul
)
.
describe
(
"multiply two data together"
)
.
set_num_inputs
(
2
)
.
attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
attr
<
FInplaceOption
>
(
"FInplaceOption"
,
InplaceIn0Out0
)
.
attr
<
FGradient
>
(
"FGradient"
,
[](
const
NodePtr
&
n
,
const
std
::
vector
<
NodeEntry
>&
ograds
){
return
std
::
vector
<
NodeEntry
>
{
MakeNode
(
"mul"
,
n
->
attrs
.
name
+
"_grad_0"
,
{
ograds
[
0
],
n
->
inputs
[
1
]}),
MakeNode
(
"mul"
,
n
->
attrs
.
name
+
"_grad_1"
,
{
ograds
[
0
],
n
->
inputs
[
0
]})
};
});
NNVM_REGISTER_OP
(
exp
)
.
describe
(
"take exponential"
)
.
set_num_inputs
(
1
)
.
attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
);
NNVM_REGISTER_OP
(
__ewise_sum__
)
.
describe
(
"elementwise sum"
)
.
set_num_inputs
(
nnvm
::
kVarg
);
NNVM_REGISTER_OP
(
__zero__
)
.
describe
(
"set output to zero"
)
.
set_num_inputs
(
0
);
NNVM_REGISTER_OP
(
__one__
)
.
describe
(
"set output to one"
)
.
set_num_inputs
(
0
);
NNVM_REGISTER_OP
(
cross_device_copy
)
.
describe
(
"Copy data across device."
)
...
...
nnvm/include/dmlc/base.h
View file @
ac070f83
...
...
@@ -58,6 +58,11 @@
__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief strict CXX11 support */
#ifndef DMLC_STRICT_CXX11
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/// check if g++ is before 4.6
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
...
...
@@ -69,6 +74,7 @@
#endif
#endif
/*!
* \brief Enable std::thread related modules,
* Used to disable some module in mingw compile.
...
...
@@ -82,6 +88,13 @@
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define DMLC_ATTRIBUTE_UNUSED
#endif
/*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
...
...
nnvm/include/dmlc/json.h
View file @
ac070f83
...
...
@@ -25,7 +25,9 @@
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#if DMLC_STRICT_CXX11
#include "./any.h"
#endif // DMLC_STRICT_CXX11
#endif // DMLC_USE_CXX11
namespace
dmlc
{
...
...
@@ -320,7 +322,8 @@ class JSONObjectReadHelper {
};
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __
static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \
__make_AnyJSONType ## _ ## KeyName ## __
/*!
* \def DMLC_JSON_ENABLE_ANY
...
...
@@ -475,7 +478,7 @@ struct Handler {
}
};
#if DMLC_
USE
_CXX11
#if DMLC_
STRICT
_CXX11
// Manager to store json serialization strategy.
class
AnyJSONManager
{
public
:
...
...
@@ -561,7 +564,7 @@ struct Handler<any> {
CHECK
(
!
reader
->
NextArrayItem
())
<<
"invalid any json format"
;
}
};
#endif // DMLC_
USE
_CXX11
#endif // DMLC_
STRICT
_CXX11
}
// namespace json
...
...
nnvm/include/dmlc/parameter.h
View file @
ac070f83
...
...
@@ -251,7 +251,8 @@ struct Parameter {
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \
} \
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__()) \
//! \endcond
...
...
nnvm/include/dmlc/registry.h
View file @
ac070f83
...
...
@@ -216,7 +216,7 @@ class FunctionRegEntryBase {
* \sa FactoryRegistryEntryBase
*/
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static
EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ =
\
static
DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ =
\
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
/*!
...
...
@@ -272,6 +272,7 @@ class FunctionRegEntryBase {
*/
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __();
static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \
__dmlc_registry_file_tag_ ## UniqueTag ## __();
}
// namespace dmlc
#endif // DMLC_REGISTRY_H_
nnvm/include/nnvm/c_api.h
View file @
ac070f83
...
...
@@ -260,6 +260,7 @@ NNVM_DLL int NNGraphFree(GraphHandle handle);
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL
int
NNGraphGetSymbol
(
GraphHandle
graph
,
SymbolHandle
*
symbol
);
/*!
* \brief Get Set a attribute in json format.
* This feature allows pass graph attributes back and forth in reasonable speed.
...
...
@@ -273,6 +274,7 @@ NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
NNVM_DLL
int
NNGraphSetJSONAttr
(
GraphHandle
handle
,
const
char
*
key
,
const
char
*
json_value
);
/*!
* \brief Get a serialized attrirbute from graph.
* This feature allows pass graph attributes back and forth in reasonable speed.
...
...
@@ -289,6 +291,21 @@ NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle,
const
char
*
key
,
const
char
**
json_out
,
int
*
success
);
/*!
* \brief Set a attribute whose type is std::vector<NodeEntry> in c++
* This feature allows pass List of symbolic variables for gradient request.
*
* \note This is beta feature only used for test purpos
*
* \param handle The graph handle.
* \param key The key to the attribute.
* \param list The symbol whose outputs represents the list of NodeEntry to be passed.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL
int
NNGraphSetNodeEntryListAttr_
(
GraphHandle
handle
,
const
char
*
key
,
SymbolHandle
list
);
/*!
* \brief Apply pass on the src graph.
* \param src The source graph handle.
...
...
nnvm/include/nnvm/op.h
View file @
ac070f83
...
...
@@ -279,10 +279,8 @@ class OpMap {
};
// internal macros to make
#define NNVM_STR_CONCAT_(__x, __y) __x##__y
#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y)
#define NNVM_REGISTER_VAR_DEF(OpName) \
static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
static
DMLC_ATTRIBUTE_UNUSED
::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
/*!
* \def NNVM_REGISTER_OP
...
...
@@ -300,7 +298,7 @@ class OpMap {
* \endcode
*/
#define NNVM_REGISTER_OP(OpName) \
NNVM_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) =
\
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) =
\
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
// implementations of template functions after this.
...
...
nnvm/include/nnvm/op_attr_types.h
View file @
ac070f83
...
...
@@ -11,6 +11,7 @@
#include <utility>
#include <functional>
#include "./base.h"
#include "./node.h"
#include "./tuple.h"
namespace
nnvm
{
...
...
@@ -107,6 +108,19 @@ using TIsBackwardOp = bool;
using
FInplaceOption
=
std
::
function
<
std
::
vector
<
std
::
pair
<
int
,
int
>
>
(
const
NodeAttrs
&
attrs
)
>
;
/*!
* \brief Get the gradient node of the op node
* This function generates the backward graph of the node
* \param nodeptr The node to take gradient
* \param out_grads Gradient of current node's outputs
* \return gradients of the inputs
*
* \note Register under "FGradient"
*/
using
FGradient
=
std
::
function
<
std
::
vector
<
NodeEntry
>
(
const
NodePtr
&
nodeptr
,
const
std
::
vector
<
NodeEntry
>&
out_grads
)
>
;
}
// namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_
nnvm/include/nnvm/pass_functions.h
View file @
ac070f83
...
...
@@ -109,6 +109,37 @@ inline Graph PlaceDevice(Graph graph,
return
ApplyPass
(
std
::
move
(
graph
),
{
"PlaceDevice"
});
}
/*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph source graph
* \param ys The entries we want to take gradient from.
* \param xs The input to take gradient with respect to.
* \param ys_out_grad The symbol for additional gradient to be propagate back to y.
* \param aggregate_fun aggregation function applied to aggregate the inputs
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \return A new graph, whose outputs corresponds to inputs of xs.
*/
inline
Graph
Gradient
(
Graph
graph
,
std
::
vector
<
NodeEntry
>
ys
,
std
::
vector
<
NodeEntry
>
xs
,
std
::
vector
<
NodeEntry
>
ys_out_grad
,
std
::
function
<
NodeEntry
(
std
::
vector
<
NodeEntry
>&&
inputs
)
>
aggregate_fun
=
nullptr
,
std
::
function
<
int
(
const
Node
&
node
)
>
mirror_fun
=
nullptr
)
{
graph
.
attrs
[
"grad_ys"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
ys
));
graph
.
attrs
[
"grad_xs"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
xs
));
graph
.
attrs
[
"grad_ys_out_grad"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
ys_out_grad
));
if
(
aggregate_fun
!=
nullptr
)
{
graph
.
attrs
[
"grad_aggregate_fun"
]
=
std
::
make_shared
<
any
>
(
aggregate_fun
);
}
if
(
mirror_fun
!=
nullptr
)
{
graph
.
attrs
[
"grad_mirror_fun"
]
=
std
::
make_shared
<
any
>
(
mirror_fun
);
}
return
ApplyPass
(
std
::
move
(
graph
),
{
"Gradient"
});
}
}
// namespace pass
}
// namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
nnvm/python/nnvm/graph.py
View file @
ac070f83
...
...
@@ -10,7 +10,7 @@ from ._base import _LIB
from
._base
import
c_array
,
c_str
,
nn_uint
,
py_str
,
string_types
from
._base
import
GraphHandle
,
SymbolHandle
from
._base
import
check_call
from
.symbol
import
Symbol
from
.symbol
import
Symbol
,
Group
as
_Group
class
Graph
(
object
):
...
...
@@ -56,8 +56,27 @@ class Graph(object):
else
:
return
None
def
_set_symbol_list_attr
(
self
,
key
,
value
):
"""Set the attribute of the graph.
Parameters
----------
key : string
The key of the attribute
value : value
The any type that can be dumped to json
type_name : string
The typename registered on c++ side.
"""
if
isinstance
(
value
,
list
):
value
=
_Group
(
value
)
if
not
isinstance
(
value
,
Symbol
):
raise
ValueError
(
"value need to be grouped symbol"
)
check_call
(
_LIB
.
NNGraphSetNodeEntryListAttr_
(
self
.
handle
,
c_str
(
key
),
value
.
handle
))
def
_set_json_attr
(
self
,
key
,
value
,
type_name
=
None
):
"""Set the attribute of the
symbol
.
"""Set the attribute of the
graph
.
Parameters
----------
...
...
nnvm/src/c_api/c_api_graph.cc
View file @
ac070f83
...
...
@@ -35,6 +35,17 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
API_END_HANDLE_ERROR
(
delete
s
);
}
int
NNGraphSetNodeEntryListAttr_
(
GraphHandle
handle
,
const
char
*
key
,
SymbolHandle
list
)
{
API_BEGIN
();
Symbol
*
s
=
static_cast
<
Symbol
*>
(
list
);
Graph
*
g
=
static_cast
<
Graph
*>
(
handle
);
g
->
attrs
[
std
::
string
(
key
)]
=
std
::
make_shared
<
any
>
(
s
->
outputs
);
API_END
();
}
int
NNGraphSetJSONAttr
(
GraphHandle
handle
,
const
char
*
key
,
const
char
*
json_value
)
{
...
...
nnvm/src/pass/gradient.cc
0 → 100644
View file @
ac070f83
/*!
* Copyright (c) 2016 by Contributors
* \file gradients.cc
* \brief Passes that takes gradient of the graph
* This code code was modified based on mxnet codebase by Min Lin
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <algorithm>
#include <functional>
namespace
nnvm
{
namespace
pass
{
namespace
{
// default aggregate gradient function
// require operator __zero__ and __ewise_sum__ to be presented.
NodeEntry
DefaultAggregateGradient
(
std
::
vector
<
NodeEntry
>&&
v
)
{
if
(
v
.
size
()
==
1
)
{
return
std
::
move
(
v
[
0
]);
}
else
if
(
v
.
size
()
==
0
)
{
NodePtr
zero_node
=
Node
::
Create
();
zero_node
->
op
=
Op
::
Get
(
"__zero__"
);
return
NodeEntry
{
zero_node
,
0
,
0
};
}
else
{
NodePtr
sum_node
=
Node
::
Create
();
sum_node
->
op
=
Op
::
Get
(
"__ewise_sum__"
);
sum_node
->
inputs
=
std
::
move
(
v
);
return
NodeEntry
{
sum_node
,
0
,
0
};
}
}
// helper entry
struct
GradEntry
{
NodeEntry
sum
{
nullptr
,
0
,
0
};
std
::
vector
<
NodeEntry
>
grads
;
};
Graph
Gradient
(
Graph
src
)
{
using
nnvm
::
FGradient
;
using
MirrorFun
=
std
::
function
<
int
(
const
Node
&
node
)
>
;
CHECK_NE
(
src
.
attrs
.
count
(
"grad_ys"
),
0
)
<<
"Gradient require grad_ys to be presented."
;
CHECK_NE
(
src
.
attrs
.
count
(
"grad_ys_out_grad"
),
0
)
<<
"Gradient require grad_ys_out_grad to be presented."
;
CHECK_NE
(
src
.
attrs
.
count
(
"grad_xs"
),
0
)
<<
"Gradient require grad_xs to be presented."
;
const
std
::
vector
<
NodeEntry
>&
ys
=
src
.
GetAttr
<
std
::
vector
<
NodeEntry
>
>
(
"grad_ys"
);
const
std
::
vector
<
NodeEntry
>&
ys_out_grad
=
src
.
GetAttr
<
std
::
vector
<
NodeEntry
>
>
(
"grad_ys_out_grad"
);
const
std
::
vector
<
NodeEntry
>&
xs
=
src
.
GetAttr
<
std
::
vector
<
NodeEntry
>
>
(
"grad_xs"
);
using
AggFun
=
std
::
function
<
NodeEntry
(
std
::
vector
<
NodeEntry
>&&
inputs
)
>
;
AggFun
agg_fun
=
DefaultAggregateGradient
;
if
(
src
.
attrs
.
count
(
"grad_aggregate_fun"
)
!=
0
)
{
agg_fun
=
src
.
GetAttr
<
AggFun
>
(
"grad_aggregate_fun"
);
}
MirrorFun
mirror_fun
=
nullptr
;
if
(
src
.
attrs
.
count
(
"grad_mirror_fun"
)
!=
0
)
{
mirror_fun
=
src
.
GetAttr
<
MirrorFun
>
(
"grad_mirror_fun"
);
}
// topo sort
std
::
vector
<
NodePtr
>
topo_order
;
std
::
unordered_map
<
Node
*
,
std
::
vector
<
GradEntry
>
>
output_grads
;
DFSVisit
(
ys
,
[
&
](
const
NodePtr
&
node
)
{
if
(
output_grads
.
count
(
node
.
get
())
==
0
)
{
output_grads
[
node
.
get
()].
resize
(
node
->
num_outputs
());
}
topo_order
.
push_back
(
node
);
});
CHECK_EQ
(
ys
.
size
(),
ys_out_grad
.
size
());
for
(
size_t
i
=
0
;
i
<
ys
.
size
();
++
i
)
{
output_grads
[
ys
[
i
].
node
.
get
()][
ys
[
i
].
index
].
grads
=
{
ys_out_grad
[
i
]
};
}
// construct mirror reduece memory strategy if needed
std
::
unordered_map
<
Node
*
,
NodePtr
>
mirror_map
;
if
(
mirror_fun
!=
nullptr
)
{
for
(
const
NodePtr
&
n
:
topo_order
)
{
if
(
mirror_fun
(
*
n
))
{
NodePtr
new_node
=
Node
::
Create
();
*
new_node
=
*
n
;
new_node
->
attrs
.
name
+=
"_mirror"
;
for
(
auto
&
e
:
new_node
->
inputs
)
{
e
.
node
=
mirror_map
.
at
(
e
.
node
.
get
());
}
for
(
auto
&
n
:
new_node
->
control_deps
)
{
n
=
mirror_map
.
at
(
n
.
get
());
}
mirror_map
[
n
.
get
()]
=
std
::
move
(
new_node
);
}
else
{
mirror_map
[
n
.
get
()]
=
n
;
}
}
}
// traverse backward
static
auto
&
grad_fun_map
=
Op
::
GetAttr
<
FGradient
>
(
"FGradient"
);
std
::
vector
<
NodeEntry
>
out_agg_grads
;
for
(
auto
rit
=
topo_order
.
rbegin
();
rit
!=
topo_order
.
rend
();
++
rit
)
{
const
NodePtr
&
ptr
=
*
rit
;
if
(
ptr
->
is_variable
())
continue
;
out_agg_grads
.
clear
();
for
(
GradEntry
&
e
:
output_grads
.
at
(
ptr
.
get
()))
{
e
.
sum
=
agg_fun
(
std
::
move
(
e
.
grads
));
out_agg_grads
.
push_back
(
e
.
sum
);
}
std
::
vector
<
NodeEntry
>
input_grads
=
grad_fun_map
[
ptr
->
op
]
(
mirror_map
.
size
()
==
0
?
ptr
:
mirror_map
.
at
(
ptr
.
get
()),
out_agg_grads
);
auto
git
=
input_grads
.
begin
();
for
(
auto
it
=
(
*
rit
)
->
inputs
.
begin
();
it
!=
(
*
rit
)
->
inputs
.
end
();
++
it
,
++
git
)
{
output_grads
[
it
->
node
.
get
()][
it
->
index
].
grads
.
emplace_back
(
std
::
move
(
*
git
));
}
}
// take out the xs' grads
Graph
ret
;
ret
.
outputs
.
reserve
(
xs
.
size
());
for
(
const
NodeEntry
&
e
:
xs
)
{
GradEntry
&
entry
=
output_grads
[
e
.
node
.
get
()][
e
.
index
];
// aggregate sum if there haven't been
if
(
entry
.
sum
.
node
.
get
()
==
nullptr
)
{
entry
.
sum
=
agg_fun
(
std
::
move
(
entry
.
grads
));
}
ret
.
outputs
.
emplace_back
(
std
::
move
(
entry
.
sum
));
}
return
ret
;
}
// register pass
NNVM_REGISTER_PASS
(
Gradient
)
.
describe
(
"Return a gradient graph of src.attrs[
\"
ys
\"
] wrt src.attrs[
\"
xs
\"
]"
)
.
set_body
(
Gradient
)
.
set_change_graph
(
true
)
.
depend_graph_attr
(
"grad_ys"
)
.
depend_graph_attr
(
"grad_xs"
)
.
depend_graph_attr
(
"grad_ys_out_grad"
);
}
// namespace
}
// namespace pass
}
// namespace nnvm
nnvm/tests/python/test_gradient.py
0 → 100644
View file @
ac070f83
import
json
import
nnvm.symbol
as
sym
import
nnvm.graph
as
graph
def
grad
(
ys
,
xs
,
ys_grads
):
g
=
graph
.
create
(
ys
)
g
.
_set_symbol_list_attr
(
'grad_ys'
,
ys
)
g
.
_set_symbol_list_attr
(
'grad_xs'
,
xs
)
g
.
_set_symbol_list_attr
(
'grad_ys_out_grad'
,
ys_grads
)
return
g
.
apply
(
'Gradient'
)
def
test_graph_gradient
():
x0
=
sym
.
Variable
(
'x0'
)
x1
=
sym
.
Variable
(
'x1'
)
yg
=
sym
.
Variable
(
'yg'
)
y
=
sym
.
exp
(
sym
.
mul
(
x0
,
x1
))
grad_graph
=
grad
(
y
,
[
x0
],
yg
)
print
(
"Original graph"
)
print
(
y
.
debug_str
())
print
(
"Gradient graph"
)
print
grad_graph
.
symbol
.
debug_str
()
if
__name__
==
"__main__"
:
test_graph_gradient
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment