base.py 2.24 KB
Newer Older
tqchen committed
1
# coding: utf-8
2
# pylint: disable=invalid-name
tqchen committed
3 4 5 6
""" ctypes library of nnvm and helper functions """
from __future__ import absolute_import

import sys
7
import os
tqchen committed
8 9 10 11 12 13 14 15
import ctypes
import numpy as np
from . import libinfo

#----------------------------
# library loading
#----------------------------
if sys.version_info[0] == 3:
16
    string_types = (str,)
tqchen committed
17 18 19 20 21
    numeric_types = (float, int, np.float32, np.int32)
    # this function is needed for python3
    # to convert ctypes.char_p .value back to python str
    py_str = lambda x: x.decode('utf-8')
else:
22
    string_types = (basestring,)
tqchen committed
23 24 25 26 27
    numeric_types = (float, int, long, np.float32, np.int32)
    py_str = lambda x: x


class TVMError(Exception):
28
    """Error thrown by TVM function"""
tqchen committed
29 30
    pass

31

tqchen committed
32 33 34 35 36 37
def _load_lib():
    """Load libary by searching possible path."""
    lib_path = libinfo.find_lib_path()
    lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
    # DMatrix functions
    lib.TVMGetLastError.restype = ctypes.c_char_p
38
    return lib, os.path.basename(lib_path[0])
tqchen committed
39 40 41 42

# version number
__version__ = libinfo.__version__
# library instance of nnvm
43
_LIB, _LIB_NAME = _load_lib()
44 45
# The FFI mode of TVM
_FFI_MODE = os.environ.get("TVM_FFI", "auto")
tqchen committed
46 47

#----------------------------
48
# helper function in ctypes.
tqchen committed
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
#----------------------------
def check_call(ret):
    """Check the return value of C API call

    This function will raise exception when error occurs.
    Wrap every API call with this function

    Parameters
    ----------
    ret : int
        return value from API calls
    """
    if ret != 0:
        raise TVMError(py_str(_LIB.TVMGetLastError()))

64

tqchen committed
65 66 67 68 69 70 71 72 73 74 75 76 77 78
def c_str(string):
    """Create ctypes char * from a python string
    Parameters
    ----------
    string : string type
        python string

    Returns
    -------
    str : c_char_p
        A char pointer that can be passed to C API
    """
    return ctypes.c_char_p(string.encode('utf-8'))

79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

def c_array(ctype, values):
    """Create ctypes array from a python array

    Parameters
    ----------
    ctype : ctypes data type
        data type of the array we want to convert to

    values : tuple or list
        data content

    Returns
    -------
    out : ctypes array
        Created ctypes array
    """
    return (ctype * len(values))(*values)