__init__.py 2.11 KB
Newer Older
1 2 3 4
"""Example extension package of TVM."""
from __future__ import absolute_import
import os
import ctypes
5 6
# Import TVM first to get library symbols
import tvm
7 8 9 10

def load_lib():
    """Load library, the functions will be registered into TVM"""
    curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
11 12 13
    # load in as global so the global extern symbol is visible to other dll.
    lib = ctypes.CDLL(
        os.path.join(curr_path, "../../lib/libtvm_ext.so"), ctypes.RTLD_GLOBAL)
14 15 16 17 18 19 20
    return lib

_LIB = load_lib()

# Expose two functions into python
bind_add = tvm.get_global_func("tvm_ext.bind_add")
sym_add = tvm.get_global_func("tvm_ext.sym_add")
21 22
ivec_create = tvm.get_global_func("tvm_ext.ivec_create")
ivec_get = tvm.get_global_func("tvm_ext.ivec_get")
23

24 25 26 27 28 29 30 31 32 33
class IntVec(object):
    """Example for using extension class in c++ """
    _tvm_tcode = 17

    def __init__(self, handle):
        self.handle = handle

    def __del__(self):
        # You can also call your own customized
        # deleter if you can free it via your own FFI.
34
        tvm.nd.free_extension_handle(self.handle, self.__class__._tvm_tcode)
35 36 37 38 39 40 41 42 43 44

    @property
    def _tvm_handle(self):
        return self.handle.value

    def __getitem__(self, idx):
        return ivec_get(self, idx)

# Register IntVec extension on python side.
tvm.register_extension(IntVec, IntVec)
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71


nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info")

class NDSubClass(tvm.nd.NDArrayBase):
    """Example for subclassing TVM's NDArray infrastructure.

    By inheriting TMV's NDArray, external libraries could
    leverage TVM's FFI without any modification.
    """
    # Should be consistent with the type-trait set in the backend
    _array_type_code = 1

    @staticmethod
    def create(addtional_info):
        return nd_create(addtional_info)

    @property
    def addtional_info(self):
        return nd_get_addtional_info(self)

    def __add__(self, other):
        return nd_add_two(self, other)

tvm.register_extension(NDSubClass, NDSubClass)