base.pxi 4.36 KB
Newer Older
1 2 3
from ..base import TVMError
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
4
from cpython cimport pycapsule
5
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t
6 7 8 9 10 11 12 13
import ctypes

cdef enum TVMTypeCode:
    kInt = 0
    kUInt = 1
    kFloat = 2
    kHandle = 3
    kNull = 4
14 15 16 17 18 19 20 21
    kTVMType = 5
    kTVMContext = 6
    kArrayHandle = 7
    kNodeHandle = 8
    kModuleHandle = 9
    kFuncHandle = 10
    kStr = 11
    kBytes = 12
22
    kNDArrayContainer = 13
23
    kExtBegin = 15
24 25

cdef extern from "tvm/runtime/c_runtime_api.h":
26
    ctypedef struct DLDataType:
27 28 29 30
        uint8_t code
        uint8_t bits
        uint16_t lanes

31 32
    ctypedef struct DLContext:
        int device_type
33
        int device_id
34 35 36 37 38 39 40 41

    ctypedef struct DLTensor:
        void* data
        DLContext ctx
        int ndim
        DLDataType dtype
        int64_t* shape
        int64_t* strides
42
        uint64_t byte_offset
43

44 45 46 47 48
    ctypedef struct DLManagedTensor:
        DLTensor dl_tensor
        void* manager_ctx
        void (*deleter)(DLManagedTensor* self)

49 50 51 52 53
    ctypedef struct TVMValue:
        int64_t v_int64
        double v_float64
        void* v_handle
        const char* v_str
54
        DLDataType v_type
55
        DLContext v_ctx
56

57
ctypedef int64_t tvm_index_t
58
ctypedef DLTensor* DLTensorHandle
59
ctypedef void* TVMStreamHandle
60 61 62 63 64 65 66 67 68 69 70 71 72 73
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* NodeHandle

ctypedef int (*TVMPackedCFunc)(
    TVMValue* args,
    int* type_codes,
    int num_args,
    TVMRetValueHandle ret,
    void* resource_handle)

ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)

cdef extern from "tvm/runtime/c_runtime_api.h":
74 75
    void TVMAPISetLastError(const char* msg)
    const char *TVMGetLastError()
76 77 78 79 80 81 82 83
    int TVMFuncCall(TVMFunctionHandle func,
                    TVMValue* arg_values,
                    int* type_codes,
                    int num_args,
                    TVMValue* ret_val,
                    int* ret_type_code)
    int TVMFuncFree(TVMFunctionHandle func)
    int TVMCFuncSetReturn(TVMRetValueHandle ret,
84 85 86
                          TVMValue* value,
                          int* type_code,
                          int num_ret)
87 88 89 90
    int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
                               void* resource_handle,
                               TVMPackedCFuncFinalizer fin,
                               TVMFunctionHandle *out)
91
    int TVMCbArgToReturn(TVMValue* value, int code)
92 93 94 95 96 97 98 99 100
    int TVMArrayAlloc(tvm_index_t* shape,
                      tvm_index_t ndim,
                      DLDataType dtype,
                      DLContext ctx,
                      DLTensorHandle* out)
    int TVMArrayFree(DLTensorHandle handle)
    int TVMArrayCopyFromTo(DLTensorHandle src,
                           DLTensorHandle to,
                           TVMStreamHandle stream)
101 102 103 104 105
    int TVMArrayFromDLPack(DLManagedTensor* arr_from,
                           DLTensorHandle* out)
    int TVMArrayToDLPack(DLTensorHandle arr_from,
                         DLManagedTensor** out)
    void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
106

107
cdef extern from "tvm/c_dsl_api.h":
108
    int TVMNodeFree(NodeHandle handle)
109 110
    int TVMNodeTypeKey2Index(const char* type_key,
                             int* out_index)
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    int TVMNodeGetTypeIndex(NodeHandle handle,
                            int* out_index)
    int TVMNodeGetAttr(NodeHandle handle,
                       const char* key,
                       TVMValue* out_value,
                       int* out_type_code,
                       int* out_success)

cdef inline py_str(const char* x):
    if PY_MAJOR_VERSION < 3:
        return x
    else:
        return x.decode("utf-8")


cdef inline c_str(pystr):
    """Create ctypes char * from a python string
    Parameters
    ----------
    string : string type
        python string

    Returns
    -------
    str : c_char_p
        A char pointer that can be passed to C API
    """
    return pystr.encode("utf-8")


cdef inline CALL(int ret):
    if ret != 0:
143
        raise TVMError(py_str(TVMGetLastError()))
144 145 146 147 148 149


cdef inline object ctypes_handle(void* chandle):
    """Cast C handle to ctypes handle."""
    return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)

150

151 152 153 154 155
cdef inline void* c_handle(object handle):
    """Cast C types handle to c handle."""
    cdef unsigned long long v_ptr
    v_ptr = handle.value
    return <void*>(v_ptr)