Commit c1e48e1a by Tianqi Chen

[CYTHON] Make speedup component minimum (#13)

parent 31f9fc0a
"""Module space to register internal functions. Leave empty"""
......@@ -13,11 +13,9 @@ from .._base import check_call, ctypes2docstring
from ..name import NameManager
from ..attribute import AttrScope
__all__ = ["Symbol", "Variable"]
class Symbol(object):
class SymbolBase(object):
"""Symbol is symbolic graph."""
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
......@@ -32,15 +30,6 @@ class Symbol(object):
def __del__(self):
check_call(_LIB.NNSymbolFree(self.handle))
def __copy__(self):
return copy.deepcopy(self)
def __deepcopy__(self, _):
handle = SymbolHandle()
check_call(_LIB.NNSymbolCopy(self.handle,
ctypes.byref(handle)))
return Symbol(handle)
def __call__(self, *args, **kwargs):
"""Invoke symbol as function on inputs.
......@@ -85,10 +74,10 @@ class Symbol(object):
either as positional or keyword arguments, not both')
for arg in args:
if not isinstance(arg, Symbol):
if not isinstance(arg, SymbolBase):
raise TypeError('Compose expect `Symbol` as arguments')
for val in kwargs.values():
if not isinstance(val, Symbol):
if not isinstance(val, SymbolBase):
raise TypeError('Compose expect `Symbol` as arguments')
num_args = len(args) + len(kwargs)
......@@ -101,65 +90,6 @@ class Symbol(object):
check_call(_LIB.NNSymbolCompose(
self.handle, name, num_args, keys, args))
def __getitem__(self, index):
if isinstance(index, string_types):
idx = None
for i, name in enumerate(self.list_outputs()):
if name == index:
if idx is not None:
raise ValueError('There are multiple outputs with name \"%s\"' % index)
idx = i
if idx is None:
raise ValueError('Cannot find output that matches name \"%s\"' % index)
index = idx
if not isinstance(index, int):
raise TypeError('Symbol only support integer index to fetch i-th output')
handle = SymbolHandle()
check_call(_LIB.NNSymbolGetOutput(
self.handle, nn_uint(index), ctypes.byref(handle)))
return Symbol(handle=handle)
def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
ret = ctypes.c_char_p()
success = ctypes.c_int()
check_call(_LIB.NNSymbolGetAttr(
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
return py_str(ret.value)
else:
return None
def list_attr(self, recursive=False):
"""Get all attributes from the symbol.
Parameters
----------
recursive : bool
Default `False`. When `recursive` is `True`, list recursively all the
attributes in the descendents. The attribute names are pre-pended with
the symbol names to avoid conflicts. If `False`, then only attributes
that belongs to this symbol is returned, and the attribute names will
**not** be pre-pended with the symbol name.
"""
size = nn_uint()
pairs = ctypes.POINTER(ctypes.c_char_p)()
option = ctypes.c_int(0) if recursive else ctypes.c_int(1)
check_call(_LIB.NNSymbolListAttrs(
self.handle, option, ctypes.byref(size), ctypes.byref(pairs)))
return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size.value)}
def _set_attr(self, **kwargs):
"""Set the attribute of the symbol.
......@@ -168,116 +98,20 @@ class Symbol(object):
**kwargs
The attributes to set
"""
keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()])
vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()])
num_args = nn_uint(len(kwargs))
check_call(_LIB.NNSymbolSetAttrs(
keys = _base.c_array(_ctypes.c_char_p,
[_base.c_str(key) for key in kwargs.keys()])
vals = _base.c_array(_ctypes.c_char_p,
[_base.c_str(str(val)) for val in kwargs.values()])
num_args = _base.nn_uint(len(kwargs))
_check_call(_LIB.NNSymbolSetAttrs(
self.handle, num_args, keys, vals))
def get_internals(self):
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.
Returns
-------
sgroup : Symbol
The internal of the symbol.
"""
handle = SymbolHandle()
check_call(_LIB.NNSymbolGetInternals(
self.handle, ctypes.byref(handle)))
return Symbol(handle=handle)
def list_arguments(self):
"""List all the arguments in the symbol.
Returns
-------
args : list of string
List of all the arguments.
"""
size = ctypes.c_uint()
sarr = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.NNSymbolListArguments(
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]
def list_outputs(self):
"""List all outputs in the symbol.
_symbol_cls = SymbolBase
Returns
-------
returns : list of string
List of all the outputs.
"""
size = ctypes.c_uint()
sarr = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.NNSymbolListOutputs(
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]
def debug_str(self):
"""Get a debug string.
Returns
-------
debug_str : string
Debug string of the symbol.
"""
debug_str = ctypes.c_char_p()
check_call(_LIB.NNSymbolPrint(
self.handle, ctypes.byref(debug_str)))
return py_str(debug_str.value)
def Variable(name, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
kwargs : dict of string -> string
Additional attributes to set on the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
if not isinstance(name, string_types):
raise TypeError('Expect a string for variable `name`')
handle = SymbolHandle()
check_call(_LIB.NNSymbolCreateVariable(c_str(name), ctypes.byref(handle)))
ret = Symbol(handle)
attr = AttrScope.current.get(kwargs)
if attr:
ret._set_attr(**attr)
return ret
def Group(symbols):
"""Create a symbol that groups symbols together.
Parameters
----------
symbols : list
List of symbols to be grouped.
Returns
-------
sym : Symbol
The created group symbol.
"""
ihandles = []
for sym in symbols:
if not isinstance(sym, Symbol):
raise TypeError('Expect Symbols in the list input')
ihandles.append(sym.handle)
handle = SymbolHandle()
check_call(_LIB.NNSymbolCreateGroup(
nn_uint(len(ihandles)),
c_array(SymbolHandle, ihandles), ctypes.byref(handle)))
return Symbol(handle)
def _set_symbol_class(cls):
global _symbol_cls
_symbol_cls = cls
def _make_atomic_symbol_function(handle):
......@@ -332,7 +166,7 @@ def _make_atomic_symbol_function(handle):
attr = kwargs.pop('attr', None)
for k, v in kwargs.items():
if isinstance(v, Symbol):
if isinstance(v, SymbolBase):
symbol_kwargs[k] = v
else:
param_keys.append(c_str(k))
......@@ -351,7 +185,7 @@ def _make_atomic_symbol_function(handle):
raise TypeError(
'%s can only accept input'
'Symbols either as positional or keyword arguments, not both' % func_name)
s = Symbol(sym_handle)
s = _symbol_cls(sym_handle)
attr = AttrScope.current.get(attr)
if attr:
s._set_attr(**attr)
......@@ -373,11 +207,12 @@ def _init_symbol_module():
check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size),
ctypes.byref(plist)))
module_obj = sys.modules["nnvm.symbol"]
module_internal = sys.modules["nnvm._symbol_internal"]
for i in range(size.value):
hdl = SymbolHandle(plist[i])
function = _make_atomic_symbol_function(hdl)
if function.__name__.startswith('_'):
setattr(Symbol, function.__name__, staticmethod(function))
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
......
......@@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs
import sys as _sys
import ctypes as _ctypes
from numbers import Number as _Number
from .._base import NNVMError
from ..name import NameManager
from ..attribute import AttrScope
......@@ -64,8 +65,7 @@ cdef extern from "nnvm/c_api.h":
const char** keys,
SymbolHandle* args);
cdef class Symbol:
cdef class SymbolBase:
"""Symbol is symbolic graph."""
# handle for symbolic operator.
cdef SymbolHandle handle
......@@ -85,76 +85,6 @@ cdef class Symbol:
def handle(self):
return _ctypes.cast(<unsigned long>self.handle, _ctypes.c_void_p)
def __copy__(self):
return self.__deepcopy__()
def __deepcopy__(self, _ = None):
cdef SymbolHandle handle
CALL(NNSymbolCopy(self.handle, &handle))
return NewSymbol(handle)
def __getitem__(self, index):
if isinstance(index, str):
idx = None
for i, name in enumerate(self.list_outputs()):
if name == index:
if idx is not None:
raise ValueError('There are multiple outputs with name \"%s\"' % index)
idx = i
if idx is None:
raise ValueError('Cannot find output that matches name \"%s\"' % index)
index = idx
if not isinstance(index, int):
raise TypeError('Symbol only support integer index to fetch i-th output')
cdef SymbolHandle handle
cdef nn_uint c_index = index
CALL(NNSymbolGetOutput(self.handle, c_index, &handle))
return NewSymbol(handle)
def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
cdef const char* ret
cdef int success
key = c_str(key)
CALL(NNSymbolGetAttr(
self.handle, key, &ret, &success))
if success != 0:
return py_str(ret.value)
else:
return None
def list_attr(self, recursive=False):
"""Get all attributes from the symbol.
Parameters
----------
recursive : bool
Default `False`. When `recursive` is `True`, list recursively all the
attributes in the descendents. The attribute names are pre-pended with
the symbol names to avoid conflicts. If `False`, then only attributes
that belongs to this symbol is returned, and the attribute names will
**not** be pre-pended with the symbol name.
"""
cdef nn_uint size
cdef const char** pairs
cdef int option
option = 0 if recursive else 1
CALL(NNSymbolListAttrs(
self.handle, option, &size, &pairs))
return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size)}
def _set_attr(self, **kwargs):
"""Set the attribute of the symbol.
......@@ -165,49 +95,6 @@ cdef class Symbol:
"""
SymbolSetAttr(self.handle, kwargs)
def get_internals(self):
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.
Returns
-------
sgroup : Symbol
The internal of the symbol.
"""
cdef SymbolHandle handle
CALL(NNSymbolGetInternals(self.handle, &handle))
return NewSymbol(handle)
def list_arguments(self):
"""List all the arguments in the symbol.
Returns
-------
args : list of string
List of all the arguments.
"""
cdef nn_uint size
cdef const char ** sarr
CALL(NNSymbolListArguments(self.handle, &size, &sarr))
return [py_str(sarr[i]) for i in range(size)]
def list_outputs(self):
"""List all outputs in the symbol.
Returns
-------
returns : list of string
List of all the outputs.
"""
cdef nn_uint size
cdef const char ** sarr
CALL(NNSymbolListOutputs(self.handle, &size, &sarr))
return [py_str(sarr[i]) for i in range(size)]
def debug_str(self):
cdef const char* out_str
CALL(NNSymbolPrint(self.handle, &out_str))
return py_str(out_str)
cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
cdef vector[string] sparam_keys
......@@ -224,34 +111,18 @@ cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals)))
_symbol_cls = SymbolBase
def _set_symbol_class(cls):
global _symbol_cls
_symbol_cls = cls
cdef NewSymbol(SymbolHandle handle):
"""Create a new symbol given handle"""
sym = Symbol(None)
sym.handle = handle
sym = _symbol_cls(None)
(<SymbolBase>sym).handle = handle
return sym
def Variable(name, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
kwargs : dict of string -> string
Additional attributes to set on the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
cdef SymbolHandle handle
name = c_str(name)
CALL(NNSymbolCreateVariable(name, &handle))
return NewSymbol(handle)
cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
"""Create an atomic symbol function by handle and funciton name."""
cdef const char *name
......@@ -292,9 +163,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
if len(kwargs) != 0:
for k, v in kwargs.items():
if isinstance(v, Symbol):
if isinstance(v, SymbolBase):
ssymbol_keys.push_back(c_str(k))
symbol_args.push_back((<Symbol>v).handle)
symbol_args.push_back((<SymbolBase>v).handle)
else:
sparam_keys.push_back(c_str(k))
sparam_vals.push_back(c_str(str(v)))
......@@ -304,9 +175,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
raise TypeError("compose only accept input Symbols\
either as positional or keyword arguments, not both")
for v in args:
if not isinstance(v, Symbol):
if not isinstance(v, SymbolBase):
raise TypeError('Compose expect `Symbol` as arguments')
symbol_args.push_back((<Symbol>v).handle)
symbol_args.push_back((<SymbolBase>v).handle)
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
......@@ -344,46 +215,20 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
return creator
def Group(symbols):
"""Create a symbol that groups symbols together.
Parameters
----------
symbols : list
List of symbols to be grouped.
Returns
-------
sym : Symbol
The created group symbol.
"""
cdef vector[SymbolHandle] ihandles
cdef SymbolHandle handle
for sym in symbols:
if not isinstance(sym, Symbol):
raise TypeError("Expect Symbols in the list input")
ihandles.push_back((<Symbol>sym).handle)
if ihandles.size() == 0:
raise ValueError("expect at least one element in the input")
CALL(NNSymbolCreateGroup(<nn_uint>ihandles.size(),
&ihandles[0], &handle))
return NewSymbol(handle)
def _init_symbol_module():
"""List and add all the atomic symbol functions to current module."""
cdef AtomicSymbolCreator* plist
cdef nn_uint size
CALL(NNSymbolListAtomicSymbolCreators(&size, &plist))
module_obj = _sys.modules["nnvm.symbol"]
module_internal = _sys.modules["nnvm._symbol_internal"]
for i in range(size):
function = _make_atomic_symbol_function(plist[i])
if function.__name__.startswith('_'):
setattr(Symbol, function.__name__, staticmethod(function))
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
# Initialize the atomic symbol in startups
_init_symbol_module()
......@@ -2,13 +2,214 @@
from __future__ import absolute_import as _abs
import sys as _sys
import os as _os
import ctypes as _ctypes
from numbers import Number as _Number
from . import _base
from ._base import _LIB, check_call as _check_call
from . import _symbol_internal as _internal
from .attribute import AttrScope
# Use different verison of SymbolBase
# When possible, use cython to speedup part of computation.
try:
if int(_os.environ.get("NNVM_ENABLE_CYTHON", True)) == 0:
from .ctypes.symbol import Symbol, Variable
from .ctypes.symbol import SymbolBase, _set_symbol_class
elif _sys.version_info >= (3, 0):
from ._cy3.symbol import Symbol, Variable, Group
from ._cy3.symbol import SymbolBase, _set_symbol_class
else:
from ._cy2.symbol import Symbol, Variable, Group
from ._cy2.symbol import SymbolBase, _set_symbol_class
except:
from .ctypes.symbol import Symbol, Variable, Group
from .ctypes.symbol import SymbolBase, _set_symbol_class
class Symbol(SymbolBase):
"""Symbol is basic operation unit for symbolic graph compostion."""
# disable dictionary storage, also do not have parent type.
__slots__ = []
def __add__(self, other):
if isinstance(other, Symbol):
return _internal.__add__symbol__(self, other)
elif isinstance(other, _Number):
return _internal.__add__scalar__(self, scalar=other)
else:
raise TypeError("type %s not supported" % str(type(other)))
def __copy__(self):
return self.__deepcopy__()
def __deepcopy__(self, _=None):
handle = _base.SymbolHandle()
_base.check_call(_LIB.NNSymbolCopy(self.handle,
_ctypes.byref(handle)))
return Symbol(handle)
def __getitem__(self, index):
if isinstance(index, _base.string_types):
idx = None
for i, name in enumerate(self.list_outputs()):
if name == index:
if idx is not None:
raise ValueError('There are multiple outputs with name \"%s\"' % index)
idx = i
if idx is None:
raise ValueError('Cannot find output that matches name \"%s\"' % index)
index = idx
if not isinstance(index, int):
raise TypeError('Symbol only support integer index to fetch i-th output')
handle = _base.SymbolHandle()
_check_call(_LIB.NNSymbolGetOutput(
self.handle, _base.nn_uint(index), _ctypes.byref(handle)))
return Symbol(handle=handle)
def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
ret = _ctypes.c_char_p()
success = _ctypes.c_int()
_check_call(_LIB.NNSymbolGetAttr(
self.handle, c_str(key), _ctypes.byref(ret), _ctypes.byref(success)))
if success.value != 0:
return _base.py_str(ret.value)
else:
return None
def list_attr(self, recursive=False):
"""Get all attributes from the symbol.
Parameters
----------
recursive : bool
Default `False`. When `recursive` is `True`, list recursively all the
attributes in the descendents. The attribute names are pre-pended with
the symbol names to avoid conflicts. If `False`, then only attributes
that belongs to this symbol is returned, and the attribute names will
**not** be pre-pended with the symbol name.
"""
size = _base.nn_uint()
pairs = _ctypes.POINTER(_ctypes.c_char_p)()
option = _ctypes.c_int(0) if recursive else _ctypes.c_int(1)
_check_call(_LIB.NNSymbolListAttrs(
self.handle, option, _ctypes.byref(size), _ctypes.byref(pairs)))
return {_base.py_str(pairs[i*2]): _base.py_str(pairs[i*2+1]) for i in range(size.value)}
def get_internals(self):
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.
Returns
-------
sgroup : Symbol
The internal of the symbol.
"""
handle = _base.SymbolHandle()
_check_call(_LIB.NNSymbolGetInternals(
self.handle, _ctypes.byref(handle)))
return Symbol(handle=handle)
def list_arguments(self):
"""List all the arguments in the symbol.
Returns
-------
args : list of string
List of all the arguments.
"""
size = _ctypes.c_uint()
sarr = _ctypes.POINTER(_ctypes.c_char_p)()
_check_call(_LIB.NNSymbolListArguments(
self.handle, _ctypes.byref(size), _ctypes.byref(sarr)))
return [_base.py_str(sarr[i]) for i in range(size.value)]
def list_outputs(self):
"""List all outputs in the symbol.
Returns
-------
returns : list of string
List of all the outputs.
"""
size = _ctypes.c_uint()
sarr = _ctypes.POINTER(_ctypes.c_char_p)()
_check_call(_LIB.NNSymbolListOutputs(
self.handle, _ctypes.byref(size), _ctypes.byref(sarr)))
return [_base.py_str(sarr[i]) for i in range(size.value)]
def debug_str(self):
"""Get a debug string.
Returns
-------
debug_str : string
Debug string of the symbol.
"""
debug_str = _ctypes.c_char_p()
_check_call(_LIB.NNSymbolPrint(
self.handle, _ctypes.byref(debug_str)))
return _base.py_str(debug_str.value)
def Variable(name, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
kwargs : dict of string -> string
Additional attributes to set on the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
if not isinstance(name, _base.string_types):
raise TypeError('Expect a string for variable `name`')
handle = _base.SymbolHandle()
_base.check_call(_LIB.NNSymbolCreateVariable(
_base.c_str(name), _ctypes.byref(handle)))
ret = Symbol(handle)
attr = AttrScope.current.get(kwargs)
if attr:
ret._set_attr(**attr)
return ret
def Group(symbols):
"""Create a symbol that groups symbols together.
Parameters
----------
symbols : list
List of symbols to be grouped.
Returns
-------
sym : Symbol
The created group symbol.
"""
ihandles = []
for sym in symbols:
if not isinstance(sym, Symbol):
raise TypeError('Expect Symbols in the list input')
ihandles.append(sym.handle)
handle = _base.SymbolHandle()
_check_call(_LIB.NNSymbolCreateGroup(
_base.nn_uint(len(ihandles)),
_base.c_array(_base.SymbolHandle, ihandles),
_ctypes.byref(handle)))
return Symbol(handle)
# Set the real symbol class to Symbol
_set_symbol_class(Symbol)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment