Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e4820d34
Commit
e4820d34
authored
Nov 21, 2016
by
Eric Junyuan Xie
Committed by
Tianqi Chen
May 29, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix json parsing behavior (#83)
parent
98fd6bd0
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
12 deletions
+25
-12
nnvm/include/dmlc/parameter.h
+14
-6
nnvm/src/core/symbolic.cc
+1
-1
nnvm/src/pass/saveload_json.cc
+9
-4
nnvm/tests/python/test_symbol.py
+1
-1
No files found.
nnvm/include/dmlc/parameter.h
View file @
e4820d34
...
@@ -63,7 +63,9 @@ enum ParamInitOption {
...
@@ -63,7 +63,9 @@ enum ParamInitOption {
/*! \brief allow unknown parameters */
/*! \brief allow unknown parameters */
kAllowUnknown
,
kAllowUnknown
,
/*! \brief need to match exact parameters */
/*! \brief need to match exact parameters */
kAllMatch
kAllMatch
,
/*! \brief allow unmatched hidden field with format __*__ */
kAllowHidden
};
};
}
// namespace parameter
}
// namespace parameter
/*!
/*!
...
@@ -122,11 +124,11 @@ struct Parameter {
...
@@ -122,11 +124,11 @@ struct Parameter {
*/
*/
template
<
typename
Container
>
template
<
typename
Container
>
inline
void
Init
(
const
Container
&
kwargs
,
inline
void
Init
(
const
Container
&
kwargs
,
parameter
::
ParamInitOption
option
=
parameter
::
kAllow
Unknow
n
)
{
parameter
::
ParamInitOption
option
=
parameter
::
kAllow
Hidde
n
)
{
PType
::
__MANAGER__
()
->
RunInit
(
static_cast
<
PType
*>
(
this
),
PType
::
__MANAGER__
()
->
RunInit
(
static_cast
<
PType
*>
(
this
),
kwargs
.
begin
(),
kwargs
.
end
(),
kwargs
.
begin
(),
kwargs
.
end
(),
NULL
,
NULL
,
option
==
parameter
::
kAllowUnknown
);
option
);
}
}
/*!
/*!
* \brief initialize the parameter by keyword arguments.
* \brief initialize the parameter by keyword arguments.
...
@@ -143,7 +145,7 @@ struct Parameter {
...
@@ -143,7 +145,7 @@ struct Parameter {
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>
unknown
;
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>
unknown
;
PType
::
__MANAGER__
()
->
RunInit
(
static_cast
<
PType
*>
(
this
),
PType
::
__MANAGER__
()
->
RunInit
(
static_cast
<
PType
*>
(
this
),
kwargs
.
begin
(),
kwargs
.
end
(),
kwargs
.
begin
(),
kwargs
.
end
(),
&
unknown
,
true
);
&
unknown
,
parameter
::
kAllowUnknown
);
return
unknown
;
return
unknown
;
}
}
/*!
/*!
...
@@ -369,7 +371,7 @@ class ParamManager {
...
@@ -369,7 +371,7 @@ class ParamManager {
RandomAccessIterator
begin
,
RandomAccessIterator
begin
,
RandomAccessIterator
end
,
RandomAccessIterator
end
,
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>
*
unknown_args
,
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>
*
unknown_args
,
bool
allow_unknow
n
)
const
{
parameter
::
ParamInitOption
optio
n
)
const
{
std
::
set
<
FieldAccessEntry
*>
selected_args
;
std
::
set
<
FieldAccessEntry
*>
selected_args
;
for
(
RandomAccessIterator
it
=
begin
;
it
!=
end
;
++
it
)
{
for
(
RandomAccessIterator
it
=
begin
;
it
!=
end
;
++
it
)
{
FieldAccessEntry
*
e
=
Find
(
it
->
first
);
FieldAccessEntry
*
e
=
Find
(
it
->
first
);
...
@@ -381,7 +383,13 @@ class ParamManager {
...
@@ -381,7 +383,13 @@ class ParamManager {
if
(
unknown_args
!=
NULL
)
{
if
(
unknown_args
!=
NULL
)
{
unknown_args
->
push_back
(
*
it
);
unknown_args
->
push_back
(
*
it
);
}
else
{
}
else
{
if
(
!
allow_unknown
)
{
if
(
option
!=
parameter
::
kAllowUnknown
)
{
if
(
option
==
parameter
::
kAllowHidden
&&
it
->
first
.
length
()
>
4
&&
it
->
first
.
find
(
"__"
)
==
0
&&
it
->
first
.
rfind
(
"__"
)
==
it
->
first
.
length
()
-
2
)
{
continue
;
}
std
::
ostringstream
os
;
std
::
ostringstream
os
;
os
<<
"Cannot find argument
\'
"
<<
it
->
first
<<
"
\'
, Possible Arguments:
\n
"
;
os
<<
"Cannot find argument
\'
"
<<
it
->
first
<<
"
\'
, Possible Arguments:
\n
"
;
os
<<
"----------------
\n
"
;
os
<<
"----------------
\n
"
;
...
...
nnvm/src/core/symbolic.cc
View file @
e4820d34
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
nnvm
{
namespace
nnvm
{
namespace
symbol_constants
{
namespace
symbol_constants
{
const
char
*
kNamespaceSeparator
=
"
_
"
;
const
char
*
kNamespaceSeparator
=
"
$
"
;
}
// namespace symbol_constants
}
// namespace symbol_constants
// auxililary version attribute in variable.
// auxililary version attribute in variable.
...
...
nnvm/src/pass/saveload_json.cc
View file @
e4820d34
...
@@ -109,10 +109,6 @@ struct JSONNode {
...
@@ -109,10 +109,6 @@ struct JSONNode {
if
(
op_type_str
!=
"null"
)
{
if
(
op_type_str
!=
"null"
)
{
try
{
try
{
node
->
attrs
.
op
=
Op
::
Get
(
op_type_str
);
node
->
attrs
.
op
=
Op
::
Get
(
op_type_str
);
// rebuild attribute parser
if
(
node
->
op
()
->
attr_parser
!=
nullptr
)
{
node
->
op
()
->
attr_parser
(
&
(
node
->
attrs
));
}
}
catch
(
const
dmlc
::
Error
&
err
)
{
}
catch
(
const
dmlc
::
Error
&
err
)
{
std
::
ostringstream
os
;
std
::
ostringstream
os
;
os
<<
"Failed loading Op "
<<
node
->
attrs
.
name
os
<<
"Failed loading Op "
<<
node
->
attrs
.
name
...
@@ -163,6 +159,10 @@ Graph LoadJSON(Graph src) {
...
@@ -163,6 +159,10 @@ Graph LoadJSON(Graph src) {
<<
"Load JSON require json to be presented."
;
<<
"Load JSON require json to be presented."
;
const
std
::
string
&
json_str
=
const
std
::
string
&
json_str
=
nnvm
::
get
<
std
::
string
>
(
*
src
.
attrs
.
at
(
"json"
));
nnvm
::
get
<
std
::
string
>
(
*
src
.
attrs
.
at
(
"json"
));
bool
no_parse
=
false
;
if
(
src
.
attrs
.
count
(
"load_json_no_parse"
))
{
no_parse
=
nnvm
::
get
<
bool
>
(
*
src
.
attrs
.
at
(
"load_json_no_parse"
));
}
std
::
istringstream
is
(
json_str
);
std
::
istringstream
is
(
json_str
);
dmlc
::
JSONReader
reader
(
&
is
);
dmlc
::
JSONReader
reader
(
&
is
);
JSONGraph
jgraph
;
JSONGraph
jgraph
;
...
@@ -179,6 +179,11 @@ Graph LoadJSON(Graph src) {
...
@@ -179,6 +179,11 @@ Graph LoadJSON(Graph src) {
for
(
uint32_t
nid
:
n
.
control_deps
)
{
for
(
uint32_t
nid
:
n
.
control_deps
)
{
n
.
node
->
control_deps
.
push_back
(
jgraph
.
nodes
[
nid
].
node
);
n
.
node
->
control_deps
.
push_back
(
jgraph
.
nodes
[
nid
].
node
);
}
}
// rebuild attribute parser
if
(
!
no_parse
&&
n
.
node
->
op
()
!=
nullptr
&&
n
.
node
->
op
()
->
attr_parser
!=
nullptr
)
{
n
.
node
->
op
()
->
attr_parser
(
&
(
n
.
node
->
attrs
));
}
}
}
// consistent check
// consistent check
for
(
uint32_t
nid
:
jgraph
.
arg_nodes
)
{
for
(
uint32_t
nid
:
jgraph
.
arg_nodes
)
{
...
...
nnvm/tests/python/test_symbol.py
View file @
e4820d34
...
@@ -12,7 +12,7 @@ def test_compose():
...
@@ -12,7 +12,7 @@ def test_compose():
assert
y
.
list_attr
()[
'gpu'
]
==
'1'
assert
y
.
list_attr
()[
'gpu'
]
==
'1'
z
=
y
.
get_internals
()
z
=
y
.
get_internals
()
assert
z
[
'add_output'
]
.
list_output_names
()
==
[
'add_output'
]
assert
z
[
'add_output'
]
.
list_output_names
()
==
[
'add_output'
]
assert
y
.
list_attr
(
recursive
=
True
)[
'add
_
gpu'
]
==
'2'
assert
y
.
list_attr
(
recursive
=
True
)[
'add
$
gpu'
]
==
'2'
def
test_default_input
():
def
test_default_input
():
x
=
sym
.
Variable
(
'x'
)
x
=
sym
.
Variable
(
'x'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment