Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
c92d63c7
Commit
c92d63c7
authored
Jul 11, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Enable use json for graph attr exchange (#5)
parent
e4a872d1
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
67 additions
and
25 deletions
+67
-25
nnvm/include/nnvm/c_api.h
+15
-8
nnvm/python/nnvm/base.py
+1
-0
nnvm/python/nnvm/graph.py
+20
-9
nnvm/src/c_api/c_api_graph.cc
+17
-7
nnvm/src/pass/saveload_json.cc
+4
-0
nnvm/tests/python/test_graph.py
+10
-1
No files found.
nnvm/include/nnvm/c_api.h
View file @
c92d63c7
...
@@ -248,26 +248,33 @@ NNVM_DLL int NNGraphFree(GraphHandle handle);
...
@@ -248,26 +248,33 @@ NNVM_DLL int NNGraphFree(GraphHandle handle);
*/
*/
NNVM_DLL
int
NNGraphGetSymbol
(
GraphHandle
graph
,
SymbolHandle
*
symbol
);
NNVM_DLL
int
NNGraphGetSymbol
(
GraphHandle
graph
,
SymbolHandle
*
symbol
);
/*!
/*!
* \brief Get Set a std::string typed attribute to graph.
* \brief Get Set a attribute in json format.
* This feature allows pass graph attributes back and forth in reasonable speed.
*
* \param handle The graph handle.
* \param handle The graph handle.
* \param key The key to the attribute.
* \param key The key to the attribute.
* \param value The value to be exposed.
* \param json_value The value need to be in format [type_name, value],
* Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
* \return 0 when success, -1 when failure happens
* \return 0 when success, -1 when failure happens
*/
*/
NNVM_DLL
int
NNGraphSet
Str
Attr
(
GraphHandle
handle
,
NNVM_DLL
int
NNGraphSet
JSON
Attr
(
GraphHandle
handle
,
const
char
*
key
,
const
char
*
key
,
const
char
*
value
);
const
char
*
json_
value
);
/*!
/*!
* \brief Get Set a std::string typed attribute from graph attribute.
* \brief Get a serialized attrirbute from graph.
* This feature allows pass graph attributes back and forth in reasonable speed.
*
* \param handle The graph handle.
* \param handle The graph handle.
* \param key The key to the attribute.
* \param key The key to the attribute.
* \param out The result attribute, can be NULL if the attribute do not exist.
* \param json_out The result attribute, can be NULL if the attribute do not exist.
* The json_out is an array of [type_name, value].
* Where the type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
* \param success Whether the result is contained in out.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
* \return 0 when success, -1 when failure happens
*/
*/
NNVM_DLL
int
NNGraphGet
Str
Attr
(
SymbolHandle
handle
,
NNVM_DLL
int
NNGraphGet
JSON
Attr
(
SymbolHandle
handle
,
const
char
*
key
,
const
char
*
key
,
const
char
**
out
,
const
char
**
json_
out
,
int
*
success
);
int
*
success
);
/*!
/*!
* \brief Apply pass on the src graph.
* \brief Apply pass on the src graph.
...
...
nnvm/python/nnvm/base.py
View file @
c92d63c7
...
@@ -47,6 +47,7 @@ SymbolCreatorHandle = ctypes.c_void_p
...
@@ -47,6 +47,7 @@ SymbolCreatorHandle = ctypes.c_void_p
SymbolHandle
=
ctypes
.
c_void_p
SymbolHandle
=
ctypes
.
c_void_p
GraphHandle
=
ctypes
.
c_void_p
GraphHandle
=
ctypes
.
c_void_p
#----------------------------
#----------------------------
# helper function definition
# helper function definition
#----------------------------
#----------------------------
...
...
nnvm/python/nnvm/graph.py
View file @
c92d63c7
...
@@ -5,12 +5,14 @@ from __future__ import absolute_import as _abs
...
@@ -5,12 +5,14 @@ from __future__ import absolute_import as _abs
import
ctypes
import
ctypes
import
sys
import
sys
import
json
from
.base
import
_LIB
from
.base
import
_LIB
from
.base
import
c_array
,
c_str
,
nn_uint
,
py_str
,
string_types
from
.base
import
c_array
,
c_str
,
nn_uint
,
py_str
,
string_types
from
.base
import
GraphHandle
,
SymbolHandle
from
.base
import
GraphHandle
,
SymbolHandle
from
.base
import
check_call
from
.base
import
check_call
from
.symbol
import
Symbol
from
.symbol
import
Symbol
class
Graph
(
object
):
class
Graph
(
object
):
"""Graph is the graph object that can be used to apply optimization pass.
"""Graph is the graph object that can be used to apply optimization pass.
It contains additional graphwise attribute besides the internal symbol.
It contains additional graphwise attribute besides the internal symbol.
...
@@ -31,7 +33,7 @@ class Graph(object):
...
@@ -31,7 +33,7 @@ class Graph(object):
def
__del__
(
self
):
def
__del__
(
self
):
check_call
(
_LIB
.
NNGraphFree
(
self
.
handle
))
check_call
(
_LIB
.
NNGraphFree
(
self
.
handle
))
def
attr
(
self
,
key
):
def
json_
attr
(
self
,
key
):
"""Get attribute string from the graph.
"""Get attribute string from the graph.
Parameters
Parameters
...
@@ -46,24 +48,33 @@ class Graph(object):
...
@@ -46,24 +48,33 @@ class Graph(object):
"""
"""
ret
=
ctypes
.
c_char_p
()
ret
=
ctypes
.
c_char_p
()
success
=
ctypes
.
c_int
()
success
=
ctypes
.
c_int
()
check_call
(
_LIB
.
NNGraphGet
Str
Attr
(
check_call
(
_LIB
.
NNGraphGet
JSON
Attr
(
self
.
handle
,
c_str
(
key
),
ctypes
.
byref
(
ret
),
ctypes
.
byref
(
success
)))
self
.
handle
,
c_str
(
key
),
ctypes
.
byref
(
ret
),
ctypes
.
byref
(
success
)))
if
success
.
value
!=
0
:
if
success
.
value
!=
0
:
return
py_str
(
ret
.
value
)
json_str
=
py_str
(
ret
.
value
)
return
json
.
loads
(
json_str
)[
1
]
else
:
else
:
return
None
return
None
def
_set_
attr
(
self
,
**
kwargs
):
def
_set_
json_attr
(
self
,
key
,
value
,
type_name
=
None
):
"""Set the attribute of the symbol.
"""Set the attribute of the symbol.
Parameters
Parameters
----------
----------
**kwargs
key : string
The attributes to set
The key of the attribute
value : value
The any type that can be dumped to json
type_name : string
The typename registered on c++ side.
"""
"""
for
k
,
v
in
kwargs
.
items
():
if
isinstance
(
value
,
string_types
):
check_call
(
_LIB
.
NNGraphSetStrAttr
(
type_name
=
'str'
self
.
handle
,
c_str
(
k
),
c_str
(
v
)))
elif
type_name
is
None
:
raise
ValueError
(
"Need to specify type_name"
)
json_value
=
json
.
dumps
([
type_name
,
value
])
check_call
(
_LIB
.
NNGraphSetJSONAttr
(
self
.
handle
,
c_str
(
key
),
c_str
(
json_value
)))
@property
@property
def
symbol
(
self
):
def
symbol
(
self
):
...
...
nnvm/src/c_api/c_api_graph.cc
View file @
c92d63c7
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <nnvm/symbolic.h>
#include <nnvm/symbolic.h>
#include <nnvm/graph.h>
#include <nnvm/graph.h>
#include <nnvm/pass.h>
#include <nnvm/pass.h>
#include <dmlc/json.h>
#include "./c_api_common.h"
#include "./c_api_common.h"
using
namespace
nnvm
;
using
namespace
nnvm
;
...
@@ -34,26 +35,35 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
...
@@ -34,26 +35,35 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
API_END_HANDLE_ERROR
(
delete
s
);
API_END_HANDLE_ERROR
(
delete
s
);
}
}
int
NNGraphSet
Str
Attr
(
GraphHandle
handle
,
int
NNGraphSet
JSON
Attr
(
GraphHandle
handle
,
const
char
*
key
,
const
char
*
key
,
const
char
*
value
)
{
const
char
*
json_
value
)
{
API_BEGIN
();
API_BEGIN
();
Graph
*
g
=
static_cast
<
Graph
*>
(
handle
);
Graph
*
g
=
static_cast
<
Graph
*>
(
handle
);
g
->
attrs
[
std
::
string
(
key
)]
=
std
::
make_shared
<
any
>
(
std
::
string
(
value
));
std
::
string
temp
(
json_value
);
std
::
istringstream
is
(
temp
);
dmlc
::
JSONReader
reader
(
&
is
);
nnvm
::
any
value
;
reader
.
Read
(
&
value
);
g
->
attrs
[
std
::
string
(
key
)]
=
std
::
make_shared
<
any
>
(
std
::
move
(
value
));
API_END
();
API_END
();
}
}
int
NNGraphGet
Str
Attr
(
GraphHandle
handle
,
int
NNGraphGet
JSON
Attr
(
GraphHandle
handle
,
const
char
*
key
,
const
char
*
key
,
const
char
**
out
,
const
char
**
json_
out
,
int
*
success
)
{
int
*
success
)
{
NNAPIThreadLocalEntry
*
ret
=
NNAPIThreadLocalStore
::
Get
();
API_BEGIN
();
API_BEGIN
();
Graph
*
g
=
static_cast
<
Graph
*>
(
handle
);
Graph
*
g
=
static_cast
<
Graph
*>
(
handle
);
std
::
string
skey
(
key
);
std
::
string
skey
(
key
);
auto
it
=
g
->
attrs
.
find
(
skey
);
auto
it
=
g
->
attrs
.
find
(
skey
);
if
(
it
!=
g
->
attrs
.
end
())
{
if
(
it
!=
g
->
attrs
.
end
())
{
const
std
::
string
&
str
=
nnvm
::
get
<
std
::
string
>
(
*
it
->
second
.
get
());
std
::
ostringstream
os
;
*
out
=
str
.
c_str
();
dmlc
::
JSONWriter
writer
(
&
os
);
writer
.
Write
(
*
it
->
second
.
get
());
ret
->
ret_str
=
os
.
str
();
*
json_out
=
(
ret
->
ret_str
).
c_str
();
*
success
=
1
;
*
success
=
1
;
}
else
{
}
else
{
*
success
=
0
;
*
success
=
0
;
...
...
nnvm/src/pass/saveload_json.cc
View file @
c92d63c7
...
@@ -203,5 +203,9 @@ NNVM_REGISTER_PASS(SaveJSON)
...
@@ -203,5 +203,9 @@ NNVM_REGISTER_PASS(SaveJSON)
.
set_change_graph
(
true
)
.
set_change_graph
(
true
)
.
provide_graph_attr
(
"json"
);
.
provide_graph_attr
(
"json"
);
DMLC_JSON_ENABLE_ANY
(
std
::
string
,
str
);
DMLC_JSON_ENABLE_ANY
(
std
::
vector
<
int
>
,
list_int
);
}
// namespace pass
}
// namespace pass
}
// namespace nnvm
}
// namespace nnvm
nnvm/tests/python/test_graph.py
View file @
c92d63c7
...
@@ -6,9 +6,18 @@ def test_json_pass():
...
@@ -6,9 +6,18 @@ def test_json_pass():
y
=
sym
.
conv2d
(
data
=
x
,
name
=
'conv'
,
stride
=
(
2
,
2
))
y
=
sym
.
conv2d
(
data
=
x
,
name
=
'conv'
,
stride
=
(
2
,
2
))
g
=
graph
.
create
(
y
)
g
=
graph
.
create
(
y
)
ret
=
g
.
apply
(
'SaveJSON'
)
ret
=
g
.
apply
(
'SaveJSON'
)
ret
.
_set_json_attr
(
'json'
,
ret
.
json_attr
(
'json'
))
g2
=
ret
.
apply
(
'LoadJSON'
)
g2
=
ret
.
apply
(
'LoadJSON'
)
assert
g2
.
apply
(
'SaveJSON'
)
.
attr
(
'json'
)
==
ret
.
attr
(
'json'
)
assert
g2
.
apply
(
'SaveJSON'
)
.
json_attr
(
'json'
)
==
ret
.
json_attr
(
'json'
)
def
test_graph_json_attr
():
x
=
sym
.
Variable
(
'x'
)
y
=
sym
.
conv2d
(
data
=
x
,
name
=
'conv'
,
stride
=
(
2
,
2
))
g
=
graph
.
create
(
y
)
g
.
_set_json_attr
(
'ilist'
,
[
1
,
2
,
3
],
'list_int'
)
assert
g
.
json_attr
(
'ilist'
)
==
[
1
,
2
,
3
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_graph_json_attr
()
test_json_pass
()
test_json_pass
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment