Commit c92d63c7 by Tianqi Chen

Enable use json for graph attr exchange (#5)

parent e4a872d1
......@@ -248,27 +248,34 @@ NNVM_DLL int NNGraphFree(GraphHandle handle);
*/
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
/*!
* \brief Get Set a std::string typed attribute to graph.
* \brief Get Set a attribute in json format.
* This feature allows pass graph attributes back and forth in reasonable speed.
*
* \param handle The graph handle.
* \param key The key to the attribute.
* \param value The value to be exposed.
* \param json_value The value need to be in format [type_name, value],
* Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphSetStrAttr(GraphHandle handle,
const char* key,
const char* value);
NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle,
const char* key,
const char* json_value);
/*!
* \brief Get Set a std::string typed attribute from graph attribute.
* \brief Get a serialized attrirbute from graph.
* This feature allows pass graph attributes back and forth in reasonable speed.
*
* \param handle The graph handle.
* \param key The key to the attribute.
* \param out The result attribute, can be NULL if the attribute do not exist.
* \param json_out The result attribute, can be NULL if the attribute do not exist.
* The json_out is an array of [type_name, value].
* Where the type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphGetStrAttr(SymbolHandle handle,
const char* key,
const char** out,
int *success);
NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle,
const char* key,
const char** json_out,
int *success);
/*!
* \brief Apply pass on the src graph.
* \param src The source graph handle.
......
......@@ -47,6 +47,7 @@ SymbolCreatorHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p
GraphHandle = ctypes.c_void_p
#----------------------------
# helper function definition
#----------------------------
......
......@@ -5,12 +5,14 @@ from __future__ import absolute_import as _abs
import ctypes
import sys
import json
from .base import _LIB
from .base import c_array, c_str, nn_uint, py_str, string_types
from .base import GraphHandle, SymbolHandle
from .base import check_call
from .symbol import Symbol
class Graph(object):
"""Graph is the graph object that can be used to apply optimization pass.
It contains additional graphwise attribute besides the internal symbol.
......@@ -31,7 +33,7 @@ class Graph(object):
def __del__(self):
check_call(_LIB.NNGraphFree(self.handle))
def attr(self, key):
def json_attr(self, key):
"""Get attribute string from the graph.
Parameters
......@@ -46,24 +48,33 @@ class Graph(object):
"""
ret = ctypes.c_char_p()
success = ctypes.c_int()
check_call(_LIB.NNGraphGetStrAttr(
check_call(_LIB.NNGraphGetJSONAttr(
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
return py_str(ret.value)
json_str = py_str(ret.value)
return json.loads(json_str)[1]
else:
return None
def _set_attr(self, **kwargs):
def _set_json_attr(self, key, value, type_name=None):
"""Set the attribute of the symbol.
Parameters
----------
**kwargs
The attributes to set
key : string
The key of the attribute
value : value
The any type that can be dumped to json
type_name : string
The typename registered on c++ side.
"""
for k, v in kwargs.items():
check_call(_LIB.NNGraphSetStrAttr(
self.handle, c_str(k), c_str(v)))
if isinstance(value, string_types):
type_name = 'str'
elif type_name is None:
raise ValueError("Need to specify type_name")
json_value = json.dumps([type_name, value])
check_call(_LIB.NNGraphSetJSONAttr(
self.handle, c_str(key), c_str(json_value)))
@property
def symbol(self):
......
......@@ -8,6 +8,7 @@
#include <nnvm/symbolic.h>
#include <nnvm/graph.h>
#include <nnvm/pass.h>
#include <dmlc/json.h>
#include "./c_api_common.h"
using namespace nnvm;
......@@ -34,26 +35,35 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
API_END_HANDLE_ERROR(delete s);
}
int NNGraphSetStrAttr(GraphHandle handle,
const char* key,
const char* value) {
int NNGraphSetJSONAttr(GraphHandle handle,
const char* key,
const char* json_value) {
API_BEGIN();
Graph* g = static_cast<Graph*>(handle);
g->attrs[std::string(key)] = std::make_shared<any>(std::string(value));
std::string temp(json_value);
std::istringstream is(temp);
dmlc::JSONReader reader(&is);
nnvm::any value;
reader.Read(&value);
g->attrs[std::string(key)] = std::make_shared<any>(std::move(value));
API_END();
}
int NNGraphGetStrAttr(GraphHandle handle,
int NNGraphGetJSONAttr(GraphHandle handle,
const char* key,
const char** out,
const char** json_out,
int *success) {
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
Graph* g = static_cast<Graph*>(handle);
std::string skey(key);
auto it = g->attrs.find(skey);
if (it != g->attrs.end()) {
const std::string& str = nnvm::get<std::string>(*it->second.get());
*out = str.c_str();
std::ostringstream os;
dmlc::JSONWriter writer(&os);
writer.Write(*it->second.get());
ret->ret_str = os.str();
*json_out = (ret->ret_str).c_str();
*success = 1;
} else {
*success = 0;
......
......@@ -203,5 +203,9 @@ NNVM_REGISTER_PASS(SaveJSON)
.set_change_graph(true)
.provide_graph_attr("json");
DMLC_JSON_ENABLE_ANY(std::string, str);
DMLC_JSON_ENABLE_ANY(std::vector<int>, list_int);
} // namespace pass
} // namespace nnvm
......@@ -6,9 +6,18 @@ def test_json_pass():
y = sym.conv2d(data=x, name='conv', stride=(2,2))
g = graph.create(y)
ret = g.apply('SaveJSON')
ret._set_json_attr('json', ret.json_attr('json'))
g2 = ret.apply('LoadJSON')
assert g2.apply('SaveJSON').attr('json') == ret.attr('json')
assert g2.apply('SaveJSON').json_attr('json') == ret.json_attr('json')
def test_graph_json_attr():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv', stride=(2,2))
g = graph.create(y)
g._set_json_attr('ilist', [1,2,3], 'list_int')
assert g.json_attr('ilist') == [1,2,3]
if __name__ == "__main__":
test_graph_json_attr()
test_json_pass()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment