Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
803db5d1
Commit
803db5d1
authored
Aug 26, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Revert "Change function def to Node ref for more flexiblity" (#29)
parent
98a67d9b
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
54 additions
and
111 deletions
+54
-111
nnvm/example/src/operator.cc
+8
-9
nnvm/include/dmlc/base.h
+0
-13
nnvm/include/dmlc/json.h
+3
-6
nnvm/include/dmlc/parameter.h
+1
-2
nnvm/include/dmlc/registry.h
+2
-3
nnvm/include/nnvm/node.h
+5
-2
nnvm/include/nnvm/op.h
+14
-12
nnvm/include/nnvm/op_attr_types.h
+11
-25
nnvm/include/nnvm/pass.h
+1
-1
nnvm/include/nnvm/pass_functions.h
+0
-29
nnvm/src/core/graph.cc
+1
-1
nnvm/src/core/symbolic.cc
+4
-4
nnvm/src/pass/infer_shape_type.cc
+1
-1
nnvm/src/pass/order_mutation.cc
+2
-2
nnvm/src/pass/plan_memory.cc
+1
-1
No files found.
nnvm/example/src/operator.cc
View file @
803db5d1
...
...
@@ -15,13 +15,12 @@ using nnvm::FMutateInputs;
using
nnvm
::
FInferShape
;
using
nnvm
::
FInferType
;
using
nnvm
::
FInplaceOption
;
using
nnvm
::
Node
;
using
nnvm
::
NodeAttrs
;
using
nnvm
::
TShape
;
using
nnvm
::
array_view
;
// simply return the shape as same
inline
bool
SameShape
(
const
Node
&
n
,
inline
bool
SameShape
(
const
Node
Attrs
&
attrs
,
std
::
vector
<
TShape
>
*
ishape
,
std
::
vector
<
TShape
>
*
oshape
)
{
if
(
ishape
->
size
()
==
0
||
(
*
ishape
)[
0
].
ndim
()
==
0
)
return
false
;
...
...
@@ -34,7 +33,7 @@ inline bool SameShape(const Node& n,
return
true
;
}
inline
std
::
vector
<
std
::
pair
<
int
,
int
>
>
InplaceIn0Out0
(
const
Node
&
n
)
{
inline
std
::
vector
<
std
::
pair
<
int
,
int
>
>
InplaceIn0Out0
(
const
Node
Attrs
&
attrs
)
{
return
{{
0
,
0
}};
}
...
...
@@ -51,11 +50,11 @@ NNVM_REGISTER_OP(reshape)
attrs
->
parsed
=
std
::
move
(
target
);
})
.
attr
<
FInferShape
>
(
"FInferShape"
,
[]
(
const
Node
&
n
,
"FInferShape"
,
[]
(
const
Node
Attrs
&
attrs
,
std
::
vector
<
TShape
>
*
ishape
,
std
::
vector
<
TShape
>
*
oshape
)
{
// get parsed attribute
const
TShape
&
target
=
nnvm
::
get
<
TShape
>
(
n
.
attrs
.
parsed
);
const
TShape
&
target
=
nnvm
::
get
<
TShape
>
(
attrs
.
parsed
);
(
*
oshape
)[
0
]
=
target
;
if
((
*
ishape
)[
0
].
ndim
()
==
0
)
return
false
;
CHECK_EQ
((
*
ishape
)[
0
].
Size
(),
target
.
Size
())
...
...
@@ -78,10 +77,10 @@ NNVM_REGISTER_OP(cast)
})
.
attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
attr
<
FInferType
>
(
"FInferType"
,
[](
const
Node
&
n
,
"FInferType"
,
[](
const
Node
Attrs
&
attrs
,
std
::
vector
<
int
>
*
itype
,
std
::
vector
<
int
>
*
otype
)
{
(
*
otype
)[
0
]
=
nnvm
::
get
<
int
>
(
n
.
attrs
.
parsed
);
(
*
otype
)[
0
]
=
nnvm
::
get
<
int
>
(
attrs
.
parsed
);
return
true
;
});
...
...
@@ -110,7 +109,7 @@ NNVM_REGISTER_OP(cross_device_copy)
NNVM_REGISTER_OP
(
conv2d
)
.
describe
(
"take conv of input"
)
.
set_num_inputs
(
2
)
.
attr
<
FListInputNames
>
(
"FListInputNames"
,
[](
const
Node
&
n
)
{
.
attr
<
FListInputNames
>
(
"FListInputNames"
,
[](
const
Node
Attrs
&
attrs
)
{
return
std
::
vector
<
std
::
string
>
{
"data"
,
"weight"
};
});
...
...
@@ -120,7 +119,7 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP
(
assign
)
.
set_num_inputs
(
2
)
.
set_num_outputs
(
1
)
.
attr
<
FMutateInputs
>
(
"FMutateInputs"
,
[](
const
Node
&
n
)
{
.
attr
<
FMutateInputs
>
(
"FMutateInputs"
,
[](
const
Node
Attrs
&
attrs
)
{
return
std
::
vector
<
uint32_t
>
{
0
};
});
...
...
nnvm/include/dmlc/base.h
View file @
803db5d1
...
...
@@ -58,11 +58,6 @@
__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief strict CXX11 support */
#ifndef DMLC_STRICT_CXX11
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/// check if g++ is before 4.6
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
...
...
@@ -74,7 +69,6 @@
#endif
#endif
/*!
* \brief Enable std::thread related modules,
* Used to disable some module in mingw compile.
...
...
@@ -88,13 +82,6 @@
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#endif
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define DMLC_ATTRIBUTE_UNUSED
#endif
/*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
...
...
nnvm/include/dmlc/json.h
View file @
803db5d1
...
...
@@ -25,9 +25,7 @@
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#if DMLC_STRICT_CXX11
#include "./any.h"
#endif // DMLC_STRICT_CXX11
#endif // DMLC_USE_CXX11
namespace
dmlc
{
...
...
@@ -322,8 +320,7 @@ class JSONObjectReadHelper {
};
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \
__make_AnyJSONType ## _ ## KeyName ## __
static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __
/*!
* \def DMLC_JSON_ENABLE_ANY
...
...
@@ -478,7 +475,7 @@ struct Handler {
}
};
#if DMLC_
STRICT
_CXX11
#if DMLC_
USE
_CXX11
// Manager to store json serialization strategy.
class
AnyJSONManager
{
public
:
...
...
@@ -564,7 +561,7 @@ struct Handler<any> {
CHECK
(
!
reader
->
NextArrayItem
())
<<
"invalid any json format"
;
}
};
#endif // DMLC_
STRICT
_CXX11
#endif // DMLC_
USE
_CXX11
}
// namespace json
...
...
nnvm/include/dmlc/parameter.h
View file @
803db5d1
...
...
@@ -251,8 +251,7 @@ struct Parameter {
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \
} \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
__make__ ## PType ## ParamManager__ = \
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__()) \
//! \endcond
...
...
nnvm/include/dmlc/registry.h
View file @
803db5d1
...
...
@@ -216,7 +216,7 @@ class FunctionRegEntryBase {
* \sa FactoryRegistryEntryBase
*/
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static
DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ =
\
static
EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ =
\
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
/*!
...
...
@@ -272,7 +272,6 @@ class FunctionRegEntryBase {
*/
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \
__dmlc_registry_file_tag_ ## UniqueTag ## __();
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __();
}
// namespace dmlc
#endif // DMLC_REGISTRY_H_
nnvm/include/nnvm/node.h
View file @
803db5d1
...
...
@@ -17,6 +17,7 @@ namespace nnvm {
// Forward declare node.
class
Node
;
/*!
* \brief we always used NodePtr for a reference pointer
* to the node, so this alias can be changed in case.
...
...
@@ -47,6 +48,8 @@ struct NodeEntry {
struct
NodeAttrs
{
/*! \brief name of the node */
std
::
string
name
;
/*! \brief Vector representation of positional attributes */
std
::
vector
<
double
>
scalars
;
/*! \brief The dictionary representation of attributes */
std
::
unordered_map
<
std
::
string
,
std
::
string
>
dict
;
/*!
...
...
@@ -105,7 +108,7 @@ inline uint32_t Node::num_outputs() const {
if
(
this
->
op
->
get_num_outputs
==
nullptr
)
{
return
this
->
op
->
num_outputs
;
}
else
{
return
this
->
op
->
get_num_outputs
(
*
thi
s
);
return
this
->
op
->
get_num_outputs
(
this
->
attr
s
);
}
}
...
...
@@ -114,7 +117,7 @@ inline uint32_t Node::num_inputs() const {
if
(
this
->
op
->
get_num_inputs
==
nullptr
)
{
return
this
->
op
->
num_inputs
;
}
else
{
return
this
->
op
->
get_num_inputs
(
*
thi
s
);
return
this
->
op
->
get_num_inputs
(
this
->
attr
s
);
}
}
...
...
nnvm/include/nnvm/op.h
View file @
803db5d1
...
...
@@ -102,16 +102,16 @@ class Op {
uint32_t
num_outputs
=
1
;
/*!
* \brief get number of outputs given information about the node.
* \param
n T
he node
* \param
attrs The attribute of t
he node
* \return number of outputs.
*/
std
::
function
<
uint32_t
(
const
Node
&
n
)
>
get_num_outputs
=
nullptr
;
std
::
function
<
uint32_t
(
const
Node
Attrs
&
attrs
)
>
get_num_outputs
=
nullptr
;
/*!
* \brief get number of inputs given information about the node.
* \param
n T
he node
* \param
attrs The attribute of t
he node
* \return number of inputs
*/
std
::
function
<
uint32_t
(
const
Node
&
n
)
>
get_num_inputs
=
nullptr
;
std
::
function
<
uint32_t
(
const
Node
Attrs
&
attrs
)
>
get_num_inputs
=
nullptr
;
/*!
* \brief Attribute parser to parse the NodeAttrs information.
*
...
...
@@ -136,11 +136,11 @@ class Op {
* attrs->parsed = std::move(param);
* }
* // The other function that can utilize the parsed result.
* TShape SumInferShape(const Node
Ptr& ptr
,
* TShape SumInferShape(const Node
Attrs& attrs
,
* const std::vector<TShape>& ishapes) {
* // we can use the parsed version of param
* // without repeatively parsing the parameter
* const SumParam& param = nnvm::get<SumParam>(
ptr->
attrs.parsed);
* const SumParam& param = nnvm::get<SumParam>(attrs.parsed);
* }
* \endcode
*/
...
...
@@ -180,7 +180,7 @@ class Op {
* \param fn The function to be set.
* \return reference to self.
*/
inline
Op
&
set_num_inputs
(
std
::
function
<
uint32_t
(
const
Node
&
n
)
>
fn
);
// NOLINT(*)
inline
Op
&
set_num_inputs
(
std
::
function
<
uint32_t
(
const
Node
Attrs
&
attr
)
>
fn
);
// NOLINT(*)
/*!
* \brief Set the num_outputs
* \param n The number of outputs to be set.
...
...
@@ -192,7 +192,7 @@ class Op {
* \param fn The function to be set.
* \return reference to self.
*/
inline
Op
&
set_num_outputs
(
std
::
function
<
uint32_t
(
const
Node
&
n
)
>
fn
);
// NOLINT(*)
inline
Op
&
set_num_outputs
(
std
::
function
<
uint32_t
(
const
Node
Attrs
&
attr
)
>
fn
);
// NOLINT(*)
/*!
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
...
...
@@ -279,8 +279,10 @@ class OpMap {
};
// internal macros to make
#define NNVM_STR_CONCAT_(__x, __y) __x##__y
#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y)
#define NNVM_REGISTER_VAR_DEF(OpName) \
static
DMLC_ATTRIBUTE_UNUSED
::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
/*!
* \def NNVM_REGISTER_OP
...
...
@@ -298,7 +300,7 @@ class OpMap {
* \endcode
*/
#define NNVM_REGISTER_OP(OpName) \
DMLC
_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
NNVM
_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
// implementations of template functions after this.
...
...
@@ -375,7 +377,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
return
*
this
;
}
inline
Op
&
Op
::
set_num_inputs
(
std
::
function
<
uint32_t
(
const
Node
&
n
)
>
fn
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_num_inputs
(
std
::
function
<
uint32_t
(
const
Node
Attrs
&
attr
)
>
fn
)
{
// NOLINT(*)
this
->
get_num_inputs
=
fn
;
return
*
this
;
}
...
...
@@ -385,7 +387,7 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
return
*
this
;
}
inline
Op
&
Op
::
set_num_outputs
(
std
::
function
<
uint32_t
(
const
Node
&
n
)
>
fn
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_num_outputs
(
std
::
function
<
uint32_t
(
const
Node
Attrs
&
attr
)
>
fn
)
{
// NOLINT(*)
this
->
get_num_outputs
=
fn
;
return
*
this
;
}
...
...
nnvm/include/nnvm/op_attr_types.h
View file @
803db5d1
...
...
@@ -12,7 +12,6 @@
#include <functional>
#include "./base.h"
#include "./tuple.h"
#include "./node.h"
namespace
nnvm
{
...
...
@@ -22,34 +21,34 @@ namespace nnvm {
/*!
* \brief Return list of input arguments names of each operator.
*
* \param
n T
he node.
* \param
attrs The attributes of t
he node.
* \return list of inputs
* \note Register under "FListInputNames", default return {"data"}.
*
* FListInputNames enables automatic variable creation for missing arguments.
*/
using
FListInputNames
=
std
::
function
<
std
::
vector
<
std
::
string
>
(
const
Node
&
n
)
>
;
using
FListInputNames
=
std
::
function
<
std
::
vector
<
std
::
string
>
(
const
Node
Attrs
&
attrs
)
>
;
/*!
* \brief Return list of output arguments names of each operator.
*
* \param
n T
he node.
* \param
attrs The attributes of t
he node.
* \return list of inputs
* \note Register under "FListOutputNames", default return {"outputs"}.
*
* FListOutputNames customized naming for operator outputs.
*/
using
FListOutputNames
=
std
::
function
<
std
::
vector
<
std
::
string
>
(
const
Node
&
n
)
>
;
using
FListOutputNames
=
std
::
function
<
std
::
vector
<
std
::
string
>
(
const
Node
Attrs
&
attrs
)
>
;
/*!
* \brief Check whether operator will mutate k-th input.
* \param
n T
he node.
* \param
attrs The attributes of t
he node.
* \return list of input indices it mutates.
*
* \note Register under "FMutateInputs", default return false
* FMutateInputs enables mutation order handling correctly.
*/
using
FMutateInputs
=
std
::
function
<
std
::
vector
<
uint32_t
>
(
const
Node
&
n
)
>
;
using
FMutateInputs
=
std
::
function
<
std
::
vector
<
uint32_t
>
(
const
Node
Attrs
&
attrs
)
>
;
/*!
* \brief Inference function of certain type.
...
...
@@ -57,9 +56,9 @@ using FMutateInputs = std::function<std::vector<uint32_t> (const Node& n)>;
* \return whether all attributes are inferred.
*/
template
<
typename
AttrType
>
using
FInferNodeEntryAttr
=
std
::
function
<
bool
(
const
Node
&
n
,
std
::
vector
<
AttrType
>
*
in_
ptr
,
std
::
vector
<
AttrType
>
*
out_
ptr
)
>
;
using
FInferNodeEntryAttr
=
std
::
function
<
bool
(
const
Node
Attrs
&
attrs
,
std
::
vector
<
AttrType
>
*
in_
attrs
,
std
::
vector
<
AttrType
>
*
out_
attrs
)
>
;
/*!
* \brief Shape inference function.
* Update the shapes given the input shape information.
...
...
@@ -97,7 +96,7 @@ using TIsBackwardOp = bool;
/*!
* \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output.
* \param
n T
he node
* \param
attrs The attributes of t
he node
* \param in_data The input data.
* \param out_data The output data.
* \return list of pair of that maps input->output,
...
...
@@ -106,20 +105,7 @@ using TIsBackwardOp = bool;
* \note Register under "FInplaceOption", by default no inplace can happen.
*/
using
FInplaceOption
=
std
::
function
<
std
::
vector
<
std
::
pair
<
int
,
int
>
>
(
const
Node
&
n
)
>
;
/*!
* \brief Get the gradient node of the op node
* This function generates the backward graph of the node
* \param nodeptr The node to take gradient
* \param out_grads Gradient of current node's outputs
* \return gradients of the inputs
*
* \note Register under "FGradient"
*/
using
FGradient
=
std
::
function
<
std
::
vector
<
NodeEntry
>
(
const
NodePtr
&
nodeptr
,
const
std
::
vector
<
NodeEntry
>&
out_grads
)
>
;
std
::
vector
<
std
::
pair
<
int
,
int
>
>
(
const
NodeAttrs
&
attrs
)
>
;
}
// namespace nnvm
...
...
nnvm/include/nnvm/pass.h
View file @
803db5d1
...
...
@@ -23,7 +23,7 @@ namespace nnvm {
* \param src The graph to be transformed.
* \return The generated graph.
*/
using
PassFunction
=
std
::
function
<
Graph
(
Graph
src
)
>
;
typedef
std
::
function
<
Graph
(
Graph
src
)
>
PassFunction
;
/*!
* \brief Apply a series of pass transformations on g.
...
...
nnvm/include/nnvm/pass_functions.h
View file @
803db5d1
...
...
@@ -11,11 +11,9 @@
#define NNVM_PASS_FUNCTIONS_H_
#include <string>
#include <vector>
#include <memory>
#include "./base.h"
#include "./pass.h"
#include "./node.h"
#include "./graph_attr_types.h"
namespace
nnvm
{
...
...
@@ -111,33 +109,6 @@ inline Graph PlaceDevice(Graph graph,
return
ApplyPass
(
std
::
move
(
graph
),
{
"PlaceDevice"
});
}
/*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph source graph
* \param ys The entries we want to take gradient from.
* \param xs The input we want to
* \param aggregate_fun aggregation function applied to aggregate the inputs
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \return A new graph, whose outputs corresponds to inputs of xs.
*/
inline
Graph
Gradient
(
Graph
graph
,
std
::
vector
<
NodeEntry
>
ys
,
std
::
vector
<
NodeEntry
>
xs
,
std
::
function
<
NodeEntry
(
std
::
vector
<
NodeEntry
>&&
inputs
)
>
aggregate_fun
=
nullptr
,
std
::
function
<
int
(
const
Node
&
node
)
>
mirror_fun
=
nullptr
)
{
graph
.
attrs
[
"grad_ys"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
ys
));
graph
.
attrs
[
"grad_xs"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
xs
));
if
(
aggregate_fun
!=
nullptr
)
{
graph
.
attrs
[
"grad_aggregate_fun"
]
=
std
::
make_shared
<
any
>
(
aggregate_fun
);
}
if
(
mirror_fun
!=
nullptr
)
{
graph
.
attrs
[
"grad_mirror_fun"
]
=
std
::
make_shared
<
any
>
(
mirror_fun
);
}
return
ApplyPass
(
std
::
move
(
graph
),
{
"Gradient"
});
}
}
// namespace pass
}
// namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
nnvm/src/core/graph.cc
View file @
803db5d1
...
...
@@ -68,7 +68,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
iptr
+
inputs_rptr
[
nid
],
iptr
+
inputs_rptr
[
nid
+
1
]);
if
(
nodes_
[
nid
].
source
->
op
!=
nullptr
&&
fmutate_inputs
.
count
(
nodes_
[
nid
].
source
->
op
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
nodes_
[
nid
].
source
->
op
](
*
(
nodes_
[
nid
].
source
)
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
nodes_
[
nid
].
source
->
op
](
nodes_
[
nid
].
source
->
attrs
))
{
mutable_input_nodes_
.
insert
(
nodes_
[
nid
].
inputs
[
i
].
node_id
);
}
}
...
...
nnvm/src/core/symbolic.cc
View file @
803db5d1
...
...
@@ -38,7 +38,7 @@ inline void UpdateNodeVersion(Node *n) {
}
}
if
(
fmutate_inputs
.
count
(
n
->
op
)
!=
0
)
{
for
(
uint32_t
i
:
fmutate_inputs
[
n
->
op
](
*
n
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
n
->
op
](
n
->
attrs
))
{
NodeEntry
&
e
=
n
->
inputs
[
i
];
CHECK
(
e
.
node
->
is_variable
())
<<
"Mutation target can only be Variable"
;
...
...
@@ -197,7 +197,7 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
if
(
node
->
is_variable
())
{
vlist
.
push_back
(
node
.
get
());
}
else
if
(
fmutate_inputs
.
count
(
node
->
op
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
node
->
op
](
*
node
)){
for
(
uint32_t
i
:
fmutate_inputs
[
node
->
op
](
node
->
attrs
)){
mutable_set
.
insert
(
node
->
inputs
[
i
].
node
.
get
());
}
}
...
...
@@ -223,7 +223,7 @@ std::vector<std::string> Symbol::ListOutputNames() const {
std
::
string
rname
;
FListOutputNames
fn
=
flist_ouputs
.
get
(
head
.
node
->
op
,
nullptr
);
if
(
fn
!=
nullptr
)
{
rname
=
fn
(
*
head
.
node
)[
head
.
index
];
rname
=
fn
(
head
.
node
->
attrs
)[
head
.
index
];
}
else
{
rname
=
"output"
;
if
(
head
.
node
->
num_outputs
()
!=
1
)
{
...
...
@@ -279,7 +279,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
// switch to keyword argument matching
if
(
args
.
size
()
!=
n_req
)
{
FListInputNames
fn
=
flist_inputs
.
get
(
n
->
op
,
nullptr
);
auto
arg_names
=
(
fn
==
nullptr
)
?
std
::
vector
<
std
::
string
>
{
"data"
}
:
fn
(
*
n
);
auto
arg_names
=
(
fn
==
nullptr
)
?
std
::
vector
<
std
::
string
>
{
"data"
}
:
fn
(
n
->
attrs
);
if
(
arg_names
.
size
()
!=
n_req
)
{
LOG
(
FATAL
)
<<
"Not enough argument to call operator "
<<
outputs
[
0
].
node
->
op
->
name
;
}
...
...
nnvm/src/pass/infer_shape_type.cc
View file @
803db5d1
...
...
@@ -75,7 +75,7 @@ Graph InferAttr(Graph &&ret,
oshape
[
i
]
=
rshape
[
idx
.
entry_id
(
nid
,
i
)];
}
num_unknown
+=
!
(
finfer_shape
[
inode
.
source
->
op
](
*
inode
.
source
,
&
ishape
,
&
oshape
));
!
(
finfer_shape
[
inode
.
source
->
op
](
inode
.
source
->
attrs
,
&
ishape
,
&
oshape
));
for
(
uint32_t
i
=
0
;
i
<
num_inputs
;
++
i
)
{
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])]
=
ishape
[
i
];
}
...
...
nnvm/src/pass/order_mutation.cc
View file @
803db5d1
...
...
@@ -44,7 +44,7 @@ Graph OrderMutation(const Graph& src) {
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
std
::
vector
<
uint32_t
>
mutate_inputs
;
if
(
!
n
->
is_variable
()
&&
fmutate_inputs
.
count
(
n
->
op
))
{
mutate_inputs
=
fmutate_inputs
[
n
->
op
](
*
n
);
mutate_inputs
=
fmutate_inputs
[
n
->
op
](
n
->
attrs
);
}
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
...
...
@@ -102,7 +102,7 @@ Graph OrderMutation(const Graph& src) {
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
std
::
vector
<
uint32_t
>
mutate_inputs
;
if
(
fmutate_inputs
.
count
(
kv
.
first
->
op
))
{
mutate_inputs
=
fmutate_inputs
[
kv
.
first
->
op
](
*
kv
.
first
);
mutate_inputs
=
fmutate_inputs
[
kv
.
first
->
op
](
kv
.
first
->
attrs
);
}
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
...
...
nnvm/src/pass/plan_memory.cc
View file @
803db5d1
...
...
@@ -169,7 +169,7 @@ Graph PlanMemory(Graph ret) {
if
(
inode
.
source
->
is_variable
())
continue
;
// check inplace option
if
(
finplace_option
.
count
(
inode
.
source
->
op
)
!=
0
)
{
auto
inplace_pairs
=
finplace_option
[
inode
.
source
->
op
](
*
inode
.
source
);
auto
inplace_pairs
=
finplace_option
[
inode
.
source
->
op
](
inode
.
source
->
attrs
);
for
(
auto
&
kv
:
inplace_pairs
)
{
uint32_t
eid_out
=
idx
.
entry_id
(
nid
,
kv
.
second
);
uint32_t
eid_in
=
idx
.
entry_id
(
inode
.
inputs
[
kv
.
first
]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment