Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
c1e48e1a
Commit
c1e48e1a
authored
Jul 17, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[CYTHON] Make speedup component minimum (#13)
parent
31f9fc0a
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
241 additions
and
359 deletions
+241
-359
nnvm/python/nnvm/_symbol_internal.py
+1
-0
nnvm/python/nnvm/ctypes/symbol.py
+18
-183
nnvm/python/nnvm/cython/symbol.pyx
+17
-172
nnvm/python/nnvm/symbol.py
+205
-4
No files found.
nnvm/python/nnvm/_symbol_internal.py
0 → 100644
View file @
c1e48e1a
"""Module space to register internal functions. Leave empty"""
nnvm/python/nnvm/ctypes/symbol.py
View file @
c1e48e1a
...
...
@@ -13,11 +13,9 @@ from .._base import check_call, ctypes2docstring
from
..name
import
NameManager
from
..attribute
import
AttrScope
__all__
=
[
"Symbol"
,
"Variable"
]
class
Symbol
(
object
):
class
SymbolBase
(
object
):
"""Symbol is symbolic graph."""
__slots__
=
[
"handle"
]
# pylint: disable=no-member
def
__init__
(
self
,
handle
):
"""Initialize the function with handle
...
...
@@ -32,15 +30,6 @@ class Symbol(object):
def
__del__
(
self
):
check_call
(
_LIB
.
NNSymbolFree
(
self
.
handle
))
def
__copy__
(
self
):
return
copy
.
deepcopy
(
self
)
def
__deepcopy__
(
self
,
_
):
handle
=
SymbolHandle
()
check_call
(
_LIB
.
NNSymbolCopy
(
self
.
handle
,
ctypes
.
byref
(
handle
)))
return
Symbol
(
handle
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""Invoke symbol as function on inputs.
...
...
@@ -85,10 +74,10 @@ class Symbol(object):
either as positional or keyword arguments, not both'
)
for
arg
in
args
:
if
not
isinstance
(
arg
,
Symbol
):
if
not
isinstance
(
arg
,
Symbol
Base
):
raise
TypeError
(
'Compose expect `Symbol` as arguments'
)
for
val
in
kwargs
.
values
():
if
not
isinstance
(
val
,
Symbol
):
if
not
isinstance
(
val
,
Symbol
Base
):
raise
TypeError
(
'Compose expect `Symbol` as arguments'
)
num_args
=
len
(
args
)
+
len
(
kwargs
)
...
...
@@ -101,65 +90,6 @@ class Symbol(object):
check_call
(
_LIB
.
NNSymbolCompose
(
self
.
handle
,
name
,
num_args
,
keys
,
args
))
def
__getitem__
(
self
,
index
):
if
isinstance
(
index
,
string_types
):
idx
=
None
for
i
,
name
in
enumerate
(
self
.
list_outputs
()):
if
name
==
index
:
if
idx
is
not
None
:
raise
ValueError
(
'There are multiple outputs with name
\"
%
s
\"
'
%
index
)
idx
=
i
if
idx
is
None
:
raise
ValueError
(
'Cannot find output that matches name
\"
%
s
\"
'
%
index
)
index
=
idx
if
not
isinstance
(
index
,
int
):
raise
TypeError
(
'Symbol only support integer index to fetch i-th output'
)
handle
=
SymbolHandle
()
check_call
(
_LIB
.
NNSymbolGetOutput
(
self
.
handle
,
nn_uint
(
index
),
ctypes
.
byref
(
handle
)))
return
Symbol
(
handle
=
handle
)
def
attr
(
self
,
key
):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
ret
=
ctypes
.
c_char_p
()
success
=
ctypes
.
c_int
()
check_call
(
_LIB
.
NNSymbolGetAttr
(
self
.
handle
,
c_str
(
key
),
ctypes
.
byref
(
ret
),
ctypes
.
byref
(
success
)))
if
success
.
value
!=
0
:
return
py_str
(
ret
.
value
)
else
:
return
None
def
list_attr
(
self
,
recursive
=
False
):
"""Get all attributes from the symbol.
Parameters
----------
recursive : bool
Default `False`. When `recursive` is `True`, list recursively all the
attributes in the descendents. The attribute names are pre-pended with
the symbol names to avoid conflicts. If `False`, then only attributes
that belongs to this symbol is returned, and the attribute names will
**not** be pre-pended with the symbol name.
"""
size
=
nn_uint
()
pairs
=
ctypes
.
POINTER
(
ctypes
.
c_char_p
)()
option
=
ctypes
.
c_int
(
0
)
if
recursive
else
ctypes
.
c_int
(
1
)
check_call
(
_LIB
.
NNSymbolListAttrs
(
self
.
handle
,
option
,
ctypes
.
byref
(
size
),
ctypes
.
byref
(
pairs
)))
return
{
py_str
(
pairs
[
i
*
2
]):
py_str
(
pairs
[
i
*
2
+
1
])
for
i
in
range
(
size
.
value
)}
def
_set_attr
(
self
,
**
kwargs
):
"""Set the attribute of the symbol.
...
...
@@ -168,116 +98,20 @@ class Symbol(object):
**kwargs
The attributes to set
"""
keys
=
c_array
(
ctypes
.
c_char_p
,
[
c_str
(
key
)
for
key
in
kwargs
.
keys
()])
vals
=
c_array
(
ctypes
.
c_char_p
,
[
c_str
(
str
(
val
))
for
val
in
kwargs
.
values
()])
num_args
=
nn_uint
(
len
(
kwargs
))
check_call
(
_LIB
.
NNSymbolSetAttrs
(
keys
=
_base
.
c_array
(
_ctypes
.
c_char_p
,
[
_base
.
c_str
(
key
)
for
key
in
kwargs
.
keys
()])
vals
=
_base
.
c_array
(
_ctypes
.
c_char_p
,
[
_base
.
c_str
(
str
(
val
))
for
val
in
kwargs
.
values
()])
num_args
=
_base
.
nn_uint
(
len
(
kwargs
))
_check_call
(
_LIB
.
NNSymbolSetAttrs
(
self
.
handle
,
num_args
,
keys
,
vals
))
def
get_internals
(
self
):
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.
Returns
-------
sgroup : Symbol
The internal of the symbol.
"""
handle
=
SymbolHandle
()
check_call
(
_LIB
.
NNSymbolGetInternals
(
self
.
handle
,
ctypes
.
byref
(
handle
)))
return
Symbol
(
handle
=
handle
)
def
list_arguments
(
self
):
"""List all the arguments in the symbol.
Returns
-------
args : list of string
List of all the arguments.
"""
size
=
ctypes
.
c_uint
()
sarr
=
ctypes
.
POINTER
(
ctypes
.
c_char_p
)()
check_call
(
_LIB
.
NNSymbolListArguments
(
self
.
handle
,
ctypes
.
byref
(
size
),
ctypes
.
byref
(
sarr
)))
return
[
py_str
(
sarr
[
i
])
for
i
in
range
(
size
.
value
)]
def
list_outputs
(
self
):
"""List all outputs in the symbol.
_symbol_cls
=
SymbolBase
Returns
-------
returns : list of string
List of all the outputs.
"""
size
=
ctypes
.
c_uint
()
sarr
=
ctypes
.
POINTER
(
ctypes
.
c_char_p
)()
check_call
(
_LIB
.
NNSymbolListOutputs
(
self
.
handle
,
ctypes
.
byref
(
size
),
ctypes
.
byref
(
sarr
)))
return
[
py_str
(
sarr
[
i
])
for
i
in
range
(
size
.
value
)]
def
debug_str
(
self
):
"""Get a debug string.
Returns
-------
debug_str : string
Debug string of the symbol.
"""
debug_str
=
ctypes
.
c_char_p
()
check_call
(
_LIB
.
NNSymbolPrint
(
self
.
handle
,
ctypes
.
byref
(
debug_str
)))
return
py_str
(
debug_str
.
value
)
def
Variable
(
name
,
**
kwargs
):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
kwargs : dict of string -> string
Additional attributes to set on the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
if
not
isinstance
(
name
,
string_types
):
raise
TypeError
(
'Expect a string for variable `name`'
)
handle
=
SymbolHandle
()
check_call
(
_LIB
.
NNSymbolCreateVariable
(
c_str
(
name
),
ctypes
.
byref
(
handle
)))
ret
=
Symbol
(
handle
)
attr
=
AttrScope
.
current
.
get
(
kwargs
)
if
attr
:
ret
.
_set_attr
(
**
attr
)
return
ret
def
Group
(
symbols
):
"""Create a symbol that groups symbols together.
Parameters
----------
symbols : list
List of symbols to be grouped.
Returns
-------
sym : Symbol
The created group symbol.
"""
ihandles
=
[]
for
sym
in
symbols
:
if
not
isinstance
(
sym
,
Symbol
):
raise
TypeError
(
'Expect Symbols in the list input'
)
ihandles
.
append
(
sym
.
handle
)
handle
=
SymbolHandle
()
check_call
(
_LIB
.
NNSymbolCreateGroup
(
nn_uint
(
len
(
ihandles
)),
c_array
(
SymbolHandle
,
ihandles
),
ctypes
.
byref
(
handle
)))
return
Symbol
(
handle
)
def
_set_symbol_class
(
cls
):
global
_symbol_cls
_symbol_cls
=
cls
def
_make_atomic_symbol_function
(
handle
):
...
...
@@ -332,7 +166,7 @@ def _make_atomic_symbol_function(handle):
attr
=
kwargs
.
pop
(
'attr'
,
None
)
for
k
,
v
in
kwargs
.
items
():
if
isinstance
(
v
,
Symbol
):
if
isinstance
(
v
,
Symbol
Base
):
symbol_kwargs
[
k
]
=
v
else
:
param_keys
.
append
(
c_str
(
k
))
...
...
@@ -351,7 +185,7 @@ def _make_atomic_symbol_function(handle):
raise
TypeError
(
'
%
s can only accept input'
'Symbols either as positional or keyword arguments, not both'
%
func_name
)
s
=
Symbol
(
sym_handle
)
s
=
_symbol_cls
(
sym_handle
)
attr
=
AttrScope
.
current
.
get
(
attr
)
if
attr
:
s
.
_set_attr
(
**
attr
)
...
...
@@ -373,11 +207,12 @@ def _init_symbol_module():
check_call
(
_LIB
.
NNSymbolListAtomicSymbolCreators
(
ctypes
.
byref
(
size
),
ctypes
.
byref
(
plist
)))
module_obj
=
sys
.
modules
[
"nnvm.symbol"
]
module_internal
=
sys
.
modules
[
"nnvm._symbol_internal"
]
for
i
in
range
(
size
.
value
):
hdl
=
SymbolHandle
(
plist
[
i
])
function
=
_make_atomic_symbol_function
(
hdl
)
if
function
.
__name__
.
startswith
(
'_'
):
setattr
(
Symbol
,
function
.
__name__
,
staticmethod
(
function
)
)
setattr
(
module_internal
,
function
.
__name__
,
function
)
else
:
setattr
(
module_obj
,
function
.
__name__
,
function
)
...
...
nnvm/python/nnvm/cython/symbol.pyx
View file @
c1e48e1a
...
...
@@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs
import
sys
as
_sys
import
ctypes
as
_ctypes
from
numbers
import
Number
as
_Number
from
..
_base
import
NNVMError
from
..
name
import
NameManager
from
..
attribute
import
AttrScope
...
...
@@ -64,8 +65,7 @@ cdef extern from "nnvm/c_api.h":
const
char
**
keys
,
SymbolHandle
*
args
);
cdef
class
Symbol
:
cdef
class
SymbolBase
:
"""Symbol is symbolic graph."""
# handle for symbolic operator.
cdef
SymbolHandle
handle
...
...
@@ -85,76 +85,6 @@ cdef class Symbol:
def
handle
(
self
)
:
return
_ctypes
.
cast
(
<
unsigned
long
>
self
.
handle
,
_ctypes
.
c_void_p
)
def
__copy__
(
self
)
:
return
self
.
__deepcopy__
()
def
__deepcopy__
(
self
,
_
=
None
)
:
cdef
SymbolHandle
handle
CALL
(
NNSymbolCopy
(
self
.
handle
,
&
handle
))
return
NewSymbol
(
handle
)
def
__getitem__
(
self
,
index
)
:
if
isinstance
(
index
,
str
)
:
idx
=
None
for
i
,
name
in
enumerate
(
self
.
list_outputs
())
:
if
name
==
index
:
if
idx
is
not
None
:
raise
ValueError
(
'
There
are
multiple
outputs
with
name
\
"%s
\"
' % index)
idx = i
if idx is None:
raise ValueError('Cannot find output that matches name
\"
%s
\"
' % index)
index = idx
if not isinstance(index, int):
raise TypeError('Symbol only support integer index to fetch i-th output')
cdef SymbolHandle handle
cdef nn_uint c_index = index
CALL(NNSymbolGetOutput(self.handle, c_index, &handle))
return NewSymbol(handle)
def attr(self, key):
"""
Get
attribute
string
from
the
symbol
,
this
function
only
works
for
non
-
grouped
symbol
.
Parameters
----------
key
:
str
The
key
to
get
attribute
from
.
Returns
-------
value
:
str
The
attribute
value
of
the
key
,
returns
None
if
attribute
do
not
exist
.
"""
cdef const char* ret
cdef int success
key = c_str(key)
CALL(NNSymbolGetAttr(
self.handle, key, &ret, &success))
if success != 0:
return py_str(ret.value)
else:
return None
def list_attr(self, recursive=False):
"""
Get
all
attributes
from
the
symbol
.
Parameters
----------
recursive
:
bool
Default
`
False
`
.
When
`
recursive
`
is
`
True
`
,
list
recursively
all
the
attributes
in
the
descendents
.
The
attribute
names
are
pre
-
pended
with
the
symbol
names
to
avoid
conflicts
.
If
`
False
`
,
then
only
attributes
that
belongs
to
this
symbol
is
returned
,
and
the
attribute
names
will
**
not
**
be
pre
-
pended
with
the
symbol
name
.
"""
cdef nn_uint size
cdef const char** pairs
cdef int option
option = 0 if recursive else 1
CALL(NNSymbolListAttrs(
self.handle, option, &size, &pairs))
return {py_str(pairs[i*2]): py_str(pairs[i*2+1]) for i in range(size)}
def
_set_attr
(
self
,
**
kwargs
)
:
"""Set the attribute of the symbol.
...
...
@@ -165,49 +95,6 @@ cdef class Symbol:
"""
SymbolSetAttr
(
self
.
handle
,
kwargs
)
def get_internals(self):
"""
Get
a
new
grouped
symbol
whose
output
contains
all
the
internal
outputs
of
this
symbol
.
Returns
-------
sgroup
:
Symbol
The
internal
of
the
symbol
.
"""
cdef SymbolHandle handle
CALL(NNSymbolGetInternals(self.handle, &handle))
return NewSymbol(handle)
def list_arguments(self):
"""
List
all
the
arguments
in
the
symbol
.
Returns
-------
args
:
list
of
string
List
of
all
the
arguments
.
"""
cdef nn_uint size
cdef const char ** sarr
CALL(NNSymbolListArguments(self.handle, &size, &sarr))
return [py_str(sarr[i]) for i in range(size)]
def list_outputs(self):
"""
List
all
outputs
in
the
symbol
.
Returns
-------
returns
:
list
of
string
List
of
all
the
outputs
.
"""
cdef nn_uint size
cdef const char ** sarr
CALL(NNSymbolListOutputs(self.handle, &size, &sarr))
return [py_str(sarr[i]) for i in range(size)]
def debug_str(self):
cdef const char* out_str
CALL(NNSymbolPrint(self.handle, &out_str))
return py_str(out_str)
cdef
SymbolSetAttr
(
SymbolHandle
handle
,
dict
kwargs
)
:
cdef
vector
[
string
]
sparam_keys
...
...
@@ -224,34 +111,18 @@ cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
handle
,
num_args
,
CBeginPtr
(
param_keys
),
CBeginPtr
(
param_vals
)))
_symbol_cls
=
SymbolBase
def
_set_symbol_class
(
cls
)
:
global
_symbol_cls
_symbol_cls
=
cls
cdef
NewSymbol
(
SymbolHandle
handle
)
:
"""Create a new symbol given handle"""
sym =
Symbol
(None)
sym
.handle = handle
sym
=
_symbol_cls
(
None
)
(
<
SymbolBase
>
sym
)
.
handle
=
handle
return
sym
def Variable(name, **kwargs):
"""
Create
a
symbolic
variable
with
specified
name
.
Parameters
----------
name
:
str
Name
of
the
variable
.
kwargs
:
dict
of
string
->
string
Additional
attributes
to
set
on
the
variable
.
Returns
-------
variable
:
Symbol
The
created
variable
symbol
.
"""
cdef SymbolHandle handle
name = c_str(name)
CALL(NNSymbolCreateVariable(name, &handle))
return NewSymbol(handle)
cdef
_make_atomic_symbol_function
(
AtomicSymbolCreator
handle
)
:
"""Create an atomic symbol function by handle and funciton name."""
cdef
const
char
*
name
...
...
@@ -292,9 +163,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
if
len
(
kwargs
)
!=
0
:
for
k
,
v
in
kwargs
.
items
()
:
if isinstance(v, Symbol):
if
isinstance
(
v
,
Symbol
Base
)
:
ssymbol_keys
.
push_back
(
c_str
(
k
))
symbol_args.push_back((<Symbol>v).handle)
symbol_args
.
push_back
((
<
Symbol
Base
>
v
).
handle
)
else
:
sparam_keys
.
push_back
(
c_str
(
k
))
sparam_vals
.
push_back
(
c_str
(
str
(
v
)))
...
...
@@ -304,9 +175,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
raise
TypeError
(
"compose only accept input Symbols\
either as positional or keyword arguments, not both"
)
for
v
in
args
:
if not isinstance(v, Symbol):
if
not
isinstance
(
v
,
Symbol
Base
)
:
raise
TypeError
(
'
Compose
expect
`
Symbol
`
as
arguments
'
)
symbol_args.push_back((<Symbol>v).handle)
symbol_args
.
push_back
((
<
Symbol
Base
>
v
).
handle
)
cdef
vector
[
const
char
*
]
param_keys
=
SVec2Ptr
(
sparam_keys
)
cdef
vector
[
const
char
*
]
param_vals
=
SVec2Ptr
(
sparam_vals
)
...
...
@@ -344,46 +215,20 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
return
creator
def Group(symbols):
"""
Create
a
symbol
that
groups
symbols
together
.
Parameters
----------
symbols
:
list
List
of
symbols
to
be
grouped
.
Returns
-------
sym
:
Symbol
The
created
group
symbol
.
"""
cdef vector[SymbolHandle] ihandles
cdef SymbolHandle handle
for sym in symbols:
if not isinstance(sym, Symbol):
raise TypeError("
Expect
Symbols
in
the
list
input
")
ihandles.push_back((<Symbol>sym).handle)
if ihandles.size() == 0:
raise ValueError("
expect
at
least
one
element
in
the
input
")
CALL(NNSymbolCreateGroup(<nn_uint>ihandles.size(),
&ihandles[0], &handle))
return NewSymbol(handle)
def
_init_symbol_module
()
:
"""List and add all the atomic symbol functions to current module."""
cdef
AtomicSymbolCreator
*
plist
cdef
nn_uint
size
CALL
(
NNSymbolListAtomicSymbolCreators
(
&
size
,
&
plist
))
module_obj
=
_sys
.
modules
[
"nnvm.symbol"
]
module_internal
=
_sys
.
modules
[
"nnvm._symbol_internal"
]
for
i
in
range
(
size
)
:
function
=
_make_atomic_symbol_function
(
plist
[
i
])
if
function
.
__name__
.
startswith
(
'_'
)
:
setattr(
Symbol, function.__name__, staticmethod(function)
)
setattr
(
module_internal
,
function
.
__name__
,
function
)
else
:
setattr
(
module_obj
,
function
.
__name__
,
function
)
# Initialize the atomic symbol in startups
_init_symbol_module
()
nnvm/python/nnvm/symbol.py
View file @
c1e48e1a
...
...
@@ -2,13 +2,214 @@
from
__future__
import
absolute_import
as
_abs
import
sys
as
_sys
import
os
as
_os
import
ctypes
as
_ctypes
from
numbers
import
Number
as
_Number
from
.
import
_base
from
._base
import
_LIB
,
check_call
as
_check_call
from
.
import
_symbol_internal
as
_internal
from
.attribute
import
AttrScope
# Use different verison of SymbolBase
# When possible, use cython to speedup part of computation.
try
:
if
int
(
_os
.
environ
.
get
(
"NNVM_ENABLE_CYTHON"
,
True
))
==
0
:
from
.ctypes.symbol
import
Symbol
,
Variable
from
.ctypes.symbol
import
Symbol
Base
,
_set_symbol_class
elif
_sys
.
version_info
>=
(
3
,
0
):
from
._cy3.symbol
import
Symbol
,
Variable
,
Group
from
._cy3.symbol
import
Symbol
Base
,
_set_symbol_class
else
:
from
._cy2.symbol
import
Symbol
,
Variable
,
Group
from
._cy2.symbol
import
Symbol
Base
,
_set_symbol_class
except
:
from
.ctypes.symbol
import
Symbol
,
Variable
,
Group
from
.ctypes.symbol
import
SymbolBase
,
_set_symbol_class
class
Symbol
(
SymbolBase
):
"""Symbol is basic operation unit for symbolic graph compostion."""
# disable dictionary storage, also do not have parent type.
__slots__
=
[]
def
__add__
(
self
,
other
):
if
isinstance
(
other
,
Symbol
):
return
_internal
.
__add__symbol__
(
self
,
other
)
elif
isinstance
(
other
,
_Number
):
return
_internal
.
__add__scalar__
(
self
,
scalar
=
other
)
else
:
raise
TypeError
(
"type
%
s not supported"
%
str
(
type
(
other
)))
def
__copy__
(
self
):
return
self
.
__deepcopy__
()
def
__deepcopy__
(
self
,
_
=
None
):
handle
=
_base
.
SymbolHandle
()
_base
.
check_call
(
_LIB
.
NNSymbolCopy
(
self
.
handle
,
_ctypes
.
byref
(
handle
)))
return
Symbol
(
handle
)
def
__getitem__
(
self
,
index
):
if
isinstance
(
index
,
_base
.
string_types
):
idx
=
None
for
i
,
name
in
enumerate
(
self
.
list_outputs
()):
if
name
==
index
:
if
idx
is
not
None
:
raise
ValueError
(
'There are multiple outputs with name
\"
%
s
\"
'
%
index
)
idx
=
i
if
idx
is
None
:
raise
ValueError
(
'Cannot find output that matches name
\"
%
s
\"
'
%
index
)
index
=
idx
if
not
isinstance
(
index
,
int
):
raise
TypeError
(
'Symbol only support integer index to fetch i-th output'
)
handle
=
_base
.
SymbolHandle
()
_check_call
(
_LIB
.
NNSymbolGetOutput
(
self
.
handle
,
_base
.
nn_uint
(
index
),
_ctypes
.
byref
(
handle
)))
return
Symbol
(
handle
=
handle
)
def
attr
(
self
,
key
):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
ret
=
_ctypes
.
c_char_p
()
success
=
_ctypes
.
c_int
()
_check_call
(
_LIB
.
NNSymbolGetAttr
(
self
.
handle
,
c_str
(
key
),
_ctypes
.
byref
(
ret
),
_ctypes
.
byref
(
success
)))
if
success
.
value
!=
0
:
return
_base
.
py_str
(
ret
.
value
)
else
:
return
None
def
list_attr
(
self
,
recursive
=
False
):
"""Get all attributes from the symbol.
Parameters
----------
recursive : bool
Default `False`. When `recursive` is `True`, list recursively all the
attributes in the descendents. The attribute names are pre-pended with
the symbol names to avoid conflicts. If `False`, then only attributes
that belongs to this symbol is returned, and the attribute names will
**not** be pre-pended with the symbol name.
"""
size
=
_base
.
nn_uint
()
pairs
=
_ctypes
.
POINTER
(
_ctypes
.
c_char_p
)()
option
=
_ctypes
.
c_int
(
0
)
if
recursive
else
_ctypes
.
c_int
(
1
)
_check_call
(
_LIB
.
NNSymbolListAttrs
(
self
.
handle
,
option
,
_ctypes
.
byref
(
size
),
_ctypes
.
byref
(
pairs
)))
return
{
_base
.
py_str
(
pairs
[
i
*
2
]):
_base
.
py_str
(
pairs
[
i
*
2
+
1
])
for
i
in
range
(
size
.
value
)}
def
get_internals
(
self
):
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.
Returns
-------
sgroup : Symbol
The internal of the symbol.
"""
handle
=
_base
.
SymbolHandle
()
_check_call
(
_LIB
.
NNSymbolGetInternals
(
self
.
handle
,
_ctypes
.
byref
(
handle
)))
return
Symbol
(
handle
=
handle
)
def
list_arguments
(
self
):
"""List all the arguments in the symbol.
Returns
-------
args : list of string
List of all the arguments.
"""
size
=
_ctypes
.
c_uint
()
sarr
=
_ctypes
.
POINTER
(
_ctypes
.
c_char_p
)()
_check_call
(
_LIB
.
NNSymbolListArguments
(
self
.
handle
,
_ctypes
.
byref
(
size
),
_ctypes
.
byref
(
sarr
)))
return
[
_base
.
py_str
(
sarr
[
i
])
for
i
in
range
(
size
.
value
)]
def
list_outputs
(
self
):
"""List all outputs in the symbol.
Returns
-------
returns : list of string
List of all the outputs.
"""
size
=
_ctypes
.
c_uint
()
sarr
=
_ctypes
.
POINTER
(
_ctypes
.
c_char_p
)()
_check_call
(
_LIB
.
NNSymbolListOutputs
(
self
.
handle
,
_ctypes
.
byref
(
size
),
_ctypes
.
byref
(
sarr
)))
return
[
_base
.
py_str
(
sarr
[
i
])
for
i
in
range
(
size
.
value
)]
def
debug_str
(
self
):
"""Get a debug string.
Returns
-------
debug_str : string
Debug string of the symbol.
"""
debug_str
=
_ctypes
.
c_char_p
()
_check_call
(
_LIB
.
NNSymbolPrint
(
self
.
handle
,
_ctypes
.
byref
(
debug_str
)))
return
_base
.
py_str
(
debug_str
.
value
)
def
Variable
(
name
,
**
kwargs
):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
kwargs : dict of string -> string
Additional attributes to set on the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
if
not
isinstance
(
name
,
_base
.
string_types
):
raise
TypeError
(
'Expect a string for variable `name`'
)
handle
=
_base
.
SymbolHandle
()
_base
.
check_call
(
_LIB
.
NNSymbolCreateVariable
(
_base
.
c_str
(
name
),
_ctypes
.
byref
(
handle
)))
ret
=
Symbol
(
handle
)
attr
=
AttrScope
.
current
.
get
(
kwargs
)
if
attr
:
ret
.
_set_attr
(
**
attr
)
return
ret
def
Group
(
symbols
):
"""Create a symbol that groups symbols together.
Parameters
----------
symbols : list
List of symbols to be grouped.
Returns
-------
sym : Symbol
The created group symbol.
"""
ihandles
=
[]
for
sym
in
symbols
:
if
not
isinstance
(
sym
,
Symbol
):
raise
TypeError
(
'Expect Symbols in the list input'
)
ihandles
.
append
(
sym
.
handle
)
handle
=
_base
.
SymbolHandle
()
_check_call
(
_LIB
.
NNSymbolCreateGroup
(
_base
.
nn_uint
(
len
(
ihandles
)),
_base
.
c_array
(
_base
.
SymbolHandle
,
ihandles
),
_ctypes
.
byref
(
handle
)))
return
Symbol
(
handle
)
# Set the real symbol class to Symbol
_set_symbol_class
(
Symbol
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment