base.pxi 4.99 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

18
from ..base import get_last_ffi_error
19 20
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
21
from cpython cimport pycapsule
22
from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t
23 24 25 26 27 28 29 30
import ctypes

cdef enum TVMTypeCode:
    kInt = 0
    kUInt = 1
    kFloat = 2
    kHandle = 3
    kNull = 4
31 32 33
    kTVMType = 5
    kTVMContext = 6
    kArrayHandle = 7
34
    kObjectHandle = 8
35 36 37 38
    kModuleHandle = 9
    kFuncHandle = 10
    kStr = 11
    kBytes = 12
39
    kNDArrayContainer = 13
40
    kExtBegin = 15
41 42

cdef extern from "tvm/runtime/c_runtime_api.h":
43
    ctypedef struct DLDataType:
44 45 46 47
        uint8_t code
        uint8_t bits
        uint16_t lanes

48 49
    ctypedef struct DLContext:
        int device_type
50
        int device_id
51 52 53 54 55 56 57 58

    ctypedef struct DLTensor:
        void* data
        DLContext ctx
        int ndim
        DLDataType dtype
        int64_t* shape
        int64_t* strides
59
        uint64_t byte_offset
60

61 62 63 64 65
    ctypedef struct DLManagedTensor:
        DLTensor dl_tensor
        void* manager_ctx
        void (*deleter)(DLManagedTensor* self)

66 67 68 69 70
    ctypedef struct TVMValue:
        int64_t v_int64
        double v_float64
        void* v_handle
        const char* v_str
71
        DLDataType v_type
72
        DLContext v_ctx
73

74
ctypedef int64_t tvm_index_t
75
ctypedef DLTensor* DLTensorHandle
76
ctypedef void* TVMStreamHandle
77 78
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
79
ctypedef void* ObjectHandle
80

81

82 83 84 85 86 87 88 89
ctypedef struct TVMNDArrayContainer:
    DLTensor dl_tensor
    void* manager_ctx
    void (*deleter)(DLManagedTensor* self)
    int32_t array_type_info

ctypedef TVMNDArrayContainer* TVMNDArrayContainerHandle

90 91 92 93 94 95 96 97 98 99
ctypedef int (*TVMPackedCFunc)(
    TVMValue* args,
    int* type_codes,
    int num_args,
    TVMRetValueHandle ret,
    void* resource_handle)

ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)

cdef extern from "tvm/runtime/c_runtime_api.h":
100 101
    void TVMAPISetLastError(const char* msg)
    const char *TVMGetLastError()
102 103 104 105 106 107 108 109
    int TVMFuncCall(TVMFunctionHandle func,
                    TVMValue* arg_values,
                    int* type_codes,
                    int num_args,
                    TVMValue* ret_val,
                    int* ret_type_code)
    int TVMFuncFree(TVMFunctionHandle func)
    int TVMCFuncSetReturn(TVMRetValueHandle ret,
110 111 112
                          TVMValue* value,
                          int* type_code,
                          int num_ret)
113 114 115 116
    int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
                               void* resource_handle,
                               TVMPackedCFuncFinalizer fin,
                               TVMFunctionHandle *out)
117
    int TVMCbArgToReturn(TVMValue* value, int code)
118 119 120 121 122 123 124 125 126
    int TVMArrayAlloc(tvm_index_t* shape,
                      tvm_index_t ndim,
                      DLDataType dtype,
                      DLContext ctx,
                      DLTensorHandle* out)
    int TVMArrayFree(DLTensorHandle handle)
    int TVMArrayCopyFromTo(DLTensorHandle src,
                           DLTensorHandle to,
                           TVMStreamHandle stream)
127 128 129 130 131
    int TVMArrayFromDLPack(DLManagedTensor* arr_from,
                           DLTensorHandle* out)
    int TVMArrayToDLPack(DLTensorHandle arr_from,
                         DLManagedTensor** out)
    void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
132 133 134
    int TVMObjectFree(ObjectHandle obj)
    int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index)

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159

cdef inline py_str(const char* x):
    if PY_MAJOR_VERSION < 3:
        return x
    else:
        return x.decode("utf-8")


cdef inline c_str(pystr):
    """Create ctypes char * from a python string
    Parameters
    ----------
    string : string type
        python string

    Returns
    -------
    str : c_char_p
        A char pointer that can be passed to C API
    """
    return pystr.encode("utf-8")


cdef inline CALL(int ret):
    if ret != 0:
160
        raise get_last_ffi_error()
161 162 163 164 165 166


cdef inline object ctypes_handle(void* chandle):
    """Cast C handle to ctypes handle."""
    return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)

167

168 169 170 171 172
cdef inline void* c_handle(object handle):
    """Cast C types handle to c handle."""
    cdef unsigned long long v_ptr
    v_ptr = handle.value
    return <void*>(v_ptr)