Commit e4a872d1 by Tianqi Chen

[PYTHON] Check in a symbolic construction interface in python, start … (#4)

* [PYTHON] Check in a symbolic construction interface in python, start add graph API

* Graph API
parent 39dfff8a
......@@ -30,8 +30,8 @@ typedef unsigned int nn_uint;
typedef void *AtomicSymbolCreator;
/*! \brief handle to a symbol that can be bind as operator */
typedef void *SymbolHandle;
/*! \brief handle to a AtomicSymbol */
typedef void *AtomicSymbolHandle;
/*! \brief handle to Graph */
typedef void *GraphHandle;
/*!
* \brief return str message of the last error
......@@ -71,7 +71,7 @@ NNVM_DLL int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type = NULL);
const char **return_type);
/*!
* \brief Create an AtomicSymbol functor.
* \param creator the AtomicSymbolCreator
......@@ -123,7 +123,18 @@ NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str);
/*!
* \brief Get string attribute from symbol
* \param symbol the source symbol
* \param key The key of the symbol.
* \param out The result attribute, can be NULL if the attribute do not exist.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int *success);
/*!
* \brief Set string attribute from symbol.
* NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph.
......@@ -216,4 +227,59 @@ NNVM_DLL int NNSymbolCompose(SymbolHandle sym,
const char** keys,
SymbolHandle* args);
// Graph IR API
/*!
* \brief create a graph handle from symbol
* \param symbol The symbol representing the graph.
* \param graph The graph handle created.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph);
/*!
* \brief free the graph handle
* \param handle The handle to be freed.
*/
NNVM_DLL int NNGraphFree(GraphHandle handle);
/*!
* \brief Get a new symbol from the graph.
* \param graph The graph handle.
* \param symbol The corresponding symbol
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
/*!
* \brief Get Set a std::string typed attribute to graph.
* \param handle The graph handle.
* \param key The key to the attribute.
* \param value The value to be exposed.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphSetStrAttr(GraphHandle handle,
const char* key,
const char* value);
/*!
* \brief Get Set a std::string typed attribute from graph attribute.
* \param handle The graph handle.
* \param key The key to the attribute.
* \param out The result attribute, can be NULL if the attribute do not exist.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphGetStrAttr(SymbolHandle handle,
const char* key,
const char** out,
int *success);
/*!
* \brief Apply pass on the src graph.
* \param src The source graph handle.
* \param num_pass The number of pass to be applied.
* \param pass_names The names of the pass.
* \param dst The result graph.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphApplyPass(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst);
#endif // NNVM_C_API_H_
......@@ -323,10 +323,10 @@ inline Op& Op::attr( // NOLINT(*)
vec.resize(index_ + 1,
std::make_pair(ValueType(), 0));
std::pair<ValueType, int>& p = vec[index_];
CHECK(p.second == 0 || p.first == value)
CHECK(p.second == 0)
<< "Attribute " << attr_name
<< " of operator " << this->name
<< " is already registered to a different value";
<< " is already registered.";
vec[index_] = std::make_pair(value, 1);
});
return *this;
......
......@@ -112,6 +112,15 @@ class Symbol {
*/
void SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs);
/*!
* \brief Get attributes from the symbol.
* This only works for symbol with outputs from single operators.
* For grouped sybmbol, an error will be raised.
* \param key Key of the attribute. When key == "name", it returns the name attirbute.
* \param out the output value of the attribute.
* \return true if the attribute exists, false if the attribute do not exist.
*/
bool GetAttr(const std::string& key, std::string* out) const;
/*!
* \brief Get attribute dictionary from the symbol.
* For grouped sybmbol, an error will be raised.
* \param option If recursive is set, the attributes of all children are retrieved,
......
#!/usr/bin/env python
# coding: utf-8
"""NNVM python API for ease of use and help new framework establish python API. """
from __future__ import absolute_import
from . import base
from . import symbol as sym
from . import symbol
__version__ = base.__version__
# coding: utf-8
"""Attribute scoping support for symbolic API."""
from __future__ import absolute_import
from .base import string_types
class AttrScope(object):
"""Attribute manager for scoping.
User can also inherit this object to change naming behavior.
Parameters
----------
kwargs
The attributes to set for all symbol creations in the scope.
"""
current = None
def __init__(self, **kwargs):
self._old_scope = None
for value in kwargs.values():
if not isinstance(value, string_types):
raise ValueError("Attributes need to be string")
self._attr = kwargs
def get(self, attr):
"""
Get the attribute dict given the attribute set by the symbol.
Parameters
----------
attr : dict of string to string
The attribute passed in by user during symbol creation.
Returns
-------
attr : dict of string to string
Updated attributes to add other scope related attributes.
"""
if self._attr:
ret = self._attr.copy()
if attr:
ret.update(attr)
return ret
else:
return attr
def __enter__(self):
# pylint: disable=protected-access
self._old_scope = AttrScope.current
attr = AttrScope.current._attr.copy()
attr.update(self._attr)
self._attr = attr
AttrScope.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
AttrScope.current = self._old_scope
AttrScope.current = AttrScope()
# coding: utf-8
# pylint: disable=invalid-name
""" ctypes library of nnvm and helper functions """
from __future__ import absolute_import
import sys
import ctypes
import numpy as np
from . import libinfo
__all__ = ['NNNetError']
#----------------------------
# library loading
#----------------------------
if sys.version_info[0] == 3:
string_types = str,
numeric_types = (float, int, np.float32, np.int32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
else:
string_types = basestring,
numeric_types = (float, int, long, np.float32, np.int32)
py_str = lambda x: x
class NNVMError(Exception):
"""Error that will be throwed by all nnvm functions"""
pass
def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
lib = ctypes.cdll.LoadLibrary(lib_path[0])
# DMatrix functions
lib.NNGetLastError.restype = ctypes.c_char_p
return lib
# version number
__version__ = libinfo.__version__
# library instance of nnvm
_LIB = _load_lib()
# type definitions
nn_uint = ctypes.c_uint
SymbolCreatorHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p
GraphHandle = ctypes.c_void_p
#----------------------------
# helper function definition
#----------------------------
def check_call(ret):
"""Check the return value of C API call
This function will raise exception when error occurs.
Wrap every API call with this function
Parameters
----------
ret : int
return value from API calls
"""
if ret != 0:
raise NNVMError(py_str(_LIB.NNGetLastError()))
def c_str(string):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Create ctypes array from a python array
Parameters
----------
ctype : ctypes data type
data type of the array we want to convert to
values : tuple or list
data content
Returns
-------
out : ctypes array
Created ctypes array
"""
return (ctype * len(values))(*values)
def ctypes2buffer(cptr, length):
"""Convert ctypes pointer to buffer type.
Parameters
----------
cptr : ctypes.POINTER(ctypes.c_char)
pointer to the raw memory region
length : int
the length of the buffer
Returns
-------
buffer : bytearray
The raw byte memory buffer
"""
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
raise TypeError('expected char pointer')
res = bytearray(length)
rptr = (ctypes.c_char * length).from_buffer(res)
if not ctypes.memmove(rptr, cptr, length):
raise RuntimeError('memmove failed')
return res
def ctypes2numpy_shared(cptr, shape):
"""Convert a ctypes pointer to a numpy array
The result numpy array shares the memory with the pointer
Parameters
----------
cptr : ctypes.POINTER(mx_float)
pointer to the memory region
shape : tuple
shape of target ndarray
Returns
-------
out : numpy_array
A numpy array : numpy array
"""
if not isinstance(cptr, ctypes.POINTER(mx_float)):
raise RuntimeError('expected float pointer')
size = 1
for s in shape:
size *= s
dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)
def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.
num_args : nn_uint
Number of arguments.
arg_names : ctypes.POINTER(ctypes.c_char_p)
Argument names.
arg_types : ctypes.POINTER(ctypes.c_char_p)
Argument type information.
arg_descs : ctypes.POINTER(ctypes.c_char_p)
Argument description information.
remove_dup : boolean, optional
Whether remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for i in range(num_args.value):
key = py_str(arg_names[i])
if key in param_keys and remove_dup:
continue
param_keys.add(key)
type_info = py_str(arg_types[i])
ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
"""Symbolic configuration API."""
from __future__ import absolute_import as _abs
import ctypes
import sys
from .base import _LIB
from .base import c_array, c_str, nn_uint, py_str, string_types
from .base import GraphHandle, SymbolHandle
from .base import check_call
from .symbol import Symbol
class Graph(object):
"""Graph is the graph object that can be used to apply optimization pass.
It contains additional graphwise attribute besides the internal symbol.
"""
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : GraphHandle
the handle to the underlying C++ Graph
"""
self.handle = handle
def __del__(self):
check_call(_LIB.NNGraphFree(self.handle))
def attr(self, key):
"""Get attribute string from the graph.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
ret = ctypes.c_char_p()
success = ctypes.c_int()
check_call(_LIB.NNGraphGetStrAttr(
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
return py_str(ret.value)
else:
return None
def _set_attr(self, **kwargs):
"""Set the attribute of the symbol.
Parameters
----------
**kwargs
The attributes to set
"""
for k, v in kwargs.items():
check_call(_LIB.NNGraphSetStrAttr(
self.handle, c_str(k), c_str(v)))
@property
def symbol(self):
shandle = SymbolHandle()
check_call(_LIB.NNGraphGetSymbol(self.handle, ctypes.byref(shandle)))
return Symbol(shandle)
def apply(self, passes):
"""Apply passes to the graph
Parameters
----------
"""
if isinstance(passes, string_types):
passes = [passes]
cpass = c_array(ctypes.c_char_p, [c_str(key) for key in passes])
ghandle = GraphHandle()
npass = nn_uint(len(passes))
check_call(_LIB.NNGraphApplyPass(self.handle, npass, cpass, ctypes.byref(ghandle)))
return Graph(ghandle)
def create(symbol):
"""Create a new graph from symbol.
Parameters
----------
symbol : Symbol
The symbolic graph used to create Graph object.
Returns
-------
graph : Graph
A generated new graph object.
"""
ghandle = GraphHandle()
check_call(_LIB.NNGraphCreate(
symbol.handle, ctypes.byref(ghandle)))
return Graph(ghandle)
# coding: utf-8
"""Information about nnvm."""
from __future__ import absolute_import
import os
import platform
def find_lib_path():
"""Find NNNet dynamic library files.
Returns
-------
lib_path : list(string)
List of all found path to the libraries
"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
api_path = os.path.join(curr_path, '../../lib/')
cmake_build_path = os.path.join(curr_path, '../../build/Release/')
dll_path = [curr_path, api_path, cmake_build_path]
if os.name == 'nt':
vs_configuration = 'Release'
if platform.architecture()[0] == '64bit':
dll_path.append(os.path.join(curr_path, '../../build', vs_configuration))
dll_path.append(os.path.join(curr_path, '../../windows/x64', vs_configuration))
else:
dll_path.append(os.path.join(curr_path, '../../build', vs_configuration))
dll_path.append(os.path.join(curr_path, '../../windows', vs_configuration))
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
if os.name == 'nt':
dll_path = [os.path.join(p, 'libnnvm.dll') for p in dll_path]
else:
dll_path = [os.path.join(p, 'libnnvm.so') for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_path) == 0:
raise RuntimeError('Cannot find the files.\n' +
'List of candidates:\n' + str('\n'.join(dll_path)))
return lib_path
# current version
__version__ = "0.7.0"
# coding: utf-8
"""Automatic naming support for symbolic API."""
from __future__ import absolute_import
class NameManager(object):
"""NameManager to do automatic naming.
User can also inherit this object to change naming behavior.
"""
current = None
def __init__(self):
self._counter = {}
self._old_manager = None
def get(self, name, hint):
"""Get the canonical name for a symbol.
This is default implementation.
When user specified a name,
the user specified name will be used.
When user did not, we will automatically generate a
name based on hint string.
Parameters
----------
name : str or None
The name user specified.
hint : str
A hint string, which can be used to generate name.
Returns
-------
full_name : str
A canonical name for the user.
"""
if name:
return name
if hint not in self._counter:
self._counter[hint] = 0
name = '%s%d' % (hint, self._counter[hint])
self._counter[hint] += 1
return name
def __enter__(self):
self._old_manager = NameManager.current
NameManager.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_manager
NameManager.current = self._old_manager
class Prefix(NameManager):
"""A name manager that always attach a prefix to all names.
Examples
--------
>>> import nnvm as nn
>>> data = nn.symbol.Variable('data')
>>> with nn.name.Prefix('mynet_'):
net = nn.symbol.FullyConnected(data, num_hidden=10, name='fc1')
>>> net.list_arguments()
['data', 'mynet_fc1_weight', 'mynet_fc1_bias']
"""
def __init__(self, prefix):
super(Prefix, self).__init__()
self._prefix = prefix
def get(self, name, hint):
name = super(Prefix, self).get(name, hint)
return self._prefix + name
# initialize the default name manager
NameManager.current = NameManager()
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
"""Symbolic configuration API."""
from __future__ import absolute_import as _abs
import copy
import ctypes
import sys
from .base import _LIB
from .base import c_array, c_str, nn_uint, py_str, string_types
from .base import SymbolHandle
from .base import check_call, ctypes2docstring
from .name import NameManager
from .attribute import AttrScope
__all__ = ["Symbol", "Variable"]
class Symbol(object):
"""Symbol is symbolic graph."""
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
self.handle = handle
def __del__(self):
check_call(_LIB.NNSymbolFree(self.handle))
def __copy__(self):
return copy.deepcopy(self)
def __deepcopy__(self, _):
handle = SymbolHandle()
check_call(_LIB.NNSymbolCopy(self.handle,
ctypes.byref(handle)))
return Symbol(handle)
def __call__(self, *args, **kwargs):
"""Invoke symbol as function on inputs.
Parameters
----------
args:
provide positional arguments
kwargs:
provide keyword arguments
Returns
-------
the resulting symbol
"""
s = copy.deepcopy(self)
s._compose(*args, **kwargs)
return s
def _compose(self, *args, **kwargs):
"""Compose symbol on inputs.
This call mutates the current symbol.
Parameters
----------
args:
provide positional arguments
kwargs:
provide keyword arguments
Returns
-------
the resulting symbol
"""
name = kwargs.pop('name', None)
if name:
name = c_str(name)
if len(args) != 0 and len(kwargs) != 0:
raise TypeError('compose only accept input Symbols \
either as positional or keyword arguments, not both')
for arg in args:
if not isinstance(arg, Symbol):
raise TypeError('Compose expect `Symbol` as arguments')
for val in kwargs.values():
if not isinstance(val, Symbol):
raise TypeError('Compose expect `Symbol` as arguments')
num_args = len(args) + len(kwargs)
if len(kwargs) != 0:
keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()])
args = c_array(SymbolHandle, [s.handle for s in kwargs.values()])
else:
keys = None
args = c_array(SymbolHandle, [s.handle for s in args])
check_call(_LIB.NNSymbolCompose(
self.handle, name, num_args, keys, args))
def __getitem__(self, index):
if isinstance(index, string_types):
idx = None
for i, name in enumerate(self.list_outputs()):
if name == index:
if idx is not None:
raise ValueError('There are multiple outputs with name \"%s\"' % index)
idx = i
if idx is None:
raise ValueError('Cannot find output that matches name \"%s\"' % index)
index = idx
if not isinstance(index, int):
raise TypeError('Symbol only support integer index to fetch i-th output')
handle = SymbolHandle()
check_call(_LIB.NNSymbolGetOutput(
self.handle, nn_uint(index), ctypes.byref(handle)))
return Symbol(handle=handle)
def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
ret = ctypes.c_char_p()
success = ctypes.c_int()
check_call(_LIB.NNSymbolGetAttr(
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
return py_str(ret.value)
else:
return None
def list_attr(self, recursive=False):
"""Get all attributes from the symbol.
Parameters
----------
recursive : bool
Default `False`. When `recursive` is `True`, list recursively all the
attributes in the descendents. The attribute names are pre-pended with
the symbol names to avoid conflicts. If `False`, then only attributes
that belongs to this symbol is returned, and the attribute names will
**not** be pre-pended with the symbol name.
"""
size = nn_uint()
pairs = ctypes.POINTER(ctypes.c_char_p)()
option = ctypes.c_int(0) if recursive else ctypes.c_int(1)
check_call(_LIB.NNSymbolListAttrs(
self.handle, option, ctypes.byref(size), ctypes.byref(pairs)))
return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size.value)}
def _set_attr(self, **kwargs):
"""Set the attribute of the symbol.
Parameters
----------
**kwargs
The attributes to set
"""
keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()])
vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()])
num_args = nn_uint(len(kwargs))
check_call(_LIB.NNSymbolSetAttrs(
self.handle, num_args, keys, vals))
def get_internals(self):
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.
Returns
-------
sgroup : Symbol
The internal of the symbol.
"""
handle = SymbolHandle()
check_call(_LIB.NNSymbolGetInternals(
self.handle, ctypes.byref(handle)))
return Symbol(handle=handle)
def list_arguments(self):
"""List all the arguments in the symbol.
Returns
-------
args : list of string
List of all the arguments.
"""
size = ctypes.c_uint()
sarr = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.NNSymbolListArguments(
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]
def list_outputs(self):
"""List all outputs in the symbol.
Returns
-------
returns : list of string
List of all the outputs.
"""
size = ctypes.c_uint()
sarr = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.NNSymbolListOutputs(
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]
def debug_str(self):
"""Get a debug string.
Returns
-------
debug_str : string
Debug string of the symbol.
"""
debug_str = ctypes.c_char_p()
check_call(_LIB.NNSymbolPrint(
self.handle, ctypes.byref(debug_str)))
return py_str(debug_str.value)
def Variable(name, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
kwargs : dict of string -> string
Additional attributes to set on the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
if not isinstance(name, string_types):
raise TypeError('Expect a string for variable `name`')
handle = SymbolHandle()
check_call(_LIB.NNSymbolCreateVariable(c_str(name), ctypes.byref(handle)))
ret = Symbol(handle)
attr = AttrScope.current.get(kwargs)
if attr:
ret._set_attr(**attr)
return ret
def Group(symbols):
"""Create a symbol that groups symbols together.
Parameters
----------
symbols : list
List of symbols to be grouped.
Returns
-------
sym : Symbol
The created group symbol.
"""
ihandles = []
for sym in symbols:
if not isinstance(sym, Symbol):
raise TypeError('Expect Symbols in the list input')
ihandles.append(sym.handle)
handle = SymbolHandle()
check_call(_LIB.NNSymbolCreateGroup(
nn_uint(len(ihandles)),
c_array(SymbolHandle, ihandles), ctypes.byref(handle)))
return Symbol(handle)
def _make_atomic_symbol_function(handle):
"""Create an atomic symbol function by handle and funciton name."""
name = ctypes.c_char_p()
desc = ctypes.c_char_p()
num_args = nn_uint()
arg_names = ctypes.POINTER(ctypes.c_char_p)()
arg_types = ctypes.POINTER(ctypes.c_char_p)()
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
ret_type = ctypes.c_char_p()
check_call(_LIB.NNSymbolGetAtomicSymbolInfo(
handle, ctypes.byref(name), ctypes.byref(desc),
ctypes.byref(num_args),
ctypes.byref(arg_names),
ctypes.byref(arg_types),
ctypes.byref(arg_descs),
ctypes.byref(ret_type)))
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)
func_name = py_str(name.value)
desc = py_str(desc.value)
doc_str = ('%s\n\n' +
'%s\n' +
'name : string, optional.\n' +
' Name of the resulting symbol.\n\n' +
'Returns\n' +
'-------\n' +
'symbol: Symbol\n' +
' The result symbol.')
doc_str = doc_str % (desc, param_str)
def creator(*args, **kwargs):
"""Activation Operator of Neural Net.
The parameters listed below can be passed in as keyword arguments.
Parameters
----------
name : string, required.
Name of the resulting symbol.
Returns
-------
symbol: Symbol
the resulting symbol
"""
param_keys = []
param_vals = []
symbol_kwargs = {}
name = kwargs.pop('name', None)
attr = kwargs.pop('attr', None)
for k, v in kwargs.items():
if isinstance(v, Symbol):
symbol_kwargs[k] = v
else:
param_keys.append(c_str(k))
param_vals.append(c_str(str(v)))
# create atomic symbol
param_keys = c_array(ctypes.c_char_p, param_keys)
param_vals = c_array(ctypes.c_char_p, param_vals)
sym_handle = SymbolHandle()
check_call(_LIB.NNSymbolCreateAtomicSymbol(
handle,
nn_uint(len(param_keys)),
param_keys, param_vals,
ctypes.byref(sym_handle)))
if len(args) != 0 and len(symbol_kwargs) != 0:
raise TypeError(
'%s can only accept input'
'Symbols either as positional or keyword arguments, not both' % func_name)
s = Symbol(sym_handle)
attr = AttrScope.current.get(attr)
if attr:
s._set_attr(**attr)
hint = func_name.lower()
name = NameManager.current.get(name, hint)
s._compose(*args, name=name, **symbol_kwargs)
return s
creator.__name__ = func_name
creator.__doc__ = doc_str
return creator
def _init_symbol_module():
"""List and add all the atomic symbol functions to current module."""
plist = ctypes.POINTER(ctypes.c_void_p)()
size = ctypes.c_uint()
check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size),
ctypes.byref(plist)))
module_obj = sys.modules[__name__]
for i in range(size.value):
hdl = SymbolHandle(plist[i])
function = _make_atomic_symbol_function(hdl)
if function.__name__.startswith('_'):
setattr(Symbol, function.__name__, staticmethod(function))
else:
setattr(module_obj, function.__name__, function)
# Initialize the atomic symbol in startups
_init_symbol_module()
/*!
* Copyright (c) 2016 by Contributors
* \file c_api_graph.cc
* \brief C API related to Graph IR.
*/
#include <nnvm/c_api.h>
#include <nnvm/op.h>
#include <nnvm/symbolic.h>
#include <nnvm/graph.h>
#include <nnvm/pass.h>
#include "./c_api_common.h"
using namespace nnvm;
int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph) {
Graph* g = new Graph();
API_BEGIN();
g->outputs = static_cast<Symbol*>(symbol)->outputs;
*graph = g;
API_END_HANDLE_ERROR(delete g);
}
int NNGraphFree(GraphHandle handle) {
API_BEGIN();
delete static_cast<Graph*>(handle);
API_END();
}
int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
Symbol* s = new Symbol();
API_BEGIN();
s->outputs = static_cast<Graph*>(graph)->outputs;
*symbol = s;
API_END_HANDLE_ERROR(delete s);
}
int NNGraphSetStrAttr(GraphHandle handle,
const char* key,
const char* value) {
API_BEGIN();
Graph* g = static_cast<Graph*>(handle);
g->attrs[std::string(key)] = std::make_shared<any>(std::string(value));
API_END();
}
int NNGraphGetStrAttr(GraphHandle handle,
const char* key,
const char** out,
int *success) {
API_BEGIN();
Graph* g = static_cast<Graph*>(handle);
std::string skey(key);
auto it = g->attrs.find(skey);
if (it != g->attrs.end()) {
const std::string& str = nnvm::get<std::string>(*it->second.get());
*out = str.c_str();
*success = 1;
} else {
*success = 0;
}
API_END();
}
int NNGraphApplyPass(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst) {
Graph* g = new Graph();
API_BEGIN();
std::vector<std::string> vpass;
for (nn_uint i = 0; i < num_pass; ++i) {
vpass.emplace_back(std::string(pass_names[i]));
}
*g = ApplyPass(*static_cast<Graph*>(src), vpass);
*dst = g;
API_END_HANDLE_ERROR(delete g);
}
......@@ -3,7 +3,6 @@
* \file c_api_symbolic.cc
* \brief C API related to symbolic graph compsition.
*/
#include <dmlc/logging.h>
#include <nnvm/c_api.h>
#include <nnvm/op.h>
#include <nnvm/symbolic.h>
......@@ -123,7 +122,24 @@ int NNSymbolPrint(SymbolHandle symbol, const char **out_str) {
API_END();
}
int MXSymbolSetAttrs(SymbolHandle symbol,
int NNSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int* success) {
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
if (s->GetAttr(key, &(ret->ret_str))) {
*out = (ret->ret_str).c_str();
*success = 1;
} else {
*out = nullptr;
*success = 0;
}
API_END();
}
int NNSymbolSetAttrs(SymbolHandle symbol,
nn_uint num_param,
const char** keys,
const char** vals) {
......
......@@ -196,8 +196,9 @@ void Symbol::Compose(const std::vector<Symbol>& args,
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
FListInputNames fn = flist_inputs.get(n->op, nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
CHECK_EQ(arg_names.size(), n_req);
if (arg_names.size() != n_req) {
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op->name;
}
size_t nmatched = 0;
for (size_t i = args.size(); i < n_req; ++i) {
auto it = kwargs.find(arg_names[i]);
......@@ -311,15 +312,37 @@ Symbol Symbol::GetInternals() const {
}
void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs) {
CHECK_EQ(outputs.size(), 1)
<< "SetAttrs only works for nongrouped symbol";
Node* n = outputs[0].node.get();
Node* node = outputs[0].node.get();
for (const NodeEntry& e : outputs) {
CHECK(node == e.node.get())
<< "Symbol.SetAttrs only works for non-grouped symbol";
}
for (const auto& kv : attrs) {
n->attrs.dict[kv.first] = kv.second;
if (kv.first == "name") {
node->attrs.name = kv.second;
} else {
node->attrs.dict[kv.first] = kv.second;
}
}
if (n->op->attr_parser != nullptr) {
(*n->op->attr_parser)(&(n->attrs));
if (node->op != nullptr && node->op->attr_parser != nullptr) {
(*node->op->attr_parser)(&(node->attrs));
}
}
bool Symbol::GetAttr(const std::string& key, std::string* out) const {
Node* node = outputs[0].node.get();
for (const NodeEntry& e : outputs) {
CHECK(node == e.node.get())
<< "Symbol.SetAttrs only works for non-grouped symbol";
}
if (key == "name") {
*out = node->attrs.name;
return true;
}
auto it = node->attrs.dict.find(key);
if (it == node->attrs.dict.end()) return false;
*out = it->second;
return true;
}
std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption option) const {
......
......@@ -2,13 +2,30 @@
// This is an example on how we can register operator information to NNVM
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include <utility>
using nnvm::FListInputNames;
using nnvm::NodeAttrs;
NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.attr("inplace_pair", std::make_pair(0, 0));
NNVM_REGISTER_OP(exp)
.describe("take exponmential")
.set_num_inputs(1)
.attr("inplace_pair", std::make_pair(0, 0));
NNVM_REGISTER_OP(conv2d)
.describe("take conv of input")
.set_num_inputs(2)
.attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "weight"};
});
NNVM_REGISTER_OP(add)
.attr<std::string>("nick_name", "plus");
......@@ -143,6 +143,7 @@ Graph LoadJSON(const Graph& src) {
for (uint32_t nid : jgraph.arg_nodes) {
CHECK(jgraph.nodes[nid].node->is_variable());
}
// return the graph
Graph ret;
ret.attrs = std::move(jgraph.attrs);
......@@ -177,6 +178,10 @@ Graph SaveJSON(const Graph& src) {
jgraph.nodes.emplace_back(std::move(jnode));
});
for (const NodeEntry& e : src.outputs) {
jgraph.heads.push_back(std::make_pair(node2index.at(e.node.get()), e.index));
}
std::ostringstream os;
dmlc::JSONWriter writer(&os);
jgraph.Save(&writer);
......
import nnvm.symbol as sym
import nnvm.graph as graph
def test_json_pass():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv', stride=(2,2))
g = graph.create(y)
ret = g.apply('SaveJSON')
g2 = ret.apply('LoadJSON')
assert g2.apply('SaveJSON').attr('json') == ret.attr('json')
if __name__ == "__main__":
test_json_pass()
import nnvm.symbol as sym
from nnvm.base import NNVMError
def test_compose():
x = sym.Variable('x')
z = sym.Variable('z')
y = sym.exp(sym.add(x, x, name='add', gpu=2),
name='exp', gpu=1, attr={"kk": "1"})
assert y.list_arguments() == ['x']
assert y.list_outputs() == ["exp_output"]
assert y.list_attr()['gpu'] == '1'
z = y.get_internals()
assert z['add_output'].list_outputs() == ['add_output']
assert y.list_attr(recursive=True)['add_gpu'] == '2'
def test_default_input():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv')
assert y.list_arguments() == ['x', 'conv_weight']
try:
z = sym.add(x)
assert False
except NNVMError:
pass
if __name__ == "__main__":
test_default_input()
test_compose()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment