Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
486249e8
Commit
486249e8
authored
Aug 09, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Update mutate function (#23)
parent
16a6db3a
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
65 additions
and
58 deletions
+65
-58
nnvm/docs/Doxyfile
+2
-2
nnvm/example/src/operator.cc
+3
-3
nnvm/include/nnvm/graph.h
+4
-4
nnvm/include/nnvm/op_attr_types.h
+3
-4
nnvm/include/nnvm/pass_functions.h
+12
-12
nnvm/src/core/graph.cc
+1
-1
nnvm/src/core/symbolic.cc
+10
-16
nnvm/src/pass/infer_shape_type.cc
+7
-7
nnvm/src/pass/order_mutation.cc
+23
-9
No files found.
nnvm/docs/Doxyfile
View file @
486249e8
...
@@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8
...
@@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8
# title of most generated pages and in a few other places.
# title of most generated pages and in a few other places.
# The default value is: My Project.
# The default value is: My Project.
PROJECT_NAME = "
mxnngraph
"
PROJECT_NAME = "
nnvm
"
# The PROJECT_NUMBER tag can be used to enter a project or revision number. This
# The PROJECT_NUMBER tag can be used to enter a project or revision number. This
# could be handy for archiving the generated documentation or if some version
# could be handy for archiving the generated documentation or if some version
...
@@ -753,7 +753,7 @@ WARN_LOGFILE =
...
@@ -753,7 +753,7 @@ WARN_LOGFILE =
# spaces.
# spaces.
# Note: If this tag is empty the current directory is searched.
# Note: If this tag is empty the current directory is searched.
INPUT = include
INPUT = include
/nnvm
# This tag can be used to specify the character encoding of the source files
# This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
...
...
nnvm/example/src/operator.cc
View file @
486249e8
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
namespace
myproject
{
namespace
myproject
{
using
nnvm
::
FListInputNames
;
using
nnvm
::
FListInputNames
;
using
nnvm
::
FMutateInput
;
using
nnvm
::
FMutateInput
s
;
using
nnvm
::
FInferShape
;
using
nnvm
::
FInferShape
;
using
nnvm
::
FInferType
;
using
nnvm
::
FInferType
;
using
nnvm
::
FInplaceOption
;
using
nnvm
::
FInplaceOption
;
...
@@ -119,8 +119,8 @@ NNVM_REGISTER_OP(add)
...
@@ -119,8 +119,8 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP
(
assign
)
NNVM_REGISTER_OP
(
assign
)
.
set_num_inputs
(
2
)
.
set_num_inputs
(
2
)
.
set_num_outputs
(
1
)
.
set_num_outputs
(
1
)
.
attr
<
FMutateInput
>
(
"FMutateInput"
,
[](
const
NodeAttrs
&
attrs
,
uint32_t
index
)
{
.
attr
<
FMutateInput
s
>
(
"FMutateInputs"
,
[](
const
NodeAttrs
&
attrs
)
{
return
index
==
0
;
return
std
::
vector
<
uint32_t
>
{
0
}
;
});
});
}
// namespace myproject
}
// namespace myproject
nnvm/include/nnvm/graph.h
View file @
486249e8
...
@@ -144,8 +144,8 @@ class IndexedGraph {
...
@@ -144,8 +144,8 @@ class IndexedGraph {
return
nodes_
[
node_id
(
node
)];
return
nodes_
[
node_id
(
node
)];
}
}
/*! \return list of argument nodes */
/*! \return list of argument nodes */
inline
const
std
::
vector
<
uint32_t
>&
arg
_nodes
()
const
{
inline
const
std
::
vector
<
uint32_t
>&
input
_nodes
()
const
{
return
arg
_nodes_
;
return
input
_nodes_
;
}
}
/*! \return list of output entries */
/*! \return list of output entries */
inline
const
std
::
vector
<
NodeEntry
>&
outputs
()
const
{
inline
const
std
::
vector
<
NodeEntry
>&
outputs
()
const
{
...
@@ -161,8 +161,8 @@ class IndexedGraph {
...
@@ -161,8 +161,8 @@ class IndexedGraph {
explicit
IndexedGraph
(
const
Graph
&
other
);
explicit
IndexedGraph
(
const
Graph
&
other
);
// node pointers in CSR structure.
// node pointers in CSR structure.
std
::
vector
<
Node
>
nodes_
;
std
::
vector
<
Node
>
nodes_
;
// index to
argumen
t nodes
// index to
inpu
t nodes
std
::
vector
<
uint32_t
>
arg
_nodes_
;
std
::
vector
<
uint32_t
>
input
_nodes_
;
// space to store the outputs entries
// space to store the outputs entries
std
::
vector
<
NodeEntry
>
outputs_
;
std
::
vector
<
NodeEntry
>
outputs_
;
// mapping from node to index.
// mapping from node to index.
...
...
nnvm/include/nnvm/op_attr_types.h
View file @
486249e8
...
@@ -43,13 +43,12 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
...
@@ -43,13 +43,12 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
/*!
/*!
* \brief Check whether operator will mutate k-th input.
* \brief Check whether operator will mutate k-th input.
* \param attrs The attributes of the node.
* \param attrs The attributes of the node.
* \param index The input index
* \return list of input indices it mutates.
* \return Whether this operator will mutate index-th input.
*
*
* \note Register under "FMutateInput", default return false
* \note Register under "FMutateInput
s
", default return false
* FMutateInputs enables mutation order handling correctly.
* FMutateInputs enables mutation order handling correctly.
*/
*/
using
FMutateInput
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
uint32_t
index
)
>
;
using
FMutateInput
s
=
std
::
function
<
std
::
vector
<
uint32_t
>
(
const
NodeAttrs
&
attrs
)
>
;
/*!
/*!
* \brief Inference function of certain type.
* \brief Inference function of certain type.
...
...
nnvm/include/nnvm/pass_functions.h
View file @
486249e8
...
@@ -54,16 +54,16 @@ inline Graph OrderMutation(Graph src) {
...
@@ -54,16 +54,16 @@ inline Graph OrderMutation(Graph src) {
/*!
/*!
* \brief Infer shapes in the graph given the information.
* \brief Infer shapes in the graph given the information.
* \param graph source graph
* \param graph source graph
* \param shape_
arg
s The shapes of aruguments to the graph.
* \param shape_
input
s The shapes of aruguments to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape.
* \param shape_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
* The index of ShapeVector is given by graph.indexed_graph().entry_id
*/
*/
inline
Graph
InferShape
(
Graph
graph
,
inline
Graph
InferShape
(
Graph
graph
,
ShapeVector
shape_
arg
s
=
{},
ShapeVector
shape_
input
s
=
{},
std
::
string
shape_attr_key
=
""
)
{
std
::
string
shape_attr_key
=
""
)
{
if
(
shape_
arg
s
.
size
()
!=
0
)
{
if
(
shape_
input
s
.
size
()
!=
0
)
{
graph
.
attrs
[
"shape_
args"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
shape_arg
s
));
graph
.
attrs
[
"shape_
inputs"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
shape_input
s
));
}
}
if
(
shape_attr_key
.
length
()
!=
0
)
{
if
(
shape_attr_key
.
length
()
!=
0
)
{
graph
.
attrs
[
"shape_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
shape_attr_key
));
graph
.
attrs
[
"shape_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
shape_attr_key
));
...
@@ -74,19 +74,19 @@ inline Graph InferShape(Graph graph,
...
@@ -74,19 +74,19 @@ inline Graph InferShape(Graph graph,
/*!
/*!
* \brief Infer types in the graph given the information.
* \brief Infer types in the graph given the information.
* \param graph source graph
* \param graph source graph
* \param
shape_args The shapes of arugumen
ts to the graph.
* \param
dtype_inputs The shapes of inpu
ts to the graph.
* \param
sha
pe_attr_key The key to the node attribute that can indicate shape.
* \param
dty
pe_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
* The index of ShapeVector is given by graph.indexed_graph().entry_id
*/
*/
inline
Graph
InferType
(
Graph
graph
,
inline
Graph
InferType
(
Graph
graph
,
DTypeVector
type_arg
s
=
{},
DTypeVector
dtype_input
s
=
{},
std
::
string
type_attr_key
=
""
)
{
std
::
string
d
type_attr_key
=
""
)
{
if
(
type_arg
s
.
size
()
!=
0
)
{
if
(
dtype_input
s
.
size
()
!=
0
)
{
graph
.
attrs
[
"dtype_
args"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
type_arg
s
));
graph
.
attrs
[
"dtype_
inputs"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
dtype_input
s
));
}
}
if
(
type_attr_key
.
length
()
!=
0
)
{
if
(
d
type_attr_key
.
length
()
!=
0
)
{
graph
.
attrs
[
"dtype_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
type_attr_key
));
graph
.
attrs
[
"dtype_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
d
type_attr_key
));
}
}
return
ApplyPass
(
std
::
move
(
graph
),
{
"InferType"
});
return
ApplyPass
(
std
::
move
(
graph
),
{
"InferType"
});
}
}
...
...
nnvm/src/core/graph.cc
View file @
486249e8
...
@@ -30,7 +30,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
...
@@ -30,7 +30,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
nodes_
.
emplace_back
(
std
::
move
(
new_node
));
nodes_
.
emplace_back
(
std
::
move
(
new_node
));
// arg_nodes_
// arg_nodes_
if
(
n
->
is_variable
())
{
if
(
n
->
is_variable
())
{
arg
_nodes_
.
push_back
(
nid
);
input
_nodes_
.
push_back
(
nid
);
}
}
// node2index_
// node2index_
node2index_
[
n
.
get
()]
=
nid
;
node2index_
[
n
.
get
()]
=
nid
;
...
...
nnvm/src/core/symbolic.cc
View file @
486249e8
...
@@ -31,22 +31,19 @@ NodePtr CreateVariableNode(const std::string& name) {
...
@@ -31,22 +31,19 @@ NodePtr CreateVariableNode(const std::string& name) {
// The version of that varaible will increase
// The version of that varaible will increase
// version is used to implicitly order the mutation sequences
// version is used to implicitly order the mutation sequences
inline
void
UpdateNodeVersion
(
Node
*
n
)
{
inline
void
UpdateNodeVersion
(
Node
*
n
)
{
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
>
(
"FMutateInput
"
);
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
s
>
(
"FMutateInputs
"
);
for
(
NodeEntry
&
e
:
n
->
inputs
)
{
for
(
NodeEntry
&
e
:
n
->
inputs
)
{
if
(
e
.
node
->
is_variable
())
{
if
(
e
.
node
->
is_variable
())
{
e
.
version
=
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
e
.
version
=
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
}
}
}
}
if
(
fmutate_inputs
.
count
(
n
->
op
)
!=
0
)
{
if
(
fmutate_inputs
.
count
(
n
->
op
)
!=
0
)
{
FMutateInput
fmutate
=
fmutate_inputs
[
n
->
op
];
for
(
uint32_t
i
:
fmutate_inputs
[
n
->
op
](
n
->
attrs
))
{
for
(
uint32_t
i
=
0
;
i
<
n
->
inputs
.
size
();
++
i
)
{
NodeEntry
&
e
=
n
->
inputs
[
i
];
if
(
fmutate
(
n
->
attrs
,
i
))
{
CHECK
(
e
.
node
->
is_variable
())
NodeEntry
&
e
=
n
->
inputs
[
i
];
<<
"Mutation target can only be Variable"
;
CHECK
(
e
.
node
->
is_variable
())
// increase the version of the variable.
<<
"Mutation target can only be Variable"
;
e
.
version
=
++
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
// increase the version of the variable.
e
.
version
=
++
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
}
}
}
}
}
}
}
...
@@ -192,16 +189,13 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
...
@@ -192,16 +189,13 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
}
else
{
}
else
{
std
::
unordered_set
<
Node
*>
mutable_set
;
std
::
unordered_set
<
Node
*>
mutable_set
;
std
::
vector
<
Node
*>
vlist
;
std
::
vector
<
Node
*>
vlist
;
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
>
(
"FMutateInput
"
);
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
s
>
(
"FMutateInputs
"
);
DFSVisit
(
this
->
outputs
,
[
&
ret
,
&
mutable_set
,
&
vlist
](
const
NodePtr
&
node
)
{
DFSVisit
(
this
->
outputs
,
[
&
ret
,
&
mutable_set
,
&
vlist
](
const
NodePtr
&
node
)
{
if
(
node
->
is_variable
())
{
if
(
node
->
is_variable
())
{
vlist
.
push_back
(
node
.
get
());
vlist
.
push_back
(
node
.
get
());
}
else
if
(
fmutate_inputs
.
count
(
node
->
op
))
{
}
else
if
(
fmutate_inputs
.
count
(
node
->
op
))
{
FMutateInput
fmutate
=
fmutate_inputs
[
node
->
op
];
for
(
uint32_t
i
:
fmutate_inputs
[
node
->
op
](
node
->
attrs
)){
for
(
uint32_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
mutable_set
.
insert
(
node
->
inputs
[
i
].
node
.
get
());
if
(
fmutate
(
node
->
attrs
,
i
))
{
mutable_set
.
insert
(
node
->
inputs
[
i
].
node
.
get
());
}
}
}
}
}
});
});
...
...
nnvm/src/pass/infer_shape_type.cc
View file @
486249e8
...
@@ -15,7 +15,7 @@ template<typename AttrType, typename IsNone>
...
@@ -15,7 +15,7 @@ template<typename AttrType, typename IsNone>
Graph
InferAttr
(
Graph
&&
ret
,
Graph
InferAttr
(
Graph
&&
ret
,
const
AttrType
def_value
,
const
AttrType
def_value
,
const
char
*
infer_name
,
const
char
*
infer_name
,
const
char
*
arg
_name
,
const
char
*
input
_name
,
const
char
*
attr_key_name
,
const
char
*
attr_key_name
,
const
char
*
attr_name
,
const
char
*
attr_name
,
const
char
*
unknown_name
,
const
char
*
unknown_name
,
...
@@ -29,15 +29,15 @@ Graph InferAttr(Graph &&ret,
...
@@ -29,15 +29,15 @@ Graph InferAttr(Graph &&ret,
// reshape shape vector
// reshape shape vector
AttrVector
rshape
(
idx
.
num_node_entries
(),
def_value
);
AttrVector
rshape
(
idx
.
num_node_entries
(),
def_value
);
if
(
ret
.
attrs
.
count
(
arg
_name
)
!=
0
)
{
if
(
ret
.
attrs
.
count
(
input
_name
)
!=
0
)
{
const
AttrVector
&
shape_args
=
ret
.
GetAttr
<
AttrVector
>
(
arg
_name
);
const
AttrVector
&
shape_args
=
ret
.
GetAttr
<
AttrVector
>
(
input
_name
);
CHECK_LE
(
shape_args
.
size
(),
idx
.
arg
_nodes
().
size
())
CHECK_LE
(
shape_args
.
size
(),
idx
.
input
_nodes
().
size
())
<<
"shape args is more than number of arguments"
;
<<
"shape args is more than number of arguments"
;
for
(
size_t
i
=
0
;
i
<
shape_args
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
shape_args
.
size
();
++
i
)
{
rshape
[
idx
.
entry_id
(
idx
.
arg
_nodes
()[
i
],
0
)]
=
shape_args
[
i
];
rshape
[
idx
.
entry_id
(
idx
.
input
_nodes
()[
i
],
0
)]
=
shape_args
[
i
];
}
}
// erase the provided arguments
// erase the provided arguments
ret
.
attrs
.
erase
(
arg
_name
);
ret
.
attrs
.
erase
(
input
_name
);
}
}
std
::
string
shape_attr_key
;
std
::
string
shape_attr_key
;
if
(
ret
.
attrs
.
count
(
attr_key_name
)
!=
0
)
{
if
(
ret
.
attrs
.
count
(
attr_key_name
)
!=
0
)
{
...
@@ -113,7 +113,7 @@ NNVM_REGISTER_PASS(InferType)
...
@@ -113,7 +113,7 @@ NNVM_REGISTER_PASS(InferType)
.
set_body
([](
Graph
ret
)
{
.
set_body
([](
Graph
ret
)
{
return
InferAttr
<
int
>
(
return
InferAttr
<
int
>
(
std
::
move
(
ret
),
0
,
std
::
move
(
ret
),
0
,
"FInferType"
,
"dtype_
arg
s"
,
"dtype_attr_key"
,
"FInferType"
,
"dtype_
input
s"
,
"dtype_attr_key"
,
"dtype"
,
"dtype_num_unknown_nodes"
,
"dtype"
,
"dtype_num_unknown_nodes"
,
[](
const
int
t
)
{
return
t
==
-
1
;
});
[](
const
int
t
)
{
return
t
==
-
1
;
});
})
})
...
...
nnvm/src/pass/order_mutation.cc
View file @
486249e8
...
@@ -21,6 +21,13 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
...
@@ -21,6 +21,13 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
return
def
;
return
def
;
}
}
inline
bool
IsMutate
(
const
std
::
vector
<
uint32_t
>&
mutate_inputs
,
uint32_t
i
)
{
if
(
mutate_inputs
.
size
()
==
0
)
return
false
;
auto
it
=
std
::
lower_bound
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
(),
i
);
return
(
it
!=
mutate_inputs
.
end
())
&&
(
*
it
==
i
);
}
Graph
OrderMutation
(
const
Graph
&
src
)
{
Graph
OrderMutation
(
const
Graph
&
src
)
{
std
::
unordered_map
<
Node
*
,
std
::
vector
<
NodeEntry
>
>
version_hist
;
std
::
unordered_map
<
Node
*
,
std
::
vector
<
NodeEntry
>
>
version_hist
;
DFSVisit
(
src
.
outputs
,
[
&
version_hist
](
const
NodePtr
&
n
)
{
DFSVisit
(
src
.
outputs
,
[
&
version_hist
](
const
NodePtr
&
n
)
{
...
@@ -37,7 +44,13 @@ Graph OrderMutation(const Graph& src) {
...
@@ -37,7 +44,13 @@ Graph OrderMutation(const Graph& src) {
// start preparing for remapping the nodes.
// start preparing for remapping the nodes.
std
::
unordered_map
<
Node
*
,
NodePtr
>
old_new
;
std
::
unordered_map
<
Node
*
,
NodePtr
>
old_new
;
auto
prepare
=
[
&
version_hist
,
&
old_new
]
(
const
NodePtr
&
n
)
{
auto
prepare
=
[
&
version_hist
,
&
old_new
]
(
const
NodePtr
&
n
)
{
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
>
(
"FMutateInput"
);
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
std
::
vector
<
uint32_t
>
mutate_inputs
;
if
(
!
n
->
is_variable
()
&&
fmutate_inputs
.
count
(
n
->
op
))
{
mutate_inputs
=
fmutate_inputs
[
n
->
op
](
n
->
attrs
);
}
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
bool
need_repl
=
false
;
bool
need_repl
=
false
;
for
(
size_t
i
=
0
;
i
<
n
->
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
n
->
inputs
.
size
();
++
i
)
{
const
NodeEntry
&
e
=
n
->
inputs
[
i
];
const
NodeEntry
&
e
=
n
->
inputs
[
i
];
...
@@ -46,9 +59,7 @@ Graph OrderMutation(const Graph& src) {
...
@@ -46,9 +59,7 @@ Graph OrderMutation(const Graph& src) {
auto
it
=
version_hist
.
find
(
e
.
node
.
get
());
auto
it
=
version_hist
.
find
(
e
.
node
.
get
());
if
(
it
!=
version_hist
.
end
())
{
if
(
it
!=
version_hist
.
end
())
{
std
::
vector
<
NodeEntry
>&
vec
=
it
->
second
;
std
::
vector
<
NodeEntry
>&
vec
=
it
->
second
;
uint32_t
is_mutate
=
vec
.
emplace_back
(
NodeEntry
{
n
,
IsMutate
(
mutate_inputs
,
i
),
e
.
version
});
fmutate_inputs
.
count
(
n
->
op
)
?
fmutate_inputs
[
n
->
op
](
n
->
attrs
,
i
)
:
0
;
vec
.
emplace_back
(
NodeEntry
{
n
,
is_mutate
,
e
.
version
});
}
}
}
else
{
}
else
{
if
(
old_new
.
count
(
e
.
node
.
get
())
!=
0
)
need_repl
=
true
;
if
(
old_new
.
count
(
e
.
node
.
get
())
!=
0
)
need_repl
=
true
;
...
@@ -91,18 +102,21 @@ Graph OrderMutation(const Graph& src) {
...
@@ -91,18 +102,21 @@ Graph OrderMutation(const Graph& src) {
get_with_default
(
old_new
,
p
.
get
(),
p
));
get_with_default
(
old_new
,
p
.
get
(),
p
));
}
}
// add control deps
// add control deps
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
>
(
"FMutateInput"
);
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
std
::
vector
<
uint32_t
>
mutate_inputs
;
if
(
fmutate_inputs
.
count
(
kv
.
first
->
op
))
{
mutate_inputs
=
fmutate_inputs
[
kv
.
first
->
op
](
kv
.
first
->
attrs
);
}
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
for
(
size_t
i
=
0
;
i
<
kv
.
first
->
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
kv
.
first
->
inputs
.
size
();
++
i
)
{
const
NodeEntry
&
e
=
kv
.
first
->
inputs
[
i
];
const
NodeEntry
&
e
=
kv
.
first
->
inputs
[
i
];
if
(
e
.
node
->
is_variable
()
&&
version_hist
.
count
(
e
.
node
.
get
())
!=
0
)
{
if
(
e
.
node
->
is_variable
()
&&
version_hist
.
count
(
e
.
node
.
get
())
!=
0
)
{
FMutateInput
fmutate
=
fmutate_inputs
.
get
(
kv
.
first
->
op
,
nullptr
);
uint32_t
is_mutate
=
(
fmutate
==
nullptr
)
?
0
:
fmutate
(
kv
.
first
->
attrs
,
i
);
std
::
vector
<
NodeEntry
>&
vec
=
version_hist
.
at
(
e
.
node
.
get
());
std
::
vector
<
NodeEntry
>&
vec
=
version_hist
.
at
(
e
.
node
.
get
());
auto
it
=
std
::
lower_bound
(
vec
.
begin
(),
vec
.
end
(),
auto
it
=
std
::
lower_bound
(
vec
.
begin
(),
vec
.
end
(),
NodeEntry
{
nullptr
,
1
,
e
.
version
},
NodeEntry
{
nullptr
,
1
,
e
.
version
},
comparator
);
comparator
);
if
(
is_mutate
!=
0
)
{
if
(
IsMutate
(
mutate_inputs
,
i
)
)
{
int
read_dep
=
0
;
int
read_dep
=
0
;
while
(
it
!=
vec
.
begin
())
{
while
(
it
!=
vec
.
begin
())
{
--
it
;
--
it
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment