Commit c1e48e1a by Tianqi Chen

[CYTHON] Make speedup component minimum (#13)

parent 31f9fc0a
"""Module space to register internal functions. Leave empty"""
...@@ -13,11 +13,9 @@ from .._base import check_call, ctypes2docstring ...@@ -13,11 +13,9 @@ from .._base import check_call, ctypes2docstring
from ..name import NameManager from ..name import NameManager
from ..attribute import AttrScope from ..attribute import AttrScope
__all__ = ["Symbol", "Variable"] class SymbolBase(object):
class Symbol(object):
"""Symbol is symbolic graph.""" """Symbol is symbolic graph."""
__slots__ = ["handle"]
# pylint: disable=no-member # pylint: disable=no-member
def __init__(self, handle): def __init__(self, handle):
"""Initialize the function with handle """Initialize the function with handle
...@@ -32,15 +30,6 @@ class Symbol(object): ...@@ -32,15 +30,6 @@ class Symbol(object):
def __del__(self): def __del__(self):
check_call(_LIB.NNSymbolFree(self.handle)) check_call(_LIB.NNSymbolFree(self.handle))
def __copy__(self):
return copy.deepcopy(self)
def __deepcopy__(self, _):
handle = SymbolHandle()
check_call(_LIB.NNSymbolCopy(self.handle,
ctypes.byref(handle)))
return Symbol(handle)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
"""Invoke symbol as function on inputs. """Invoke symbol as function on inputs.
...@@ -85,10 +74,10 @@ class Symbol(object): ...@@ -85,10 +74,10 @@ class Symbol(object):
either as positional or keyword arguments, not both') either as positional or keyword arguments, not both')
for arg in args: for arg in args:
if not isinstance(arg, Symbol): if not isinstance(arg, SymbolBase):
raise TypeError('Compose expect `Symbol` as arguments') raise TypeError('Compose expect `Symbol` as arguments')
for val in kwargs.values(): for val in kwargs.values():
if not isinstance(val, Symbol): if not isinstance(val, SymbolBase):
raise TypeError('Compose expect `Symbol` as arguments') raise TypeError('Compose expect `Symbol` as arguments')
num_args = len(args) + len(kwargs) num_args = len(args) + len(kwargs)
...@@ -101,65 +90,6 @@ class Symbol(object): ...@@ -101,65 +90,6 @@ class Symbol(object):
check_call(_LIB.NNSymbolCompose( check_call(_LIB.NNSymbolCompose(
self.handle, name, num_args, keys, args)) self.handle, name, num_args, keys, args))
def __getitem__(self, index):
if isinstance(index, string_types):
idx = None
for i, name in enumerate(self.list_outputs()):
if name == index:
if idx is not None:
raise ValueError('There are multiple outputs with name \"%s\"' % index)
idx = i
if idx is None:
raise ValueError('Cannot find output that matches name \"%s\"' % index)
index = idx
if not isinstance(index, int):
raise TypeError('Symbol only support integer index to fetch i-th output')
handle = SymbolHandle()
check_call(_LIB.NNSymbolGetOutput(
self.handle, nn_uint(index), ctypes.byref(handle)))
return Symbol(handle=handle)
def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
ret = ctypes.c_char_p()
success = ctypes.c_int()
check_call(_LIB.NNSymbolGetAttr(
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
return py_str(ret.value)
else:
return None
def list_attr(self, recursive=False):
"""Get all attributes from the symbol.
Parameters
----------
recursive : bool
Default `False`. When `recursive` is `True`, list recursively all the
attributes in the descendents. The attribute names are pre-pended with
the symbol names to avoid conflicts. If `False`, then only attributes
that belongs to this symbol is returned, and the attribute names will
**not** be pre-pended with the symbol name.
"""
size = nn_uint()
pairs = ctypes.POINTER(ctypes.c_char_p)()
option = ctypes.c_int(0) if recursive else ctypes.c_int(1)
check_call(_LIB.NNSymbolListAttrs(
self.handle, option, ctypes.byref(size), ctypes.byref(pairs)))
return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size.value)}
def _set_attr(self, **kwargs): def _set_attr(self, **kwargs):
"""Set the attribute of the symbol. """Set the attribute of the symbol.
...@@ -168,116 +98,20 @@ class Symbol(object): ...@@ -168,116 +98,20 @@ class Symbol(object):
**kwargs **kwargs
The attributes to set The attributes to set
""" """
keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) keys = _base.c_array(_ctypes.c_char_p,
vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]) [_base.c_str(key) for key in kwargs.keys()])
num_args = nn_uint(len(kwargs)) vals = _base.c_array(_ctypes.c_char_p,
check_call(_LIB.NNSymbolSetAttrs( [_base.c_str(str(val)) for val in kwargs.values()])
num_args = _base.nn_uint(len(kwargs))
_check_call(_LIB.NNSymbolSetAttrs(
self.handle, num_args, keys, vals)) self.handle, num_args, keys, vals))
def get_internals(self):
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.
Returns
-------
sgroup : Symbol
The internal of the symbol.
"""
handle = SymbolHandle()
check_call(_LIB.NNSymbolGetInternals(
self.handle, ctypes.byref(handle)))
return Symbol(handle=handle)
def list_arguments(self):
"""List all the arguments in the symbol.
Returns
-------
args : list of string
List of all the arguments.
"""
size = ctypes.c_uint()
sarr = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.NNSymbolListArguments(
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]
def list_outputs(self): _symbol_cls = SymbolBase
"""List all outputs in the symbol.
Returns def _set_symbol_class(cls):
------- global _symbol_cls
returns : list of string _symbol_cls = cls
List of all the outputs.
"""
size = ctypes.c_uint()
sarr = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.NNSymbolListOutputs(
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]
def debug_str(self):
"""Get a debug string.
Returns
-------
debug_str : string
Debug string of the symbol.
"""
debug_str = ctypes.c_char_p()
check_call(_LIB.NNSymbolPrint(
self.handle, ctypes.byref(debug_str)))
return py_str(debug_str.value)
def Variable(name, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
kwargs : dict of string -> string
Additional attributes to set on the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
if not isinstance(name, string_types):
raise TypeError('Expect a string for variable `name`')
handle = SymbolHandle()
check_call(_LIB.NNSymbolCreateVariable(c_str(name), ctypes.byref(handle)))
ret = Symbol(handle)
attr = AttrScope.current.get(kwargs)
if attr:
ret._set_attr(**attr)
return ret
def Group(symbols):
"""Create a symbol that groups symbols together.
Parameters
----------
symbols : list
List of symbols to be grouped.
Returns
-------
sym : Symbol
The created group symbol.
"""
ihandles = []
for sym in symbols:
if not isinstance(sym, Symbol):
raise TypeError('Expect Symbols in the list input')
ihandles.append(sym.handle)
handle = SymbolHandle()
check_call(_LIB.NNSymbolCreateGroup(
nn_uint(len(ihandles)),
c_array(SymbolHandle, ihandles), ctypes.byref(handle)))
return Symbol(handle)
def _make_atomic_symbol_function(handle): def _make_atomic_symbol_function(handle):
...@@ -332,7 +166,7 @@ def _make_atomic_symbol_function(handle): ...@@ -332,7 +166,7 @@ def _make_atomic_symbol_function(handle):
attr = kwargs.pop('attr', None) attr = kwargs.pop('attr', None)
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, Symbol): if isinstance(v, SymbolBase):
symbol_kwargs[k] = v symbol_kwargs[k] = v
else: else:
param_keys.append(c_str(k)) param_keys.append(c_str(k))
...@@ -351,7 +185,7 @@ def _make_atomic_symbol_function(handle): ...@@ -351,7 +185,7 @@ def _make_atomic_symbol_function(handle):
raise TypeError( raise TypeError(
'%s can only accept input' '%s can only accept input'
'Symbols either as positional or keyword arguments, not both' % func_name) 'Symbols either as positional or keyword arguments, not both' % func_name)
s = Symbol(sym_handle) s = _symbol_cls(sym_handle)
attr = AttrScope.current.get(attr) attr = AttrScope.current.get(attr)
if attr: if attr:
s._set_attr(**attr) s._set_attr(**attr)
...@@ -373,11 +207,12 @@ def _init_symbol_module(): ...@@ -373,11 +207,12 @@ def _init_symbol_module():
check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size), check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size),
ctypes.byref(plist))) ctypes.byref(plist)))
module_obj = sys.modules["nnvm.symbol"] module_obj = sys.modules["nnvm.symbol"]
module_internal = sys.modules["nnvm._symbol_internal"]
for i in range(size.value): for i in range(size.value):
hdl = SymbolHandle(plist[i]) hdl = SymbolHandle(plist[i])
function = _make_atomic_symbol_function(hdl) function = _make_atomic_symbol_function(hdl)
if function.__name__.startswith('_'): if function.__name__.startswith('_'):
setattr(Symbol, function.__name__, staticmethod(function)) setattr(module_internal, function.__name__, function)
else: else:
setattr(module_obj, function.__name__, function) setattr(module_obj, function.__name__, function)
......
...@@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs ...@@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs
import sys as _sys import sys as _sys
import ctypes as _ctypes import ctypes as _ctypes
from numbers import Number as _Number
from .._base import NNVMError from .._base import NNVMError
from ..name import NameManager from ..name import NameManager
from ..attribute import AttrScope from ..attribute import AttrScope
...@@ -64,8 +65,7 @@ cdef extern from "nnvm/c_api.h": ...@@ -64,8 +65,7 @@ cdef extern from "nnvm/c_api.h":
const char** keys, const char** keys,
SymbolHandle* args); SymbolHandle* args);
cdef class SymbolBase:
cdef class Symbol:
"""Symbol is symbolic graph.""" """Symbol is symbolic graph."""
# handle for symbolic operator. # handle for symbolic operator.
cdef SymbolHandle handle cdef SymbolHandle handle
...@@ -85,76 +85,6 @@ cdef class Symbol: ...@@ -85,76 +85,6 @@ cdef class Symbol:
def handle(self): def handle(self):
return _ctypes.cast(<unsigned long>self.handle, _ctypes.c_void_p) return _ctypes.cast(<unsigned long>self.handle, _ctypes.c_void_p)
def __copy__(self):
return self.__deepcopy__()
def __deepcopy__(self, _ = None):
cdef SymbolHandle handle
CALL(NNSymbolCopy(self.handle, &handle))
return NewSymbol(handle)
def __getitem__(self, index):
if isinstance(index, str):
idx = None
for i, name in enumerate(self.list_outputs()):
if name == index:
if idx is not None:
raise ValueError('There are multiple outputs with name \"%s\"' % index)
idx = i
if idx is None:
raise ValueError('Cannot find output that matches name \"%s\"' % index)
index = idx
if not isinstance(index, int):
raise TypeError('Symbol only support integer index to fetch i-th output')
cdef SymbolHandle handle
cdef nn_uint c_index = index
CALL(NNSymbolGetOutput(self.handle, c_index, &handle))
return NewSymbol(handle)
def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
cdef const char* ret
cdef int success
key = c_str(key)
CALL(NNSymbolGetAttr(
self.handle, key, &ret, &success))
if success != 0:
return py_str(ret.value)
else:
return None
def list_attr(self, recursive=False):
"""Get all attributes from the symbol.
Parameters
----------
recursive : bool
Default `False`. When `recursive` is `True`, list recursively all the
attributes in the descendents. The attribute names are pre-pended with
the symbol names to avoid conflicts. If `False`, then only attributes
that belongs to this symbol is returned, and the attribute names will
**not** be pre-pended with the symbol name.
"""
cdef nn_uint size
cdef const char** pairs
cdef int option
option = 0 if recursive else 1
CALL(NNSymbolListAttrs(
self.handle, option, &size, &pairs))
return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size)}
def _set_attr(self, **kwargs): def _set_attr(self, **kwargs):
"""Set the attribute of the symbol. """Set the attribute of the symbol.
...@@ -165,49 +95,6 @@ cdef class Symbol: ...@@ -165,49 +95,6 @@ cdef class Symbol:
""" """
SymbolSetAttr(self.handle, kwargs) SymbolSetAttr(self.handle, kwargs)
def get_internals(self):
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.
Returns
-------
sgroup : Symbol
The internal of the symbol.
"""
cdef SymbolHandle handle
CALL(NNSymbolGetInternals(self.handle, &handle))
return NewSymbol(handle)
def list_arguments(self):
"""List all the arguments in the symbol.
Returns
-------
args : list of string
List of all the arguments.
"""
cdef nn_uint size
cdef const char ** sarr
CALL(NNSymbolListArguments(self.handle, &size, &sarr))
return [py_str(sarr[i]) for i in range(size)]
def list_outputs(self):
"""List all outputs in the symbol.
Returns
-------
returns : list of string
List of all the outputs.
"""
cdef nn_uint size
cdef const char ** sarr
CALL(NNSymbolListOutputs(self.handle, &size, &sarr))
return [py_str(sarr[i]) for i in range(size)]
def debug_str(self):
cdef const char* out_str
CALL(NNSymbolPrint(self.handle, &out_str))
return py_str(out_str)
cdef SymbolSetAttr(SymbolHandle handle, dict kwargs): cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
cdef vector[string] sparam_keys cdef vector[string] sparam_keys
...@@ -224,34 +111,18 @@ cdef SymbolSetAttr(SymbolHandle handle, dict kwargs): ...@@ -224,34 +111,18 @@ cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals))) handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals)))
_symbol_cls = SymbolBase
def _set_symbol_class(cls):
global _symbol_cls
_symbol_cls = cls
cdef NewSymbol(SymbolHandle handle): cdef NewSymbol(SymbolHandle handle):
"""Create a new symbol given handle""" """Create a new symbol given handle"""
sym = Symbol(None) sym = _symbol_cls(None)
sym.handle = handle (<SymbolBase>sym).handle = handle
return sym return sym
def Variable(name, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
kwargs : dict of string -> string
Additional attributes to set on the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
cdef SymbolHandle handle
name = c_str(name)
CALL(NNSymbolCreateVariable(name, &handle))
return NewSymbol(handle)
cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
"""Create an atomic symbol function by handle and funciton name.""" """Create an atomic symbol function by handle and funciton name."""
cdef const char *name cdef const char *name
...@@ -292,9 +163,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): ...@@ -292,9 +163,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
if len(kwargs) != 0: if len(kwargs) != 0:
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, Symbol): if isinstance(v, SymbolBase):
ssymbol_keys.push_back(c_str(k)) ssymbol_keys.push_back(c_str(k))
symbol_args.push_back((<Symbol>v).handle) symbol_args.push_back((<SymbolBase>v).handle)
else: else:
sparam_keys.push_back(c_str(k)) sparam_keys.push_back(c_str(k))
sparam_vals.push_back(c_str(str(v))) sparam_vals.push_back(c_str(str(v)))
...@@ -304,9 +175,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): ...@@ -304,9 +175,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
raise TypeError("compose only accept input Symbols\ raise TypeError("compose only accept input Symbols\
either as positional or keyword arguments, not both") either as positional or keyword arguments, not both")
for v in args: for v in args:
if not isinstance(v, Symbol): if not isinstance(v, SymbolBase):
raise TypeError('Compose expect `Symbol` as arguments') raise TypeError('Compose expect `Symbol` as arguments')
symbol_args.push_back((<Symbol>v).handle) symbol_args.push_back((<SymbolBase>v).handle)
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys) cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals) cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
...@@ -344,46 +215,20 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): ...@@ -344,46 +215,20 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
return creator return creator
def Group(symbols):
"""Create a symbol that groups symbols together.
Parameters
----------
symbols : list
List of symbols to be grouped.
Returns
-------
sym : Symbol
The created group symbol.
"""
cdef vector[SymbolHandle] ihandles
cdef SymbolHandle handle
for sym in symbols:
if not isinstance(sym, Symbol):
raise TypeError("Expect Symbols in the list input")
ihandles.push_back((<Symbol>sym).handle)
if ihandles.size() == 0:
raise ValueError("expect at least one element in the input")
CALL(NNSymbolCreateGroup(<nn_uint>ihandles.size(),
&ihandles[0], &handle))
return NewSymbol(handle)
def _init_symbol_module(): def _init_symbol_module():
"""List and add all the atomic symbol functions to current module.""" """List and add all the atomic symbol functions to current module."""
cdef AtomicSymbolCreator* plist cdef AtomicSymbolCreator* plist
cdef nn_uint size cdef nn_uint size
CALL(NNSymbolListAtomicSymbolCreators(&size, &plist)) CALL(NNSymbolListAtomicSymbolCreators(&size, &plist))
module_obj = _sys.modules["nnvm.symbol"] module_obj = _sys.modules["nnvm.symbol"]
module_internal = _sys.modules["nnvm._symbol_internal"]
for i in range(size): for i in range(size):
function = _make_atomic_symbol_function(plist[i]) function = _make_atomic_symbol_function(plist[i])
if function.__name__.startswith('_'): if function.__name__.startswith('_'):
setattr(Symbol, function.__name__, staticmethod(function)) setattr(module_internal, function.__name__, function)
else: else:
setattr(module_obj, function.__name__, function) setattr(module_obj, function.__name__, function)
# Initialize the atomic symbol in startups # Initialize the atomic symbol in startups
_init_symbol_module() _init_symbol_module()
...@@ -2,13 +2,214 @@ ...@@ -2,13 +2,214 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import sys as _sys import sys as _sys
import os as _os import os as _os
import ctypes as _ctypes
from numbers import Number as _Number
from . import _base
from ._base import _LIB, check_call as _check_call
from . import _symbol_internal as _internal
from .attribute import AttrScope
# Use different verison of SymbolBase
# When possible, use cython to speedup part of computation.
try: try:
if int(_os.environ.get("NNVM_ENABLE_CYTHON", True)) == 0: if int(_os.environ.get("NNVM_ENABLE_CYTHON", True)) == 0:
from .ctypes.symbol import Symbol, Variable from .ctypes.symbol import SymbolBase, _set_symbol_class
elif _sys.version_info >= (3, 0): elif _sys.version_info >= (3, 0):
from ._cy3.symbol import Symbol, Variable, Group from ._cy3.symbol import SymbolBase, _set_symbol_class
else: else:
from ._cy2.symbol import Symbol, Variable, Group from ._cy2.symbol import SymbolBase, _set_symbol_class
except: except:
from .ctypes.symbol import Symbol, Variable, Group from .ctypes.symbol import SymbolBase, _set_symbol_class
class Symbol(SymbolBase):
"""Symbol is basic operation unit for symbolic graph compostion."""
# disable dictionary storage, also do not have parent type.
__slots__ = []
def __add__(self, other):
if isinstance(other, Symbol):
return _internal.__add__symbol__(self, other)
elif isinstance(other, _Number):
return _internal.__add__scalar__(self, scalar=other)
else:
raise TypeError("type %s not supported" % str(type(other)))
def __copy__(self):
return self.__deepcopy__()
def __deepcopy__(self, _=None):
handle = _base.SymbolHandle()
_base.check_call(_LIB.NNSymbolCopy(self.handle,
_ctypes.byref(handle)))
return Symbol(handle)
def __getitem__(self, index):
if isinstance(index, _base.string_types):
idx = None
for i, name in enumerate(self.list_outputs()):
if name == index:
if idx is not None:
raise ValueError('There are multiple outputs with name \"%s\"' % index)
idx = i
if idx is None:
raise ValueError('Cannot find output that matches name \"%s\"' % index)
index = idx
if not isinstance(index, int):
raise TypeError('Symbol only support integer index to fetch i-th output')
handle = _base.SymbolHandle()
_check_call(_LIB.NNSymbolGetOutput(
self.handle, _base.nn_uint(index), _ctypes.byref(handle)))
return Symbol(handle=handle)
def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
ret = _ctypes.c_char_p()
success = _ctypes.c_int()
_check_call(_LIB.NNSymbolGetAttr(
self.handle, c_str(key), _ctypes.byref(ret), _ctypes.byref(success)))
if success.value != 0:
return _base.py_str(ret.value)
else:
return None
def list_attr(self, recursive=False):
"""Get all attributes from the symbol.
Parameters
----------
recursive : bool
Default `False`. When `recursive` is `True`, list recursively all the
attributes in the descendents. The attribute names are pre-pended with
the symbol names to avoid conflicts. If `False`, then only attributes
that belongs to this symbol is returned, and the attribute names will
**not** be pre-pended with the symbol name.
"""
size = _base.nn_uint()
pairs = _ctypes.POINTER(_ctypes.c_char_p)()
option = _ctypes.c_int(0) if recursive else _ctypes.c_int(1)
_check_call(_LIB.NNSymbolListAttrs(
self.handle, option, _ctypes.byref(size), _ctypes.byref(pairs)))
return {_base.py_str(pairs[i*2]): _base.py_str(pairs[i*2+1]) for i in range(size.value)}
def get_internals(self):
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.
Returns
-------
sgroup : Symbol
The internal of the symbol.
"""
handle = _base.SymbolHandle()
_check_call(_LIB.NNSymbolGetInternals(
self.handle, _ctypes.byref(handle)))
return Symbol(handle=handle)
def list_arguments(self):
"""List all the arguments in the symbol.
Returns
-------
args : list of string
List of all the arguments.
"""
size = _ctypes.c_uint()
sarr = _ctypes.POINTER(_ctypes.c_char_p)()
_check_call(_LIB.NNSymbolListArguments(
self.handle, _ctypes.byref(size), _ctypes.byref(sarr)))
return [_base.py_str(sarr[i]) for i in range(size.value)]
def list_outputs(self):
"""List all outputs in the symbol.
Returns
-------
returns : list of string
List of all the outputs.
"""
size = _ctypes.c_uint()
sarr = _ctypes.POINTER(_ctypes.c_char_p)()
_check_call(_LIB.NNSymbolListOutputs(
self.handle, _ctypes.byref(size), _ctypes.byref(sarr)))
return [_base.py_str(sarr[i]) for i in range(size.value)]
def debug_str(self):
"""Get a debug string.
Returns
-------
debug_str : string
Debug string of the symbol.
"""
debug_str = _ctypes.c_char_p()
_check_call(_LIB.NNSymbolPrint(
self.handle, _ctypes.byref(debug_str)))
return _base.py_str(debug_str.value)
def Variable(name, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
kwargs : dict of string -> string
Additional attributes to set on the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
if not isinstance(name, _base.string_types):
raise TypeError('Expect a string for variable `name`')
handle = _base.SymbolHandle()
_base.check_call(_LIB.NNSymbolCreateVariable(
_base.c_str(name), _ctypes.byref(handle)))
ret = Symbol(handle)
attr = AttrScope.current.get(kwargs)
if attr:
ret._set_attr(**attr)
return ret
def Group(symbols):
"""Create a symbol that groups symbols together.
Parameters
----------
symbols : list
List of symbols to be grouped.
Returns
-------
sym : Symbol
The created group symbol.
"""
ihandles = []
for sym in symbols:
if not isinstance(sym, Symbol):
raise TypeError('Expect Symbols in the list input')
ihandles.append(sym.handle)
handle = _base.SymbolHandle()
_check_call(_LIB.NNSymbolCreateGroup(
_base.nn_uint(len(ihandles)),
_base.c_array(_base.SymbolHandle, ihandles),
_ctypes.byref(handle)))
return Symbol(handle)
# Set the real symbol class to Symbol
_set_symbol_class(Symbol)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment