Commit c92d63c7 by Tianqi Chen

Enable use json for graph attr exchange (#5)

parent e4a872d1
...@@ -248,26 +248,33 @@ NNVM_DLL int NNGraphFree(GraphHandle handle); ...@@ -248,26 +248,33 @@ NNVM_DLL int NNGraphFree(GraphHandle handle);
*/ */
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
/*! /*!
* \brief Get Set a std::string typed attribute to graph. * \brief Get Set a attribute in json format.
* This feature allows pass graph attributes back and forth in reasonable speed.
*
* \param handle The graph handle. * \param handle The graph handle.
* \param key The key to the attribute. * \param key The key to the attribute.
* \param value The value to be exposed. * \param json_value The value need to be in format [type_name, value],
* Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
NNVM_DLL int NNGraphSetStrAttr(GraphHandle handle, NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle,
const char* key, const char* key,
const char* value); const char* json_value);
/*! /*!
* \brief Get Set a std::string typed attribute from graph attribute. * \brief Get a serialized attrirbute from graph.
* This feature allows pass graph attributes back and forth in reasonable speed.
*
* \param handle The graph handle. * \param handle The graph handle.
* \param key The key to the attribute. * \param key The key to the attribute.
* \param out The result attribute, can be NULL if the attribute do not exist. * \param json_out The result attribute, can be NULL if the attribute do not exist.
* The json_out is an array of [type_name, value].
* Where the type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
* \param success Whether the result is contained in out. * \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
NNVM_DLL int NNGraphGetStrAttr(SymbolHandle handle, NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle,
const char* key, const char* key,
const char** out, const char** json_out,
int *success); int *success);
/*! /*!
* \brief Apply pass on the src graph. * \brief Apply pass on the src graph.
......
...@@ -47,6 +47,7 @@ SymbolCreatorHandle = ctypes.c_void_p ...@@ -47,6 +47,7 @@ SymbolCreatorHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p SymbolHandle = ctypes.c_void_p
GraphHandle = ctypes.c_void_p GraphHandle = ctypes.c_void_p
#---------------------------- #----------------------------
# helper function definition # helper function definition
#---------------------------- #----------------------------
......
...@@ -5,12 +5,14 @@ from __future__ import absolute_import as _abs ...@@ -5,12 +5,14 @@ from __future__ import absolute_import as _abs
import ctypes import ctypes
import sys import sys
import json
from .base import _LIB from .base import _LIB
from .base import c_array, c_str, nn_uint, py_str, string_types from .base import c_array, c_str, nn_uint, py_str, string_types
from .base import GraphHandle, SymbolHandle from .base import GraphHandle, SymbolHandle
from .base import check_call from .base import check_call
from .symbol import Symbol from .symbol import Symbol
class Graph(object): class Graph(object):
"""Graph is the graph object that can be used to apply optimization pass. """Graph is the graph object that can be used to apply optimization pass.
It contains additional graphwise attribute besides the internal symbol. It contains additional graphwise attribute besides the internal symbol.
...@@ -31,7 +33,7 @@ class Graph(object): ...@@ -31,7 +33,7 @@ class Graph(object):
def __del__(self): def __del__(self):
check_call(_LIB.NNGraphFree(self.handle)) check_call(_LIB.NNGraphFree(self.handle))
def attr(self, key): def json_attr(self, key):
"""Get attribute string from the graph. """Get attribute string from the graph.
Parameters Parameters
...@@ -46,24 +48,33 @@ class Graph(object): ...@@ -46,24 +48,33 @@ class Graph(object):
""" """
ret = ctypes.c_char_p() ret = ctypes.c_char_p()
success = ctypes.c_int() success = ctypes.c_int()
check_call(_LIB.NNGraphGetStrAttr( check_call(_LIB.NNGraphGetJSONAttr(
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success))) self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0: if success.value != 0:
return py_str(ret.value) json_str = py_str(ret.value)
return json.loads(json_str)[1]
else: else:
return None return None
def _set_attr(self, **kwargs): def _set_json_attr(self, key, value, type_name=None):
"""Set the attribute of the symbol. """Set the attribute of the symbol.
Parameters Parameters
---------- ----------
**kwargs key : string
The attributes to set The key of the attribute
value : value
The any type that can be dumped to json
type_name : string
The typename registered on c++ side.
""" """
for k, v in kwargs.items(): if isinstance(value, string_types):
check_call(_LIB.NNGraphSetStrAttr( type_name = 'str'
self.handle, c_str(k), c_str(v))) elif type_name is None:
raise ValueError("Need to specify type_name")
json_value = json.dumps([type_name, value])
check_call(_LIB.NNGraphSetJSONAttr(
self.handle, c_str(key), c_str(json_value)))
@property @property
def symbol(self): def symbol(self):
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <nnvm/symbolic.h> #include <nnvm/symbolic.h>
#include <nnvm/graph.h> #include <nnvm/graph.h>
#include <nnvm/pass.h> #include <nnvm/pass.h>
#include <dmlc/json.h>
#include "./c_api_common.h" #include "./c_api_common.h"
using namespace nnvm; using namespace nnvm;
...@@ -34,26 +35,35 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) { ...@@ -34,26 +35,35 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
API_END_HANDLE_ERROR(delete s); API_END_HANDLE_ERROR(delete s);
} }
int NNGraphSetStrAttr(GraphHandle handle, int NNGraphSetJSONAttr(GraphHandle handle,
const char* key, const char* key,
const char* value) { const char* json_value) {
API_BEGIN(); API_BEGIN();
Graph* g = static_cast<Graph*>(handle); Graph* g = static_cast<Graph*>(handle);
g->attrs[std::string(key)] = std::make_shared<any>(std::string(value)); std::string temp(json_value);
std::istringstream is(temp);
dmlc::JSONReader reader(&is);
nnvm::any value;
reader.Read(&value);
g->attrs[std::string(key)] = std::make_shared<any>(std::move(value));
API_END(); API_END();
} }
int NNGraphGetStrAttr(GraphHandle handle, int NNGraphGetJSONAttr(GraphHandle handle,
const char* key, const char* key,
const char** out, const char** json_out,
int *success) { int *success) {
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
Graph* g = static_cast<Graph*>(handle); Graph* g = static_cast<Graph*>(handle);
std::string skey(key); std::string skey(key);
auto it = g->attrs.find(skey); auto it = g->attrs.find(skey);
if (it != g->attrs.end()) { if (it != g->attrs.end()) {
const std::string& str = nnvm::get<std::string>(*it->second.get()); std::ostringstream os;
*out = str.c_str(); dmlc::JSONWriter writer(&os);
writer.Write(*it->second.get());
ret->ret_str = os.str();
*json_out = (ret->ret_str).c_str();
*success = 1; *success = 1;
} else { } else {
*success = 0; *success = 0;
......
...@@ -203,5 +203,9 @@ NNVM_REGISTER_PASS(SaveJSON) ...@@ -203,5 +203,9 @@ NNVM_REGISTER_PASS(SaveJSON)
.set_change_graph(true) .set_change_graph(true)
.provide_graph_attr("json"); .provide_graph_attr("json");
DMLC_JSON_ENABLE_ANY(std::string, str);
DMLC_JSON_ENABLE_ANY(std::vector<int>, list_int);
} // namespace pass } // namespace pass
} // namespace nnvm } // namespace nnvm
...@@ -6,9 +6,18 @@ def test_json_pass(): ...@@ -6,9 +6,18 @@ def test_json_pass():
y = sym.conv2d(data=x, name='conv', stride=(2,2)) y = sym.conv2d(data=x, name='conv', stride=(2,2))
g = graph.create(y) g = graph.create(y)
ret = g.apply('SaveJSON') ret = g.apply('SaveJSON')
ret._set_json_attr('json', ret.json_attr('json'))
g2 = ret.apply('LoadJSON') g2 = ret.apply('LoadJSON')
assert g2.apply('SaveJSON').attr('json') == ret.attr('json') assert g2.apply('SaveJSON').json_attr('json') == ret.json_attr('json')
def test_graph_json_attr():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv', stride=(2,2))
g = graph.create(y)
g._set_json_attr('ilist', [1,2,3], 'list_int')
assert g.json_attr('ilist') == [1,2,3]
if __name__ == "__main__": if __name__ == "__main__":
test_graph_json_attr()
test_json_pass() test_json_pass()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment