Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
803db5d1
Commit
803db5d1
authored
Aug 26, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Revert "Change function def to Node ref for more flexiblity" (#29)
parent
98a67d9b
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
54 additions
and
111 deletions
+54
-111
nnvm/example/src/operator.cc
+8
-9
nnvm/include/dmlc/base.h
+0
-13
nnvm/include/dmlc/json.h
+3
-6
nnvm/include/dmlc/parameter.h
+1
-2
nnvm/include/dmlc/registry.h
+2
-3
nnvm/include/nnvm/node.h
+5
-2
nnvm/include/nnvm/op.h
+14
-12
nnvm/include/nnvm/op_attr_types.h
+11
-25
nnvm/include/nnvm/pass.h
+1
-1
nnvm/include/nnvm/pass_functions.h
+0
-29
nnvm/src/core/graph.cc
+1
-1
nnvm/src/core/symbolic.cc
+4
-4
nnvm/src/pass/infer_shape_type.cc
+1
-1
nnvm/src/pass/order_mutation.cc
+2
-2
nnvm/src/pass/plan_memory.cc
+1
-1
No files found.
nnvm/example/src/operator.cc
View file @
803db5d1
...
@@ -15,13 +15,12 @@ using nnvm::FMutateInputs;
...
@@ -15,13 +15,12 @@ using nnvm::FMutateInputs;
using
nnvm
::
FInferShape
;
using
nnvm
::
FInferShape
;
using
nnvm
::
FInferType
;
using
nnvm
::
FInferType
;
using
nnvm
::
FInplaceOption
;
using
nnvm
::
FInplaceOption
;
using
nnvm
::
Node
;
using
nnvm
::
NodeAttrs
;
using
nnvm
::
NodeAttrs
;
using
nnvm
::
TShape
;
using
nnvm
::
TShape
;
using
nnvm
::
array_view
;
using
nnvm
::
array_view
;
// simply return the shape as same
// simply return the shape as same
inline
bool
SameShape
(
const
Node
&
n
,
inline
bool
SameShape
(
const
Node
Attrs
&
attrs
,
std
::
vector
<
TShape
>
*
ishape
,
std
::
vector
<
TShape
>
*
ishape
,
std
::
vector
<
TShape
>
*
oshape
)
{
std
::
vector
<
TShape
>
*
oshape
)
{
if
(
ishape
->
size
()
==
0
||
(
*
ishape
)[
0
].
ndim
()
==
0
)
return
false
;
if
(
ishape
->
size
()
==
0
||
(
*
ishape
)[
0
].
ndim
()
==
0
)
return
false
;
...
@@ -34,7 +33,7 @@ inline bool SameShape(const Node& n,
...
@@ -34,7 +33,7 @@ inline bool SameShape(const Node& n,
return
true
;
return
true
;
}
}
inline
std
::
vector
<
std
::
pair
<
int
,
int
>
>
InplaceIn0Out0
(
const
Node
&
n
)
{
inline
std
::
vector
<
std
::
pair
<
int
,
int
>
>
InplaceIn0Out0
(
const
Node
Attrs
&
attrs
)
{
return
{{
0
,
0
}};
return
{{
0
,
0
}};
}
}
...
@@ -51,11 +50,11 @@ NNVM_REGISTER_OP(reshape)
...
@@ -51,11 +50,11 @@ NNVM_REGISTER_OP(reshape)
attrs
->
parsed
=
std
::
move
(
target
);
attrs
->
parsed
=
std
::
move
(
target
);
})
})
.
attr
<
FInferShape
>
(
.
attr
<
FInferShape
>
(
"FInferShape"
,
[]
(
const
Node
&
n
,
"FInferShape"
,
[]
(
const
Node
Attrs
&
attrs
,
std
::
vector
<
TShape
>
*
ishape
,
std
::
vector
<
TShape
>
*
ishape
,
std
::
vector
<
TShape
>
*
oshape
)
{
std
::
vector
<
TShape
>
*
oshape
)
{
// get parsed attribute
// get parsed attribute
const
TShape
&
target
=
nnvm
::
get
<
TShape
>
(
n
.
attrs
.
parsed
);
const
TShape
&
target
=
nnvm
::
get
<
TShape
>
(
attrs
.
parsed
);
(
*
oshape
)[
0
]
=
target
;
(
*
oshape
)[
0
]
=
target
;
if
((
*
ishape
)[
0
].
ndim
()
==
0
)
return
false
;
if
((
*
ishape
)[
0
].
ndim
()
==
0
)
return
false
;
CHECK_EQ
((
*
ishape
)[
0
].
Size
(),
target
.
Size
())
CHECK_EQ
((
*
ishape
)[
0
].
Size
(),
target
.
Size
())
...
@@ -78,10 +77,10 @@ NNVM_REGISTER_OP(cast)
...
@@ -78,10 +77,10 @@ NNVM_REGISTER_OP(cast)
})
})
.
attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
attr
<
FInferType
>
(
.
attr
<
FInferType
>
(
"FInferType"
,
[](
const
Node
&
n
,
"FInferType"
,
[](
const
Node
Attrs
&
attrs
,
std
::
vector
<
int
>
*
itype
,
std
::
vector
<
int
>
*
itype
,
std
::
vector
<
int
>
*
otype
)
{
std
::
vector
<
int
>
*
otype
)
{
(
*
otype
)[
0
]
=
nnvm
::
get
<
int
>
(
n
.
attrs
.
parsed
);
(
*
otype
)[
0
]
=
nnvm
::
get
<
int
>
(
attrs
.
parsed
);
return
true
;
return
true
;
});
});
...
@@ -110,7 +109,7 @@ NNVM_REGISTER_OP(cross_device_copy)
...
@@ -110,7 +109,7 @@ NNVM_REGISTER_OP(cross_device_copy)
NNVM_REGISTER_OP
(
conv2d
)
NNVM_REGISTER_OP
(
conv2d
)
.
describe
(
"take conv of input"
)
.
describe
(
"take conv of input"
)
.
set_num_inputs
(
2
)
.
set_num_inputs
(
2
)
.
attr
<
FListInputNames
>
(
"FListInputNames"
,
[](
const
Node
&
n
)
{
.
attr
<
FListInputNames
>
(
"FListInputNames"
,
[](
const
Node
Attrs
&
attrs
)
{
return
std
::
vector
<
std
::
string
>
{
"data"
,
"weight"
};
return
std
::
vector
<
std
::
string
>
{
"data"
,
"weight"
};
});
});
...
@@ -120,7 +119,7 @@ NNVM_REGISTER_OP(add)
...
@@ -120,7 +119,7 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP
(
assign
)
NNVM_REGISTER_OP
(
assign
)
.
set_num_inputs
(
2
)
.
set_num_inputs
(
2
)
.
set_num_outputs
(
1
)
.
set_num_outputs
(
1
)
.
attr
<
FMutateInputs
>
(
"FMutateInputs"
,
[](
const
Node
&
n
)
{
.
attr
<
FMutateInputs
>
(
"FMutateInputs"
,
[](
const
Node
Attrs
&
attrs
)
{
return
std
::
vector
<
uint32_t
>
{
0
};
return
std
::
vector
<
uint32_t
>
{
0
};
});
});
...
...
nnvm/include/dmlc/base.h
View file @
803db5d1
...
@@ -58,11 +58,6 @@
...
@@ -58,11 +58,6 @@
__cplusplus >= 201103L || defined(_MSC_VER))
__cplusplus >= 201103L || defined(_MSC_VER))
#endif
#endif
/*! \brief strict CXX11 support */
#ifndef DMLC_STRICT_CXX11
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/// check if g++ is before 4.6
/// check if g++ is before 4.6
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
...
@@ -74,7 +69,6 @@
...
@@ -74,7 +69,6 @@
#endif
#endif
#endif
#endif
/*!
/*!
* \brief Enable std::thread related modules,
* \brief Enable std::thread related modules,
* Used to disable some module in mingw compile.
* Used to disable some module in mingw compile.
...
@@ -88,13 +82,6 @@
...
@@ -88,13 +82,6 @@
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
#endif
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define DMLC_ATTRIBUTE_UNUSED
#endif
/*! \brief helper macro to generate string concat */
/*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
...
...
nnvm/include/dmlc/json.h
View file @
803db5d1
...
@@ -25,9 +25,7 @@
...
@@ -25,9 +25,7 @@
#include <typeindex>
#include <typeindex>
#include <typeinfo>
#include <typeinfo>
#include <unordered_map>
#include <unordered_map>
#if DMLC_STRICT_CXX11
#include "./any.h"
#include "./any.h"
#endif // DMLC_STRICT_CXX11
#endif // DMLC_USE_CXX11
#endif // DMLC_USE_CXX11
namespace
dmlc
{
namespace
dmlc
{
...
@@ -322,8 +320,7 @@ class JSONObjectReadHelper {
...
@@ -322,8 +320,7 @@ class JSONObjectReadHelper {
};
};
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \
static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __
__make_AnyJSONType ## _ ## KeyName ## __
/*!
/*!
* \def DMLC_JSON_ENABLE_ANY
* \def DMLC_JSON_ENABLE_ANY
...
@@ -478,7 +475,7 @@ struct Handler {
...
@@ -478,7 +475,7 @@ struct Handler {
}
}
};
};
#if DMLC_
STRICT
_CXX11
#if DMLC_
USE
_CXX11
// Manager to store json serialization strategy.
// Manager to store json serialization strategy.
class
AnyJSONManager
{
class
AnyJSONManager
{
public
:
public
:
...
@@ -564,7 +561,7 @@ struct Handler<any> {
...
@@ -564,7 +561,7 @@ struct Handler<any> {
CHECK
(
!
reader
->
NextArrayItem
())
<<
"invalid any json format"
;
CHECK
(
!
reader
->
NextArrayItem
())
<<
"invalid any json format"
;
}
}
};
};
#endif // DMLC_
STRICT
_CXX11
#endif // DMLC_
USE
_CXX11
}
// namespace json
}
// namespace json
...
...
nnvm/include/dmlc/parameter.h
View file @
803db5d1
...
@@ -251,8 +251,7 @@ struct Parameter {
...
@@ -251,8 +251,7 @@ struct Parameter {
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \
return &inst.manager; \
} \
} \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \
__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__()) \
(*PType::__MANAGER__()) \
//! \endcond
//! \endcond
...
...
nnvm/include/dmlc/registry.h
View file @
803db5d1
...
@@ -216,7 +216,7 @@ class FunctionRegEntryBase {
...
@@ -216,7 +216,7 @@ class FunctionRegEntryBase {
* \sa FactoryRegistryEntryBase
* \sa FactoryRegistryEntryBase
*/
*/
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static
DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ =
\
static
EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ =
\
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
/*!
/*!
...
@@ -272,7 +272,6 @@ class FunctionRegEntryBase {
...
@@ -272,7 +272,6 @@ class FunctionRegEntryBase {
*/
*/
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __();
__dmlc_registry_file_tag_ ## UniqueTag ## __();
}
// namespace dmlc
}
// namespace dmlc
#endif // DMLC_REGISTRY_H_
#endif // DMLC_REGISTRY_H_
nnvm/include/nnvm/node.h
View file @
803db5d1
...
@@ -17,6 +17,7 @@ namespace nnvm {
...
@@ -17,6 +17,7 @@ namespace nnvm {
// Forward declare node.
// Forward declare node.
class
Node
;
class
Node
;
/*!
/*!
* \brief we always used NodePtr for a reference pointer
* \brief we always used NodePtr for a reference pointer
* to the node, so this alias can be changed in case.
* to the node, so this alias can be changed in case.
...
@@ -47,6 +48,8 @@ struct NodeEntry {
...
@@ -47,6 +48,8 @@ struct NodeEntry {
struct
NodeAttrs
{
struct
NodeAttrs
{
/*! \brief name of the node */
/*! \brief name of the node */
std
::
string
name
;
std
::
string
name
;
/*! \brief Vector representation of positional attributes */
std
::
vector
<
double
>
scalars
;
/*! \brief The dictionary representation of attributes */
/*! \brief The dictionary representation of attributes */
std
::
unordered_map
<
std
::
string
,
std
::
string
>
dict
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
dict
;
/*!
/*!
...
@@ -105,7 +108,7 @@ inline uint32_t Node::num_outputs() const {
...
@@ -105,7 +108,7 @@ inline uint32_t Node::num_outputs() const {
if
(
this
->
op
->
get_num_outputs
==
nullptr
)
{
if
(
this
->
op
->
get_num_outputs
==
nullptr
)
{
return
this
->
op
->
num_outputs
;
return
this
->
op
->
num_outputs
;
}
else
{
}
else
{
return
this
->
op
->
get_num_outputs
(
*
thi
s
);
return
this
->
op
->
get_num_outputs
(
this
->
attr
s
);
}
}
}
}
...
@@ -114,7 +117,7 @@ inline uint32_t Node::num_inputs() const {
...
@@ -114,7 +117,7 @@ inline uint32_t Node::num_inputs() const {
if
(
this
->
op
->
get_num_inputs
==
nullptr
)
{
if
(
this
->
op
->
get_num_inputs
==
nullptr
)
{
return
this
->
op
->
num_inputs
;
return
this
->
op
->
num_inputs
;
}
else
{
}
else
{
return
this
->
op
->
get_num_inputs
(
*
thi
s
);
return
this
->
op
->
get_num_inputs
(
this
->
attr
s
);
}
}
}
}
...
...
nnvm/include/nnvm/op.h
View file @
803db5d1
...
@@ -102,16 +102,16 @@ class Op {
...
@@ -102,16 +102,16 @@ class Op {
uint32_t
num_outputs
=
1
;
uint32_t
num_outputs
=
1
;
/*!
/*!
* \brief get number of outputs given information about the node.
* \brief get number of outputs given information about the node.
* \param
n T
he node
* \param
attrs The attribute of t
he node
* \return number of outputs.
* \return number of outputs.
*/
*/
std
::
function
<
uint32_t
(
const
Node
&
n
)
>
get_num_outputs
=
nullptr
;
std
::
function
<
uint32_t
(
const
Node
Attrs
&
attrs
)
>
get_num_outputs
=
nullptr
;
/*!
/*!
* \brief get number of inputs given information about the node.
* \brief get number of inputs given information about the node.
* \param
n T
he node
* \param
attrs The attribute of t
he node
* \return number of inputs
* \return number of inputs
*/
*/
std
::
function
<
uint32_t
(
const
Node
&
n
)
>
get_num_inputs
=
nullptr
;
std
::
function
<
uint32_t
(
const
Node
Attrs
&
attrs
)
>
get_num_inputs
=
nullptr
;
/*!
/*!
* \brief Attribute parser to parse the NodeAttrs information.
* \brief Attribute parser to parse the NodeAttrs information.
*
*
...
@@ -136,11 +136,11 @@ class Op {
...
@@ -136,11 +136,11 @@ class Op {
* attrs->parsed = std::move(param);
* attrs->parsed = std::move(param);
* }
* }
* // The other function that can utilize the parsed result.
* // The other function that can utilize the parsed result.
* TShape SumInferShape(const Node
Ptr& ptr
,
* TShape SumInferShape(const Node
Attrs& attrs
,
* const std::vector<TShape>& ishapes) {
* const std::vector<TShape>& ishapes) {
* // we can use the parsed version of param
* // we can use the parsed version of param
* // without repeatively parsing the parameter
* // without repeatively parsing the parameter
* const SumParam& param = nnvm::get<SumParam>(
ptr->
attrs.parsed);
* const SumParam& param = nnvm::get<SumParam>(attrs.parsed);
* }
* }
* \endcode
* \endcode
*/
*/
...
@@ -180,7 +180,7 @@ class Op {
...
@@ -180,7 +180,7 @@ class Op {
* \param fn The function to be set.
* \param fn The function to be set.
* \return reference to self.
* \return reference to self.
*/
*/
inline
Op
&
set_num_inputs
(
std
::
function
<
uint32_t
(
const
Node
&
n
)
>
fn
);
// NOLINT(*)
inline
Op
&
set_num_inputs
(
std
::
function
<
uint32_t
(
const
Node
Attrs
&
attr
)
>
fn
);
// NOLINT(*)
/*!
/*!
* \brief Set the num_outputs
* \brief Set the num_outputs
* \param n The number of outputs to be set.
* \param n The number of outputs to be set.
...
@@ -192,7 +192,7 @@ class Op {
...
@@ -192,7 +192,7 @@ class Op {
* \param fn The function to be set.
* \param fn The function to be set.
* \return reference to self.
* \return reference to self.
*/
*/
inline
Op
&
set_num_outputs
(
std
::
function
<
uint32_t
(
const
Node
&
n
)
>
fn
);
// NOLINT(*)
inline
Op
&
set_num_outputs
(
std
::
function
<
uint32_t
(
const
Node
Attrs
&
attr
)
>
fn
);
// NOLINT(*)
/*!
/*!
* \brief Set the attr_parser function.
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
* \param fn The number of outputs to be set.
...
@@ -279,8 +279,10 @@ class OpMap {
...
@@ -279,8 +279,10 @@ class OpMap {
};
};
// internal macros to make
// internal macros to make
#define NNVM_STR_CONCAT_(__x, __y) __x##__y
#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y)
#define NNVM_REGISTER_VAR_DEF(OpName) \
#define NNVM_REGISTER_VAR_DEF(OpName) \
static
DMLC_ATTRIBUTE_UNUSED
::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
/*!
/*!
* \def NNVM_REGISTER_OP
* \def NNVM_REGISTER_OP
...
@@ -298,7 +300,7 @@ class OpMap {
...
@@ -298,7 +300,7 @@ class OpMap {
* \endcode
* \endcode
*/
*/
#define NNVM_REGISTER_OP(OpName) \
#define NNVM_REGISTER_OP(OpName) \
DMLC
_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
NNVM
_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
// implementations of template functions after this.
// implementations of template functions after this.
...
@@ -375,7 +377,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
...
@@ -375,7 +377,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
return
*
this
;
return
*
this
;
}
}
inline
Op
&
Op
::
set_num_inputs
(
std
::
function
<
uint32_t
(
const
Node
&
n
)
>
fn
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_num_inputs
(
std
::
function
<
uint32_t
(
const
Node
Attrs
&
attr
)
>
fn
)
{
// NOLINT(*)
this
->
get_num_inputs
=
fn
;
this
->
get_num_inputs
=
fn
;
return
*
this
;
return
*
this
;
}
}
...
@@ -385,7 +387,7 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
...
@@ -385,7 +387,7 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
return
*
this
;
return
*
this
;
}
}
inline
Op
&
Op
::
set_num_outputs
(
std
::
function
<
uint32_t
(
const
Node
&
n
)
>
fn
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_num_outputs
(
std
::
function
<
uint32_t
(
const
Node
Attrs
&
attr
)
>
fn
)
{
// NOLINT(*)
this
->
get_num_outputs
=
fn
;
this
->
get_num_outputs
=
fn
;
return
*
this
;
return
*
this
;
}
}
...
...
nnvm/include/nnvm/op_attr_types.h
View file @
803db5d1
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
#include <functional>
#include <functional>
#include "./base.h"
#include "./base.h"
#include "./tuple.h"
#include "./tuple.h"
#include "./node.h"
namespace
nnvm
{
namespace
nnvm
{
...
@@ -22,34 +21,34 @@ namespace nnvm {
...
@@ -22,34 +21,34 @@ namespace nnvm {
/*!
/*!
* \brief Return list of input arguments names of each operator.
* \brief Return list of input arguments names of each operator.
*
*
* \param
n T
he node.
* \param
attrs The attributes of t
he node.
* \return list of inputs
* \return list of inputs
* \note Register under "FListInputNames", default return {"data"}.
* \note Register under "FListInputNames", default return {"data"}.
*
*
* FListInputNames enables automatic variable creation for missing arguments.
* FListInputNames enables automatic variable creation for missing arguments.
*/
*/
using
FListInputNames
=
std
::
function
<
std
::
vector
<
std
::
string
>
(
const
Node
&
n
)
>
;
using
FListInputNames
=
std
::
function
<
std
::
vector
<
std
::
string
>
(
const
Node
Attrs
&
attrs
)
>
;
/*!
/*!
* \brief Return list of output arguments names of each operator.
* \brief Return list of output arguments names of each operator.
*
*
* \param
n T
he node.
* \param
attrs The attributes of t
he node.
* \return list of inputs
* \return list of inputs
* \note Register under "FListOutputNames", default return {"outputs"}.
* \note Register under "FListOutputNames", default return {"outputs"}.
*
*
* FListOutputNames customized naming for operator outputs.
* FListOutputNames customized naming for operator outputs.
*/
*/
using
FListOutputNames
=
std
::
function
<
std
::
vector
<
std
::
string
>
(
const
Node
&
n
)
>
;
using
FListOutputNames
=
std
::
function
<
std
::
vector
<
std
::
string
>
(
const
Node
Attrs
&
attrs
)
>
;
/*!
/*!
* \brief Check whether operator will mutate k-th input.
* \brief Check whether operator will mutate k-th input.
* \param
n T
he node.
* \param
attrs The attributes of t
he node.
* \return list of input indices it mutates.
* \return list of input indices it mutates.
*
*
* \note Register under "FMutateInputs", default return false
* \note Register under "FMutateInputs", default return false
* FMutateInputs enables mutation order handling correctly.
* FMutateInputs enables mutation order handling correctly.
*/
*/
using
FMutateInputs
=
std
::
function
<
std
::
vector
<
uint32_t
>
(
const
Node
&
n
)
>
;
using
FMutateInputs
=
std
::
function
<
std
::
vector
<
uint32_t
>
(
const
Node
Attrs
&
attrs
)
>
;
/*!
/*!
* \brief Inference function of certain type.
* \brief Inference function of certain type.
...
@@ -57,9 +56,9 @@ using FMutateInputs = std::function<std::vector<uint32_t> (const Node& n)>;
...
@@ -57,9 +56,9 @@ using FMutateInputs = std::function<std::vector<uint32_t> (const Node& n)>;
* \return whether all attributes are inferred.
* \return whether all attributes are inferred.
*/
*/
template
<
typename
AttrType
>
template
<
typename
AttrType
>
using
FInferNodeEntryAttr
=
std
::
function
<
bool
(
const
Node
&
n
,
using
FInferNodeEntryAttr
=
std
::
function
<
bool
(
const
Node
Attrs
&
attrs
,
std
::
vector
<
AttrType
>
*
in_
ptr
,
std
::
vector
<
AttrType
>
*
in_
attrs
,
std
::
vector
<
AttrType
>
*
out_
ptr
)
>
;
std
::
vector
<
AttrType
>
*
out_
attrs
)
>
;
/*!
/*!
* \brief Shape inference function.
* \brief Shape inference function.
* Update the shapes given the input shape information.
* Update the shapes given the input shape information.
...
@@ -97,7 +96,7 @@ using TIsBackwardOp = bool;
...
@@ -97,7 +96,7 @@ using TIsBackwardOp = bool;
/*!
/*!
* \brief Get possible inplace options.
* \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output.
* This function enables optimization to reuse memory of inputs in output.
* \param
n T
he node
* \param
attrs The attributes of t
he node
* \param in_data The input data.
* \param in_data The input data.
* \param out_data The output data.
* \param out_data The output data.
* \return list of pair of that maps input->output,
* \return list of pair of that maps input->output,
...
@@ -106,20 +105,7 @@ using TIsBackwardOp = bool;
...
@@ -106,20 +105,7 @@ using TIsBackwardOp = bool;
* \note Register under "FInplaceOption", by default no inplace can happen.
* \note Register under "FInplaceOption", by default no inplace can happen.
*/
*/
using
FInplaceOption
=
std
::
function
<
using
FInplaceOption
=
std
::
function
<
std
::
vector
<
std
::
pair
<
int
,
int
>
>
(
const
Node
&
n
)
>
;
std
::
vector
<
std
::
pair
<
int
,
int
>
>
(
const
NodeAttrs
&
attrs
)
>
;
/*!
* \brief Get the gradient node of the op node
* This function generates the backward graph of the node
* \param nodeptr The node to take gradient
* \param out_grads Gradient of current node's outputs
* \return gradients of the inputs
*
* \note Register under "FGradient"
*/
using
FGradient
=
std
::
function
<
std
::
vector
<
NodeEntry
>
(
const
NodePtr
&
nodeptr
,
const
std
::
vector
<
NodeEntry
>&
out_grads
)
>
;
}
// namespace nnvm
}
// namespace nnvm
...
...
nnvm/include/nnvm/pass.h
View file @
803db5d1
...
@@ -23,7 +23,7 @@ namespace nnvm {
...
@@ -23,7 +23,7 @@ namespace nnvm {
* \param src The graph to be transformed.
* \param src The graph to be transformed.
* \return The generated graph.
* \return The generated graph.
*/
*/
using
PassFunction
=
std
::
function
<
Graph
(
Graph
src
)
>
;
typedef
std
::
function
<
Graph
(
Graph
src
)
>
PassFunction
;
/*!
/*!
* \brief Apply a series of pass transformations on g.
* \brief Apply a series of pass transformations on g.
...
...
nnvm/include/nnvm/pass_functions.h
View file @
803db5d1
...
@@ -11,11 +11,9 @@
...
@@ -11,11 +11,9 @@
#define NNVM_PASS_FUNCTIONS_H_
#define NNVM_PASS_FUNCTIONS_H_
#include <string>
#include <string>
#include <vector>
#include <memory>
#include <memory>
#include "./base.h"
#include "./base.h"
#include "./pass.h"
#include "./pass.h"
#include "./node.h"
#include "./graph_attr_types.h"
#include "./graph_attr_types.h"
namespace
nnvm
{
namespace
nnvm
{
...
@@ -111,33 +109,6 @@ inline Graph PlaceDevice(Graph graph,
...
@@ -111,33 +109,6 @@ inline Graph PlaceDevice(Graph graph,
return
ApplyPass
(
std
::
move
(
graph
),
{
"PlaceDevice"
});
return
ApplyPass
(
std
::
move
(
graph
),
{
"PlaceDevice"
});
}
}
/*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph source graph
* \param ys The entries we want to take gradient from.
* \param xs The input we want to
* \param aggregate_fun aggregation function applied to aggregate the inputs
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \return A new graph, whose outputs corresponds to inputs of xs.
*/
inline
Graph
Gradient
(
Graph
graph
,
std
::
vector
<
NodeEntry
>
ys
,
std
::
vector
<
NodeEntry
>
xs
,
std
::
function
<
NodeEntry
(
std
::
vector
<
NodeEntry
>&&
inputs
)
>
aggregate_fun
=
nullptr
,
std
::
function
<
int
(
const
Node
&
node
)
>
mirror_fun
=
nullptr
)
{
graph
.
attrs
[
"grad_ys"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
ys
));
graph
.
attrs
[
"grad_xs"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
xs
));
if
(
aggregate_fun
!=
nullptr
)
{
graph
.
attrs
[
"grad_aggregate_fun"
]
=
std
::
make_shared
<
any
>
(
aggregate_fun
);
}
if
(
mirror_fun
!=
nullptr
)
{
graph
.
attrs
[
"grad_mirror_fun"
]
=
std
::
make_shared
<
any
>
(
mirror_fun
);
}
return
ApplyPass
(
std
::
move
(
graph
),
{
"Gradient"
});
}
}
// namespace pass
}
// namespace pass
}
// namespace nnvm
}
// namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
#endif // NNVM_PASS_FUNCTIONS_H_
nnvm/src/core/graph.cc
View file @
803db5d1
...
@@ -68,7 +68,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
...
@@ -68,7 +68,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
iptr
+
inputs_rptr
[
nid
],
iptr
+
inputs_rptr
[
nid
+
1
]);
iptr
+
inputs_rptr
[
nid
],
iptr
+
inputs_rptr
[
nid
+
1
]);
if
(
nodes_
[
nid
].
source
->
op
!=
nullptr
&&
if
(
nodes_
[
nid
].
source
->
op
!=
nullptr
&&
fmutate_inputs
.
count
(
nodes_
[
nid
].
source
->
op
))
{
fmutate_inputs
.
count
(
nodes_
[
nid
].
source
->
op
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
nodes_
[
nid
].
source
->
op
](
*
(
nodes_
[
nid
].
source
)
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
nodes_
[
nid
].
source
->
op
](
nodes_
[
nid
].
source
->
attrs
))
{
mutable_input_nodes_
.
insert
(
nodes_
[
nid
].
inputs
[
i
].
node_id
);
mutable_input_nodes_
.
insert
(
nodes_
[
nid
].
inputs
[
i
].
node_id
);
}
}
}
}
...
...
nnvm/src/core/symbolic.cc
View file @
803db5d1
...
@@ -38,7 +38,7 @@ inline void UpdateNodeVersion(Node *n) {
...
@@ -38,7 +38,7 @@ inline void UpdateNodeVersion(Node *n) {
}
}
}
}
if
(
fmutate_inputs
.
count
(
n
->
op
)
!=
0
)
{
if
(
fmutate_inputs
.
count
(
n
->
op
)
!=
0
)
{
for
(
uint32_t
i
:
fmutate_inputs
[
n
->
op
](
*
n
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
n
->
op
](
n
->
attrs
))
{
NodeEntry
&
e
=
n
->
inputs
[
i
];
NodeEntry
&
e
=
n
->
inputs
[
i
];
CHECK
(
e
.
node
->
is_variable
())
CHECK
(
e
.
node
->
is_variable
())
<<
"Mutation target can only be Variable"
;
<<
"Mutation target can only be Variable"
;
...
@@ -197,7 +197,7 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
...
@@ -197,7 +197,7 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
if
(
node
->
is_variable
())
{
if
(
node
->
is_variable
())
{
vlist
.
push_back
(
node
.
get
());
vlist
.
push_back
(
node
.
get
());
}
else
if
(
fmutate_inputs
.
count
(
node
->
op
))
{
}
else
if
(
fmutate_inputs
.
count
(
node
->
op
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
node
->
op
](
*
node
)){
for
(
uint32_t
i
:
fmutate_inputs
[
node
->
op
](
node
->
attrs
)){
mutable_set
.
insert
(
node
->
inputs
[
i
].
node
.
get
());
mutable_set
.
insert
(
node
->
inputs
[
i
].
node
.
get
());
}
}
}
}
...
@@ -223,7 +223,7 @@ std::vector<std::string> Symbol::ListOutputNames() const {
...
@@ -223,7 +223,7 @@ std::vector<std::string> Symbol::ListOutputNames() const {
std
::
string
rname
;
std
::
string
rname
;
FListOutputNames
fn
=
flist_ouputs
.
get
(
head
.
node
->
op
,
nullptr
);
FListOutputNames
fn
=
flist_ouputs
.
get
(
head
.
node
->
op
,
nullptr
);
if
(
fn
!=
nullptr
)
{
if
(
fn
!=
nullptr
)
{
rname
=
fn
(
*
head
.
node
)[
head
.
index
];
rname
=
fn
(
head
.
node
->
attrs
)[
head
.
index
];
}
else
{
}
else
{
rname
=
"output"
;
rname
=
"output"
;
if
(
head
.
node
->
num_outputs
()
!=
1
)
{
if
(
head
.
node
->
num_outputs
()
!=
1
)
{
...
@@ -279,7 +279,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
...
@@ -279,7 +279,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
// switch to keyword argument matching
// switch to keyword argument matching
if
(
args
.
size
()
!=
n_req
)
{
if
(
args
.
size
()
!=
n_req
)
{
FListInputNames
fn
=
flist_inputs
.
get
(
n
->
op
,
nullptr
);
FListInputNames
fn
=
flist_inputs
.
get
(
n
->
op
,
nullptr
);
auto
arg_names
=
(
fn
==
nullptr
)
?
std
::
vector
<
std
::
string
>
{
"data"
}
:
fn
(
*
n
);
auto
arg_names
=
(
fn
==
nullptr
)
?
std
::
vector
<
std
::
string
>
{
"data"
}
:
fn
(
n
->
attrs
);
if
(
arg_names
.
size
()
!=
n_req
)
{
if
(
arg_names
.
size
()
!=
n_req
)
{
LOG
(
FATAL
)
<<
"Not enough argument to call operator "
<<
outputs
[
0
].
node
->
op
->
name
;
LOG
(
FATAL
)
<<
"Not enough argument to call operator "
<<
outputs
[
0
].
node
->
op
->
name
;
}
}
...
...
nnvm/src/pass/infer_shape_type.cc
View file @
803db5d1
...
@@ -75,7 +75,7 @@ Graph InferAttr(Graph &&ret,
...
@@ -75,7 +75,7 @@ Graph InferAttr(Graph &&ret,
oshape
[
i
]
=
rshape
[
idx
.
entry_id
(
nid
,
i
)];
oshape
[
i
]
=
rshape
[
idx
.
entry_id
(
nid
,
i
)];
}
}
num_unknown
+=
num_unknown
+=
!
(
finfer_shape
[
inode
.
source
->
op
](
*
inode
.
source
,
&
ishape
,
&
oshape
));
!
(
finfer_shape
[
inode
.
source
->
op
](
inode
.
source
->
attrs
,
&
ishape
,
&
oshape
));
for
(
uint32_t
i
=
0
;
i
<
num_inputs
;
++
i
)
{
for
(
uint32_t
i
=
0
;
i
<
num_inputs
;
++
i
)
{
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])]
=
ishape
[
i
];
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])]
=
ishape
[
i
];
}
}
...
...
nnvm/src/pass/order_mutation.cc
View file @
803db5d1
...
@@ -44,7 +44,7 @@ Graph OrderMutation(const Graph& src) {
...
@@ -44,7 +44,7 @@ Graph OrderMutation(const Graph& src) {
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
std
::
vector
<
uint32_t
>
mutate_inputs
;
std
::
vector
<
uint32_t
>
mutate_inputs
;
if
(
!
n
->
is_variable
()
&&
fmutate_inputs
.
count
(
n
->
op
))
{
if
(
!
n
->
is_variable
()
&&
fmutate_inputs
.
count
(
n
->
op
))
{
mutate_inputs
=
fmutate_inputs
[
n
->
op
](
*
n
);
mutate_inputs
=
fmutate_inputs
[
n
->
op
](
n
->
attrs
);
}
}
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
...
@@ -102,7 +102,7 @@ Graph OrderMutation(const Graph& src) {
...
@@ -102,7 +102,7 @@ Graph OrderMutation(const Graph& src) {
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
std
::
vector
<
uint32_t
>
mutate_inputs
;
std
::
vector
<
uint32_t
>
mutate_inputs
;
if
(
fmutate_inputs
.
count
(
kv
.
first
->
op
))
{
if
(
fmutate_inputs
.
count
(
kv
.
first
->
op
))
{
mutate_inputs
=
fmutate_inputs
[
kv
.
first
->
op
](
*
kv
.
first
);
mutate_inputs
=
fmutate_inputs
[
kv
.
first
->
op
](
kv
.
first
->
attrs
);
}
}
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
...
...
nnvm/src/pass/plan_memory.cc
View file @
803db5d1
...
@@ -169,7 +169,7 @@ Graph PlanMemory(Graph ret) {
...
@@ -169,7 +169,7 @@ Graph PlanMemory(Graph ret) {
if
(
inode
.
source
->
is_variable
())
continue
;
if
(
inode
.
source
->
is_variable
())
continue
;
// check inplace option
// check inplace option
if
(
finplace_option
.
count
(
inode
.
source
->
op
)
!=
0
)
{
if
(
finplace_option
.
count
(
inode
.
source
->
op
)
!=
0
)
{
auto
inplace_pairs
=
finplace_option
[
inode
.
source
->
op
](
*
inode
.
source
);
auto
inplace_pairs
=
finplace_option
[
inode
.
source
->
op
](
inode
.
source
->
attrs
);
for
(
auto
&
kv
:
inplace_pairs
)
{
for
(
auto
&
kv
:
inplace_pairs
)
{
uint32_t
eid_out
=
idx
.
entry_id
(
nid
,
kv
.
second
);
uint32_t
eid_out
=
idx
.
entry_id
(
nid
,
kv
.
second
);
uint32_t
eid_in
=
idx
.
entry_id
(
inode
.
inputs
[
kv
.
first
]);
uint32_t
eid_in
=
idx
.
entry_id
(
inode
.
inputs
[
kv
.
first
]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment