base.pxi 3.89 KB
Newer Older
1 2 3
from ..base import TVMError
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
4
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t
5 6 7 8 9 10 11 12
import ctypes

cdef enum TVMTypeCode:
    kInt = 0
    kUInt = 1
    kFloat = 2
    kHandle = 3
    kNull = 4
13 14 15 16 17 18 19 20
    kTVMType = 5
    kTVMContext = 6
    kArrayHandle = 7
    kNodeHandle = 8
    kModuleHandle = 9
    kFuncHandle = 10
    kStr = 11
    kBytes = 12
21
    kExtBegin = 15
22 23

cdef extern from "tvm/runtime/c_runtime_api.h":
24
    ctypedef struct DLDataType:
25 26 27 28
        uint8_t code
        uint8_t bits
        uint16_t lanes

29 30
    ctypedef struct DLContext:
        int device_type
31
        int device_id
32 33 34 35 36 37 38 39

    ctypedef struct DLTensor:
        void* data
        DLContext ctx
        int ndim
        DLDataType dtype
        int64_t* shape
        int64_t* strides
40
        uint64_t byte_offset;
41

42 43 44 45 46
    ctypedef struct TVMValue:
        int64_t v_int64
        double v_float64
        void* v_handle
        const char* v_str
47
        DLDataType v_type
48
        DLContext v_ctx
49

50 51 52
ctypedef int64_t tvm_index_t
ctypedef void* DLTensorHandle
ctypedef void* TVMStreamHandle
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* NodeHandle

ctypedef int (*TVMPackedCFunc)(
    TVMValue* args,
    int* type_codes,
    int num_args,
    TVMRetValueHandle ret,
    void* resource_handle)

ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)

cdef extern from "tvm/runtime/c_runtime_api.h":
    void TVMAPISetLastError(const char* msg);
    const char *TVMGetLastError();
    int TVMFuncCall(TVMFunctionHandle func,
                    TVMValue* arg_values,
                    int* type_codes,
                    int num_args,
                    TVMValue* ret_val,
                    int* ret_type_code)
    int TVMFuncFree(TVMFunctionHandle func)
    int TVMCFuncSetReturn(TVMRetValueHandle ret,
77 78 79
                          TVMValue* value,
                          int* type_code,
                          int num_ret)
80 81 82 83
    int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
                               void* resource_handle,
                               TVMPackedCFuncFinalizer fin,
                               TVMFunctionHandle *out)
84
    int TVMCbArgToReturn(TVMValue* value, int code)
85 86 87 88 89 90 91 92 93
    int TVMArrayAlloc(tvm_index_t* shape,
                      tvm_index_t ndim,
                      DLDataType dtype,
                      DLContext ctx,
                      DLTensorHandle* out)
    int TVMArrayFree(DLTensorHandle handle)
    int TVMArrayCopyFromTo(DLTensorHandle src,
                           DLTensorHandle to,
                           TVMStreamHandle stream)
94

95
cdef extern from "tvm/c_dsl_api.h":
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    int TVMNodeFree(NodeHandle handle)
    TVMNodeTypeKey2Index(const char* type_key,
                         int* out_index)
    int TVMNodeGetTypeIndex(NodeHandle handle,
                            int* out_index)
    int TVMNodeGetAttr(NodeHandle handle,
                       const char* key,
                       TVMValue* out_value,
                       int* out_type_code,
                       int* out_success)

cdef inline py_str(const char* x):
    if PY_MAJOR_VERSION < 3:
        return x
    else:
        return x.decode("utf-8")


cdef inline c_str(pystr):
    """Create ctypes char * from a python string
    Parameters
    ----------
    string : string type
        python string

    Returns
    -------
    str : c_char_p
        A char pointer that can be passed to C API
    """
    return pystr.encode("utf-8")


cdef inline CALL(int ret):
    if ret != 0:
        raise TVMError(TVMGetLastError())


cdef inline object ctypes_handle(void* chandle):
    """Cast C handle to ctypes handle."""
    return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)

138

139 140 141 142 143
cdef inline void* c_handle(object handle):
    """Cast C types handle to c handle."""
    cdef unsigned long long v_ptr
    v_ptr = handle.value
    return <void*>(v_ptr)