Unverified Commit f9b46c43 by Tianqi Chen Committed by GitHub

[REFACTOR][PY] tvm._ffi (#4813)

* [REFACTOR][PY] tvm._ffi

- Remove from __future__ import absolute_import in the related files as they are no longer needed if the code only runs in python3
- Remove reverse dependency of _ctypes _cython to object_generic.
- function.py -> packed_func.py
- Function -> PackedFunc
- all registry related logics goes to tvm._ffi.registry
- Use absolute references for FFI related calls.
  - tvm._ffi.register_object
  - tvm._ffi.register_func
  - tvm._ffi.get_global_func

* Move get global func to the ffi side
parent 4a39e521
...@@ -16,13 +16,17 @@ ...@@ -16,13 +16,17 @@
# under the License. # under the License.
# pylint: disable=redefined-builtin, wildcard-import # pylint: disable=redefined-builtin, wildcard-import
"""TVM: Low level DSL/IR stack for tensor computation.""" """TVM: Low level DSL/IR stack for tensor computation."""
from __future__ import absolute_import as _abs
import multiprocessing import multiprocessing
import sys import sys
import traceback import traceback
from . import _pyversion # import ffi related features
from ._ffi.base import TVMError, __version__
from ._ffi.runtime_ctypes import TypeCode, TVMType
from ._ffi.ndarray import TVMContext
from ._ffi.packed_func import PackedFunc as Function
from ._ffi.registry import register_object, register_func, register_extension
from ._ffi.object import Object
from . import tensor from . import tensor
from . import arith from . import arith
...@@ -34,7 +38,6 @@ from . import codegen ...@@ -34,7 +38,6 @@ from . import codegen
from . import container from . import container
from . import schedule from . import schedule
from . import module from . import module
from . import object
from . import attrs from . import attrs
from . import ir_builder from . import ir_builder
from . import target from . import target
...@@ -48,15 +51,9 @@ from . import ndarray as nd ...@@ -48,15 +51,9 @@ from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from ._ffi.runtime_ctypes import TypeCode, TVMType
from ._ffi.ndarray import TVMContext
from ._ffi.function import Function
from ._ffi.base import TVMError, __version__
from .api import * from .api import *
from .intrin import * from .intrin import *
from .tensor_intrin import decl_tensor_intrin from .tensor_intrin import decl_tensor_intrin
from .object import register_object
from .ndarray import register_extension
from .schedule import create_schedule from .schedule import create_schedule
from .build_module import build, lower, build_config from .build_module import build, lower, build_config
from .tag import tag_scope from .tag import tag_scope
......
...@@ -24,3 +24,7 @@ be used via ctypes function calls. ...@@ -24,3 +24,7 @@ be used via ctypes function calls.
Some performance critical functions are implemented by cython Some performance critical functions are implemented by cython
and have a ctypes fallback implementation. and have a ctypes fallback implementation.
""" """
from . import _pyversion
from .base import register_error
from .registry import register_object, register_func, register_extension
from .registry import _init_api, get_global_func
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
# under the License. # under the License.
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Runtime NDArray api""" """Runtime NDArray api"""
from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call, c_str from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle from ..runtime_ctypes import TVMArrayHandle
......
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
# under the License. # under the License.
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Runtime Object api""" """Runtime Object api"""
from __future__ import absolute_import
import ctypes import ctypes
from ..base import _LIB, check_call from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
......
...@@ -17,15 +17,12 @@ ...@@ -17,15 +17,12 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import # pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import
"""Function configuration API.""" """Function configuration API."""
from __future__ import absolute_import
import ctypes import ctypes
import traceback import traceback
from numbers import Number, Integral from numbers import Number, Integral
from ..base import _LIB, get_last_ffi_error, py2cerror from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import c_str, string_types from ..base import c_str, string_types
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from . import ndarray as _nd from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array from .ndarray import NDArrayBase, _make_array
...@@ -35,7 +32,7 @@ from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_in ...@@ -35,7 +32,7 @@ from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_in
from .object import ObjectBase, _set_class_object from .object import ObjectBase, _set_class_object
from . import object as _object from . import object as _object
FunctionHandle = ctypes.c_void_p PackedFuncHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p
ObjectHandle = ctypes.c_void_p ObjectHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p TVMRetValueHandle = ctypes.c_void_p
...@@ -49,6 +46,15 @@ def _ctypes_free_resource(rhandle): ...@@ -49,6 +46,15 @@ def _ctypes_free_resource(rhandle):
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource) TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ)) ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
def _make_packed_func(handle, is_global):
"""Make a packed function class"""
obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
obj.is_global = is_global
obj.handle = handle
return obj
def convert_to_tvm_func(pyfunc): def convert_to_tvm_func(pyfunc):
"""Convert a python function to TVM function """Convert a python function to TVM function
...@@ -89,7 +95,7 @@ def convert_to_tvm_func(pyfunc): ...@@ -89,7 +95,7 @@ def convert_to_tvm_func(pyfunc):
_ = rv _ = rv
return 0 return 0
handle = FunctionHandle() handle = PackedFuncHandle()
f = TVMPackedCFunc(cfun) f = TVMPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f # NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed. # TVM_FREE_PYOBJ will be called after it is no longer needed.
...@@ -98,7 +104,7 @@ def convert_to_tvm_func(pyfunc): ...@@ -98,7 +104,7 @@ def convert_to_tvm_func(pyfunc):
if _LIB.TVMFuncCreateFromCFunc( if _LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0: f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0:
raise get_last_ffi_error() raise get_last_ffi_error()
return _CLASS_FUNCTION(handle, False) return _make_packed_func(handle, False)
def _make_tvm_args(args, temp_args): def _make_tvm_args(args, temp_args):
...@@ -144,15 +150,15 @@ def _make_tvm_args(args, temp_args): ...@@ -144,15 +150,15 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, string_types): elif isinstance(arg, string_types):
values[i].v_str = c_str(arg) values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)): elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
arg = convert_to_object(arg) arg = _FUNC_CONVERT_TO_OBJECT(arg)
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_HANDLE type_codes[i] = TypeCode.OBJECT_HANDLE
temp_args.append(arg) temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE): elif isinstance(arg, _CLASS_MODULE):
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.MODULE_HANDLE type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, FunctionBase): elif isinstance(arg, PackedFuncBase):
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.PACKED_FUNC_HANDLE type_codes[i] = TypeCode.PACKED_FUNC_HANDLE
elif isinstance(arg, ctypes.c_void_p): elif isinstance(arg, ctypes.c_void_p):
...@@ -168,7 +174,7 @@ def _make_tvm_args(args, temp_args): ...@@ -168,7 +174,7 @@ def _make_tvm_args(args, temp_args):
return values, type_codes, num_args return values, type_codes, num_args
class FunctionBase(object): class PackedFuncBase(object):
"""Function base.""" """Function base."""
__slots__ = ["handle", "is_global"] __slots__ = ["handle", "is_global"]
# pylint: disable=no-member # pylint: disable=no-member
...@@ -177,7 +183,7 @@ class FunctionBase(object): ...@@ -177,7 +183,7 @@ class FunctionBase(object):
Parameters Parameters
---------- ----------
handle : FunctionHandle handle : PackedFuncHandle
the handle to the underlying function. the handle to the underlying function.
is_global : bool is_global : bool
...@@ -238,9 +244,22 @@ def _return_module(x): ...@@ -238,9 +244,22 @@ def _return_module(x):
def _handle_return_func(x): def _handle_return_func(x):
"""Return function""" """Return function"""
handle = x.v_handle handle = x.v_handle
if not isinstance(handle, FunctionHandle): if not isinstance(handle, PackedFuncHandle):
handle = FunctionHandle(handle) handle = PackedFuncHandle(handle)
return _CLASS_FUNCTION(handle, False) return _CLASS_PACKED_FUNC(handle, False)
def _get_global_func(name, allow_missing=False):
handle = PackedFuncHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
if handle.value:
return _make_packed_func(handle, False)
if allow_missing:
return None
raise ValueError("Cannot find global function %s" % name)
# setup return handle for function type # setup return handle for function type
_object.__init_by_constructor__ = __init_handle_by_constructor__ _object.__init_by_constructor__ = __init_handle_by_constructor__
...@@ -255,13 +274,22 @@ C_TO_PY_ARG_SWITCH[TypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, ...@@ -255,13 +274,22 @@ C_TO_PY_ARG_SWITCH[TypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle,
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
_CLASS_MODULE = None _CLASS_MODULE = None
_CLASS_FUNCTION = None _CLASS_PACKED_FUNC = None
_CLASS_OBJECT_GENERIC = None
_FUNC_CONVERT_TO_OBJECT = None
def _set_class_module(module_class): def _set_class_module(module_class):
"""Initialize the module.""" """Initialize the module."""
global _CLASS_MODULE global _CLASS_MODULE
_CLASS_MODULE = module_class _CLASS_MODULE = module_class
def _set_class_function(func_class): def _set_class_packed_func(packed_func_class):
global _CLASS_FUNCTION global _CLASS_PACKED_FUNC
_CLASS_FUNCTION = func_class _CLASS_PACKED_FUNC = packed_func_class
def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _CLASS_OBJECT_GENERIC
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
# under the License. # under the License.
"""The C Types used in API.""" """The C Types used in API."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import ctypes import ctypes
import struct import struct
from ..base import py_str, check_call, _LIB from ..base import py_str, check_call, _LIB
......
...@@ -75,7 +75,7 @@ ctypedef int64_t tvm_index_t ...@@ -75,7 +75,7 @@ ctypedef int64_t tvm_index_t
ctypedef DLTensor* DLTensorHandle ctypedef DLTensor* DLTensorHandle
ctypedef void* TVMStreamHandle ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle ctypedef void* TVMPackedFuncHandle
ctypedef void* ObjectHandle ctypedef void* ObjectHandle
ctypedef struct TVMObject: ctypedef struct TVMObject:
...@@ -96,13 +96,15 @@ ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle) ...@@ -96,13 +96,15 @@ ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "tvm/runtime/c_runtime_api.h":
void TVMAPISetLastError(const char* msg) void TVMAPISetLastError(const char* msg)
const char *TVMGetLastError() const char *TVMGetLastError()
int TVMFuncCall(TVMFunctionHandle func, int TVMFuncGetGlobal(const char* name,
TVMPackedFuncHandle* out);
int TVMFuncCall(TVMPackedFuncHandle func,
TVMValue* arg_values, TVMValue* arg_values,
int* type_codes, int* type_codes,
int num_args, int num_args,
TVMValue* ret_val, TVMValue* ret_val,
int* ret_type_code) int* ret_type_code)
int TVMFuncFree(TVMFunctionHandle func) int TVMFuncFree(TVMPackedFuncHandle func)
int TVMCFuncSetReturn(TVMRetValueHandle ret, int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue* value, TVMValue* value,
int* type_code, int* type_code,
...@@ -110,7 +112,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ...@@ -110,7 +112,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
int TVMFuncCreateFromCFunc(TVMPackedCFunc func, int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle, void* resource_handle,
TVMPackedCFuncFinalizer fin, TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out) TVMPackedFuncHandle *out)
int TVMCbArgToReturn(TVMValue* value, int code) int TVMCbArgToReturn(TVMValue* value, int code)
int TVMArrayAlloc(tvm_index_t* shape, int TVMArrayAlloc(tvm_index_t* shape,
tvm_index_t ndim, tvm_index_t ndim,
......
...@@ -17,7 +17,5 @@ ...@@ -17,7 +17,5 @@
include "./base.pxi" include "./base.pxi"
include "./object.pxi" include "./object.pxi"
# include "./node.pxi" include "./packed_func.pxi"
include "./function.pxi"
include "./ndarray.pxi" include "./ndarray.pxi"
...@@ -96,6 +96,6 @@ cdef class ObjectBase: ...@@ -96,6 +96,6 @@ cdef class ObjectBase:
self.chandle = NULL self.chandle = NULL
cdef void* chandle cdef void* chandle
ConstructorCall( ConstructorCall(
(<FunctionBase>fconstructor).chandle, (<PackedFuncBase>fconstructor).chandle,
kTVMObjectHandle, args, &chandle) kTVMObjectHandle, args, &chandle)
self.chandle = chandle self.chandle = chandle
...@@ -20,7 +20,6 @@ import traceback ...@@ -20,7 +20,6 @@ import traceback
from cpython cimport Py_INCREF, Py_DECREF from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral from numbers import Number, Integral
from ..base import string_types, py2cerror from ..base import string_types, py2cerror
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
...@@ -67,6 +66,13 @@ cdef int tvm_callback(TVMValue* args, ...@@ -67,6 +66,13 @@ cdef int tvm_callback(TVMValue* args,
return 0 return 0
cdef object make_packed_func(TVMPackedFuncHandle chandle, int is_global):
obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
(<PackedFuncBase>obj).chandle = chandle
(<PackedFuncBase>obj).is_global = is_global
return obj
def convert_to_tvm_func(object pyfunc): def convert_to_tvm_func(object pyfunc):
"""Convert a python function to TVM function """Convert a python function to TVM function
...@@ -80,15 +86,13 @@ def convert_to_tvm_func(object pyfunc): ...@@ -80,15 +86,13 @@ def convert_to_tvm_func(object pyfunc):
tvmfunc: tvm.Function tvmfunc: tvm.Function
The converted tvm function. The converted tvm function.
""" """
cdef TVMFunctionHandle chandle cdef TVMPackedFuncHandle chandle
Py_INCREF(pyfunc) Py_INCREF(pyfunc)
CALL(TVMFuncCreateFromCFunc(tvm_callback, CALL(TVMFuncCreateFromCFunc(tvm_callback,
<void*>(pyfunc), <void*>(pyfunc),
tvm_callback_finalize, tvm_callback_finalize,
&chandle)) &chandle))
ret = _CLASS_FUNCTION(None, False) return make_packed_func(chandle, False)
(<FunctionBase>ret).chandle = chandle
return ret
cdef inline int make_arg(object arg, cdef inline int make_arg(object arg,
...@@ -149,29 +153,30 @@ cdef inline int make_arg(object arg, ...@@ -149,29 +153,30 @@ cdef inline int make_arg(object arg,
value[0].v_str = tstr value[0].v_str = tstr
tcode[0] = kTVMStr tcode[0] = kTVMStr
temp_args.append(tstr) temp_args.append(tstr)
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)): elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
arg = convert_to_object(arg) arg = _FUNC_CONVERT_TO_OBJECT(arg)
value[0].v_handle = (<ObjectBase>arg).chandle value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kTVMObjectHandle tcode[0] = kTVMObjectHandle
temp_args.append(arg) temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE): elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle) value[0].v_handle = c_handle(arg.handle)
tcode[0] = kTVMModuleHandle tcode[0] = kTVMModuleHandle
elif isinstance(arg, FunctionBase): elif isinstance(arg, PackedFuncBase):
value[0].v_handle = (<FunctionBase>arg).chandle value[0].v_handle = (<PackedFuncBase>arg).chandle
tcode[0] = kTVMPackedFuncHandle tcode[0] = kTVMPackedFuncHandle
elif isinstance(arg, ctypes.c_void_p): elif isinstance(arg, ctypes.c_void_p):
value[0].v_handle = c_handle(arg) value[0].v_handle = c_handle(arg)
tcode[0] = kTVMOpaqueHandle tcode[0] = kTVMOpaqueHandle
elif callable(arg): elif callable(arg):
arg = convert_to_tvm_func(arg) arg = convert_to_tvm_func(arg)
value[0].v_handle = (<FunctionBase>arg).chandle value[0].v_handle = (<PackedFuncBase>arg).chandle
tcode[0] = kTVMPackedFuncHandle tcode[0] = kTVMPackedFuncHandle
temp_args.append(arg) temp_args.append(arg)
else: else:
raise TypeError("Don't know how to handle type %s" % type(arg)) raise TypeError("Don't know how to handle type %s" % type(arg))
return 0 return 0
cdef inline bytearray make_ret_bytes(void* chandle): cdef inline bytearray make_ret_bytes(void* chandle):
handle = ctypes_handle(chandle) handle = ctypes_handle(chandle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0] arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
...@@ -182,6 +187,7 @@ cdef inline bytearray make_ret_bytes(void* chandle): ...@@ -182,6 +187,7 @@ cdef inline bytearray make_ret_bytes(void* chandle):
raise RuntimeError('memmove failed') raise RuntimeError('memmove failed')
return res return res
cdef inline object make_ret(TVMValue value, int tcode): cdef inline object make_ret(TVMValue value, int tcode):
"""convert result to return value.""" """convert result to return value."""
if tcode == kTVMObjectHandle: if tcode == kTVMObjectHandle:
...@@ -205,9 +211,7 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -205,9 +211,7 @@ cdef inline object make_ret(TVMValue value, int tcode):
elif tcode == kTVMModuleHandle: elif tcode == kTVMModuleHandle:
return _CLASS_MODULE(ctypes_handle(value.v_handle)) return _CLASS_MODULE(ctypes_handle(value.v_handle))
elif tcode == kTVMPackedFuncHandle: elif tcode == kTVMPackedFuncHandle:
fobj = _CLASS_FUNCTION(None, False) return make_packed_func(value.v_handle, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
elif tcode in _TVM_EXT_RET: elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
...@@ -264,8 +268,8 @@ cdef inline int ConstructorCall(void* constructor_handle, ...@@ -264,8 +268,8 @@ cdef inline int ConstructorCall(void* constructor_handle,
return 0 return 0
cdef class FunctionBase: cdef class PackedFuncBase:
cdef TVMFunctionHandle chandle cdef TVMPackedFuncHandle chandle
cdef int is_global cdef int is_global
cdef inline _set_handle(self, handle): cdef inline _set_handle(self, handle):
...@@ -305,19 +309,39 @@ cdef class FunctionBase: ...@@ -305,19 +309,39 @@ cdef class FunctionBase:
return make_ret(ret_val, ret_tcode) return make_ret(ret_val, ret_tcode)
_CLASS_FUNCTION = None def _get_global_func(name, allow_missing):
cdef TVMPackedFuncHandle chandle
CALL(TVMFuncGetGlobal(c_str(name), &chandle))
if chandle != NULL:
return make_packed_func(chandle, True)
if allow_missing:
return None
raise ValueError("Cannot find global function %s" % name)
_CLASS_PACKED_FUNC = None
_CLASS_MODULE = None _CLASS_MODULE = None
_CLASS_OBJECT = None _CLASS_OBJECT = None
_CLASS_OBJECT_GENERIC = None
_FUNC_CONVERT_TO_OBJECT = None
def _set_class_module(module_class): def _set_class_module(module_class):
"""Initialize the module.""" """Initialize the module."""
global _CLASS_MODULE global _CLASS_MODULE
_CLASS_MODULE = module_class _CLASS_MODULE = module_class
def _set_class_function(func_class): def _set_class_packed_func(func_class):
global _CLASS_FUNCTION global _CLASS_PACKED_FUNC
_CLASS_FUNCTION = func_class _CLASS_PACKED_FUNC = func_class
def _set_class_object(obj_class): def _set_class_object(obj_class):
global _CLASS_OBJECT global _CLASS_OBJECT
_CLASS_OBJECT = obj_class _CLASS_OBJECT = obj_class
def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _CLASS_OBJECT_GENERIC
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
""" """
import sys import sys
#----------------------------
# Python3 version.
#----------------------------
if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 5): if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 5):
PY3STATEMENT = """TVM project proudly dropped support of Python2. PY3STATEMENT = """TVM project proudly dropped support of Python2.
The minimal Python requirement is Python 3.5 The minimal Python requirement is Python 3.5
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Base library for TVM FFI.""" """Base library for TVM FFI."""
from __future__ import absolute_import
import sys import sys
import os import os
import ctypes import ctypes
...@@ -28,27 +26,22 @@ from . import libinfo ...@@ -28,27 +26,22 @@ from . import libinfo
#---------------------------- #----------------------------
# library loading # library loading
#---------------------------- #----------------------------
if sys.version_info[0] == 3: string_types = (str,)
string_types = (str,) integer_types = (int, np.int32)
integer_types = (int, np.int32) numeric_types = integer_types + (float, np.float32)
numeric_types = integer_types + (float, np.float32)
# this function is needed for python3 # this function is needed for python3
# to convert ctypes.char_p .value back to python str # to convert ctypes.char_p .value back to python str
if sys.platform == "win32": if sys.platform == "win32":
def _py_str(x): def _py_str(x):
try: try:
return x.decode('utf-8') return x.decode('utf-8')
except UnicodeDecodeError: except UnicodeDecodeError:
encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP()) encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP())
return x.decode(encoding) return x.decode(encoding)
py_str = _py_str py_str = _py_str
else:
py_str = lambda x: x.decode('utf-8')
else: else:
string_types = (basestring,) py_str = lambda x: x.decode('utf-8')
integer_types = (int, long, np.int32)
numeric_types = integer_types + (float, np.float32)
py_str = lambda x: x
def _load_lib(): def _load_lib():
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Library information.""" """Library information."""
from __future__ import absolute_import
import sys import sys
import os import os
...@@ -39,6 +38,7 @@ def split_env_var(env_var, split): ...@@ -39,6 +38,7 @@ def split_env_var(env_var, split):
return [p.strip() for p in os.environ[env_var].split(split)] return [p.strip() for p in os.environ[env_var].split(split)]
return [] return []
def find_lib_path(name=None, search_path=None, optional=False): def find_lib_path(name=None, search_path=None, optional=False):
"""Find dynamic library files. """Find dynamic library files.
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
"""Runtime Module namespace."""
import ctypes
from .base import _LIB, check_call, c_str, string_types
from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
class ModuleBase(object):
"""Base class for module"""
__slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle):
self.handle = handle
self._entry = None
self.entry_name = "__tvm_main__"
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
@property
def entry_func(self):
"""Get the entry function
Returns
-------
f : Function
The entry function if exist
"""
if self._entry:
return self._entry
self._entry = self.get_function(self.entry_name)
return self._entry
def get_function(self, name, query_imports=False):
"""Get function from the module.
Parameters
----------
name : str
The name of the function
query_imports : bool
Whether also query modules imported by this module.
Returns
-------
f : Function
The result function.
"""
ret_handle = PackedFuncHandle()
check_call(_LIB.TVMModGetFunction(
self.handle, c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle)))
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
return PackedFunc(ret_handle, False)
def import_module(self, module):
"""Add module to the import list of current one.
Parameters
----------
module : Module
The other module.
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))
def __getitem__(self, name):
if not isinstance(name, string_types):
raise ValueError("Can only take string as function name")
return self.get_function(name)
def __call__(self, *args):
if self._entry:
return self._entry(*args)
f = self.entry_func
return f(*args)
...@@ -16,35 +16,22 @@ ...@@ -16,35 +16,22 @@
# under the License. # under the License.
# pylint: disable=invalid-name, unused-import # pylint: disable=invalid-name, unused-import
"""Runtime NDArray api""" """Runtime NDArray api"""
from __future__ import absolute_import
import sys
import ctypes import ctypes
import numpy as np import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t from .runtime_ctypes import TypeCode, tvm_shape_index_t
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try: try:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import NDArrayBase as _NDArrayBase except (RuntimeError, ImportError):
from ._cy3.core import _reg_extension
else:
from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import _reg_extension
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import _reg_extension
def context(dev_type, dev_id=0): def context(dev_type, dev_id=0):
...@@ -297,59 +284,3 @@ class NDArrayBase(_NDArrayBase): ...@@ -297,59 +284,3 @@ class NDArrayBase(_NDArrayBase):
res = empty(self.shape, self.dtype, target) res = empty(self.shape, self.dtype, target)
return self._copyto(res) return self._copyto(res)
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
def register_extension(cls, fcreate=None):
"""Register a extension class to TVM.
After the class is registered, the class will be able
to directly pass as Function argument generated by TVM.
Parameters
----------
cls : class
The class object to be registered as extension.
fcreate : function, optional
The creation function to create a class object given handle value.
Note
----
The registered class is requires one property: _tvm_handle.
If the registered class is a subclass of NDArray,
it is required to have a class attribute _array_type_code.
Otherwise, it is required to have a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type
code of the class.
Returns
-------
cls : class
The class being registered.
Example
-------
The following code registers user defined class
MyTensor to be DLTensor compatible.
.. code-block:: python
@tvm.register_extension
class MyTensor(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self):
self.handle = _LIB.NewDLTensor()
@property
def _tvm_handle(self):
return self.handle.value
"""
assert hasattr(cls, "_tvm_tcode")
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls
...@@ -16,33 +16,20 @@ ...@@ -16,33 +16,20 @@
# under the License. # under the License.
# pylint: disable=invalid-name, unused-import # pylint: disable=invalid-name, unused-import
"""Runtime Object API""" """Runtime Object API"""
from __future__ import absolute_import
import sys
import ctypes import ctypes
from .. import _api_internal from .. import _api_internal
from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
from .object_generic import ObjectGeneric, convert_to_object, const
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try: try:
# pylint: disable=wrong-import-position,unused-import # pylint: disable=wrong-import-position,unused-import
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): from ._cy3.core import _set_class_object, _set_class_object_generic
from ._cy3.core import _set_class_object from ._cy3.core import ObjectBase
from ._cy3.core import ObjectBase as _ObjectBase except (RuntimeError, ImportError):
from ._cy3.core import _register_object
else:
from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position,unused-import # pylint: disable=wrong-import-position,unused-import
from ._ctypes.function import _set_class_object from ._ctypes.packed_func import _set_class_object, _set_class_object_generic
from ._ctypes.object import ObjectBase as _ObjectBase from ._ctypes.object import ObjectBase
from ._ctypes.object import _register_object
def _new_object(cls): def _new_object(cls):
...@@ -50,7 +37,7 @@ def _new_object(cls): ...@@ -50,7 +37,7 @@ def _new_object(cls):
return cls.__new__(cls) return cls.__new__(cls)
class Object(_ObjectBase): class Object(ObjectBase):
"""Base class for all tvm's runtime objects.""" """Base class for all tvm's runtime objects."""
def __repr__(self): def __repr__(self):
return _api_internal._format_str(self) return _api_internal._format_str(self)
...@@ -104,52 +91,6 @@ class Object(_ObjectBase): ...@@ -104,52 +91,6 @@ class Object(_ObjectBase):
return self.__hash__() == other.__hash__() return self.__hash__() == other.__hash__()
def register_object(type_key=None):
"""register object type.
Parameters
----------
type_key : str or cls
The type key of the node
Examples
--------
The following code registers MyObject
using type key "test.MyObject"
.. code-block:: python
@tvm.register_object("test.MyObject")
class MyObject(Object):
pass
"""
object_name = type_key if isinstance(type_key, str) else type_key.__name__
def register(cls):
"""internal register function"""
if hasattr(cls, "_type_index"):
tindex = cls._type_index
else:
tidx = ctypes.c_uint()
if not _RUNTIME_ONLY:
check_call(_LIB.TVMObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tidx)))
else:
# directly skip unknown objects during runtime.
ret = _LIB.TVMObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tidx))
if ret != 0:
return cls
tindex = tidx.value
_register_object(tindex, cls)
return cls
if isinstance(type_key, str):
return register
return register(type_key)
def getitem_helper(obj, elem_getter, length, idx): def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function. """Helper function to implement a pythonic getitem function.
......
...@@ -16,35 +16,14 @@ ...@@ -16,35 +16,14 @@
# under the License. # under the License.
"""Common implementation of object generic related logic""" """Common implementation of object generic related logic"""
# pylint: disable=unused-import # pylint: disable=unused-import
from __future__ import absolute_import
from numbers import Number, Integral from numbers import Number, Integral
from .. import _api_internal from .. import _api_internal
from .base import string_types
# Object base class
_CLASS_OBJECTS = None
def _set_class_objects(cls):
global _CLASS_OBJECTS
_CLASS_OBJECTS = cls
from .base import string_types
def _scalar_type_inference(value): from .object import ObjectBase, _set_class_object_generic
if hasattr(value, 'dtype'): from .ndarray import NDArrayBase
dtype = str(value.dtype) from .packed_func import PackedFuncBase, convert_to_tvm_func
elif isinstance(value, bool): from .module import ModuleBase
dtype = 'bool'
elif isinstance(value, float):
# We intentionally convert the float to float32 since it's more common in DL.
dtype = 'float32'
elif isinstance(value, int):
# We intentionally convert the python int to int32 since it's more common in DL.
dtype = 'int32'
else:
raise NotImplementedError('Cannot automatically inference the type.'
' value={}'.format(value))
return dtype
class ObjectGeneric(object): class ObjectGeneric(object):
...@@ -54,6 +33,9 @@ class ObjectGeneric(object): ...@@ -54,6 +33,9 @@ class ObjectGeneric(object):
raise NotImplementedError() raise NotImplementedError()
_CLASS_OBJECTS = (ObjectBase, NDArrayBase, ModuleBase)
def convert_to_object(value): def convert_to_object(value):
"""Convert a python value to corresponding object type. """Convert a python value to corresponding object type.
...@@ -95,22 +77,65 @@ def convert_to_object(value): ...@@ -95,22 +77,65 @@ def convert_to_object(value):
raise ValueError("don't know how to convert type %s to object" % type(value)) raise ValueError("don't know how to convert type %s to object" % type(value))
def convert(value):
"""Convert value to TVM object or function.
Parameters
----------
value : python value
Returns
-------
tvm_val : Object or Function
Converted value in TVM
"""
if isinstance(value, (PackedFuncBase, ObjectBase)):
return value
if callable(value):
return convert_to_tvm_func(value)
return convert_to_object(value)
def _scalar_type_inference(value):
if hasattr(value, 'dtype'):
dtype = str(value.dtype)
elif isinstance(value, bool):
dtype = 'bool'
elif isinstance(value, float):
# We intentionally convert the float to float32 since it's more common in DL.
dtype = 'float32'
elif isinstance(value, int):
# We intentionally convert the python int to int32 since it's more common in DL.
dtype = 'int32'
else:
raise NotImplementedError('Cannot automatically inference the type.'
' value={}'.format(value))
return dtype
def const(value, dtype=None): def const(value, dtype=None):
"""Construct a constant value for a given type. """construct a constant
Parameters Parameters
---------- ----------
value : int or float value : number
The input value The content of the constant number.
dtype : str or None, optional dtype : str or None, optional
The data type. The data type.
Returns Returns
------- -------
expr : Expr const_val: tvm.Expr
Constant expression corresponds to the value. The result expression.
""" """
if dtype is None: if dtype is None:
dtype = _scalar_type_inference(value) dtype = _scalar_type_inference(value)
if dtype == "uint64" and value >= (1 << 63):
return _api_internal._LargeUIntImm(
dtype, value & ((1 << 32) - 1), value >> 32)
return _api_internal._const(value, dtype) return _api_internal._const(value, dtype)
_set_class_object_generic(ObjectGeneric, convert_to_object)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
"""Packed Function namespace."""
import ctypes
from .base import _LIB, check_call, c_str, string_types, _FFI_MODE
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
from ._cy3.core import _set_class_packed_func, _set_class_module
from ._cy3.core import PackedFuncBase
from ._cy3.core import convert_to_tvm_func
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position
from ._ctypes.packed_func import _set_class_packed_func, _set_class_module
from ._ctypes.packed_func import PackedFuncBase
from ._ctypes.packed_func import convert_to_tvm_func
PackedFuncHandle = ctypes.c_void_p
class PackedFunc(PackedFuncBase):
"""The PackedFunc object used in TVM.
Function plays an key role to bridge front and backend in TVM.
Function provide a type-erased interface, you can call function with positional arguments.
The compiled module returns Function.
TVM backend also registers and exposes its API as Functions.
For example, the developer function exposed in tvm.ir_pass are actually
C++ functions that are registered as PackedFunc
The following are list of common usage scenario of tvm.Function.
- Automatic exposure of C++ API into python
- To call PackedFunc from python side
- To call python callbacks to inspect results in generated code
- Bring python hook into C++ backend
See Also
--------
tvm.register_func: How to register global function.
tvm.get_global_func: How to get global function.
"""
_set_class_packed_func(PackedFunc)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
"""FFI registry to register function and objects."""
import sys
import ctypes
from .. import _api_internal
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE, _RUNTIME_ONLY
try:
# pylint: disable=wrong-import-position,unused-import
if _FFI_MODE == "ctypes":
raise ImportError()
from ._cy3.core import _register_object
from ._cy3.core import _reg_extension
from ._cy3.core import convert_to_tvm_func, _get_global_func, PackedFuncBase
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import
from ._ctypes.object import _register_object
from ._ctypes.ndarray import _reg_extension
from ._ctypes.packed_func import convert_to_tvm_func, _get_global_func, PackedFuncBase
def register_object(type_key=None):
"""register object type.
Parameters
----------
type_key : str or cls
The type key of the node
Examples
--------
The following code registers MyObject
using type key "test.MyObject"
.. code-block:: python
@tvm.register_object("test.MyObject")
class MyObject(Object):
pass
"""
object_name = type_key if isinstance(type_key, str) else type_key.__name__
def register(cls):
"""internal register function"""
if hasattr(cls, "_type_index"):
tindex = cls._type_index
else:
tidx = ctypes.c_uint()
if not _RUNTIME_ONLY:
check_call(_LIB.TVMObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tidx)))
else:
# directly skip unknown objects during runtime.
ret = _LIB.TVMObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tidx))
if ret != 0:
return cls
tindex = tidx.value
_register_object(tindex, cls)
return cls
if isinstance(type_key, str):
return register
return register(type_key)
def register_extension(cls, fcreate=None):
"""Register a extension class to TVM.
After the class is registered, the class will be able
to directly pass as Function argument generated by TVM.
Parameters
----------
cls : class
The class object to be registered as extension.
fcreate : function, optional
The creation function to create a class object given handle value.
Note
----
The registered class is requires one property: _tvm_handle.
If the registered class is a subclass of NDArray,
it is required to have a class attribute _array_type_code.
Otherwise, it is required to have a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type
code of the class.
Returns
-------
cls : class
The class being registered.
Example
-------
The following code registers user defined class
MyTensor to be DLTensor compatible.
.. code-block:: python
@tvm.register_extension
class MyTensor(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self):
self.handle = _LIB.NewDLTensor()
@property
def _tvm_handle(self):
return self.handle.value
"""
assert hasattr(cls, "_tvm_tcode")
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls
def register_func(func_name, f=None, override=False):
"""Register global function
Parameters
----------
func_name : str or function
The function name
f : function, optional
The function to be registered.
override: boolean optional
Whether override existing entry.
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
The following code registers my_packed_func as global function.
Note that we simply get it back from global function table to invoke
it from python side. However, we can also invoke the same function
from C++ backend, or in the compiled TVM code.
.. code-block:: python
targs = (10, 10.0, "hello")
@tvm.register_func
def my_packed_func(*args):
assert(tuple(args) == targs)
return 10
# Get it out from global function table
f = tvm.get_global_func("my_packed_func")
assert isinstance(f, tvm.PackedFunc)
y = f(*targs)
assert y == 10
"""
if callable(func_name):
f = func_name
func_name = f.__name__
if not isinstance(func_name, str):
raise ValueError("expect string function name")
ioverride = ctypes.c_int(override)
def register(myf):
"""internal register function"""
if not isinstance(myf, PackedFuncBase):
myf = convert_to_tvm_func(myf)
check_call(_LIB.TVMFuncRegisterGlobal(
c_str(func_name), myf.handle, ioverride))
return myf
if f:
return register(f)
return register
def get_global_func(name, allow_missing=False):
"""Get a global function by name
Parameters
----------
name : str
The name of the global function
allow_missing : bool
Whether allow missing function or raise an error.
Returns
-------
func : PackedFunc
The function to be returned, None if function is missing.
"""
return _get_global_func(name, allow_missing)
def list_global_func_names():
"""Get list of global functions registered.
Returns
-------
names : list
List of global functions names.
"""
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size),
ctypes.byref(plist)))
fnames = []
for i in range(size.value):
fnames.append(py_str(plist[i]))
return fnames
def extract_ext_funcs(finit):
"""
Extract the extension PackedFuncs from a C module.
Parameters
----------
finit : ctypes function
a ctypes that takes signature of TVMExtensionDeclarer
Returns
-------
fdict : dict of str to Function
The extracted functions
"""
fdict = {}
def _list(name, func):
fdict[name] = func
myf = convert_to_tvm_func(_list)
ret = finit(myf.handle)
_ = myf
if ret != 0:
raise RuntimeError("cannot initialize with %s" % finit)
return fdict
def _get_api(f):
flocal = f
flocal.is_global = True
return flocal
def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
namespace : str
The namespace of the source registry
target_module_name : str
The target module name if different from namespace
"""
target_module_name = (
target_module_name if target_module_name else namespace)
if namespace.startswith("tvm."):
_init_api_prefix(target_module_name, namespace[4:])
else:
_init_api_prefix(target_module_name, namespace)
def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]
for name in list_global_func_names():
if prefix == "api":
fname = name
if name.startswith("_"):
target_module = sys.modules["tvm._api_internal"]
else:
target_module = module
else:
if not name.startswith(prefix):
continue
fname = name[len(prefix)+1:]
target_module = module
if fname.find(".") != -1:
continue
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = ("TVM PackedFunc %s. " % fname)
setattr(target_module, ff.__name__, ff)
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
# under the License. # under the License.
"""Common runtime ctypes.""" """Common runtime ctypes."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
from __future__ import absolute_import
import ctypes import ctypes
import json import json
import numpy as np import numpy as np
......
...@@ -16,17 +16,13 @@ ...@@ -16,17 +16,13 @@
# under the License. # under the License.
"""Functions defined in TVM.""" """Functions defined in TVM."""
# pylint: disable=invalid-name,unused-import,redefined-builtin # pylint: disable=invalid-name,unused-import,redefined-builtin
from __future__ import absolute_import as _abs
from numbers import Integral as _Integral from numbers import Integral as _Integral
import tvm._ffi
from ._ffi.base import string_types, TVMError from ._ffi.base import string_types, TVMError
from ._ffi.object import register_object, Object from ._ffi.object_generic import convert, const
from ._ffi.object import convert_to_object as _convert_to_object from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
from ._ffi.object_generic import _scalar_type_inference
from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from ._ffi.runtime_ctypes import TVMType from ._ffi.runtime_ctypes import TVMType
from . import _api_internal from . import _api_internal
from . import make as _make from . import make as _make
...@@ -75,30 +71,6 @@ def max_value(dtype): ...@@ -75,30 +71,6 @@ def max_value(dtype):
return _api_internal._max_value(dtype) return _api_internal._max_value(dtype)
def const(value, dtype=None):
"""construct a constant
Parameters
----------
value : number
The content of the constant number.
dtype : str or None, optional
The data type.
Returns
-------
const_val: tvm.Expr
The result expression.
"""
if dtype is None:
dtype = _scalar_type_inference(value)
if dtype == "uint64" and value >= (1 << 63):
return _api_internal._LargeUIntImm(
dtype, value & ((1 << 32) - 1), value >> 32)
return _api_internal._const(value, dtype)
def get_env_func(name): def get_env_func(name):
"""Get an EnvFunc by a global name. """Get an EnvFunc by a global name.
...@@ -121,27 +93,6 @@ def get_env_func(name): ...@@ -121,27 +93,6 @@ def get_env_func(name):
return _api_internal._EnvFuncGet(name) return _api_internal._EnvFuncGet(name)
def convert(value):
"""Convert value to TVM node or function.
Parameters
----------
value : python value
Returns
-------
tvm_val : Object or Function
Converted value in TVM
"""
if isinstance(value, (Function, Object)):
return value
if callable(value):
return _convert_tvm_func(value)
return _convert_to_object(value)
def load_json(json_str): def load_json(json_str):
"""Load tvm object from json_str. """Load tvm object from json_str.
...@@ -1073,10 +1024,9 @@ def floormod(a, b): ...@@ -1073,10 +1024,9 @@ def floormod(a, b):
""" """
return _make._OpFloorMod(a, b) return _make._OpFloorMod(a, b)
_init_api("tvm.api")
#pylint: disable=unnecessary-lambda #pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum") sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min') min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max') max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max')
tvm._ffi._init_api("tvm.api")
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
# under the License. # under the License.
"""Arithmetic data structure and utility""" """Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm._ffi
from ._ffi.object import Object, register_object from ._ffi.object import Object
from ._ffi.function import _init_api
from . import _api_internal from . import _api_internal
class IntSet(Object): class IntSet(Object):
...@@ -32,7 +32,7 @@ class IntSet(Object): ...@@ -32,7 +32,7 @@ class IntSet(Object):
return _api_internal._IntSetIsEverything(self) return _api_internal._IntSetIsEverything(self)
@register_object("arith.IntervalSet") @tvm._ffi.register_object("arith.IntervalSet")
class IntervalSet(IntSet): class IntervalSet(IntSet):
"""Represent set of continuous interval [min_value, max_value] """Represent set of continuous interval [min_value, max_value]
...@@ -49,7 +49,7 @@ class IntervalSet(IntSet): ...@@ -49,7 +49,7 @@ class IntervalSet(IntSet):
_make_IntervalSet, min_value, max_value) _make_IntervalSet, min_value, max_value)
@register_object("arith.ModularSet") @tvm._ffi.register_object("arith.ModularSet")
class ModularSet(Object): class ModularSet(Object):
"""Represent range of (coeff * x + base) for x in Z """ """Represent range of (coeff * x + base) for x in Z """
def __init__(self, coeff, base): def __init__(self, coeff, base):
...@@ -57,7 +57,7 @@ class ModularSet(Object): ...@@ -57,7 +57,7 @@ class ModularSet(Object):
_make_ModularSet, coeff, base) _make_ModularSet, coeff, base)
@register_object("arith.ConstIntBound") @tvm._ffi.register_object("arith.ConstIntBound")
class ConstIntBound(Object): class ConstIntBound(Object):
"""Represent constant integer bound """Represent constant integer bound
...@@ -258,4 +258,4 @@ class Analyzer: ...@@ -258,4 +258,4 @@ class Analyzer:
"Do not know how to handle type {}".format(type(info))) "Do not know how to handle type {}".format(type(info)))
_init_api("tvm.arith") tvm._ffi._init_api("tvm.arith")
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
""" TVM Attribute module, which is mainly used for defining attributes of operators""" """ TVM Attribute module, which is mainly used for defining attributes of operators"""
from ._ffi.object import Object, register_object import tvm._ffi
from ._ffi.function import _init_api
from ._ffi.object import Object
from . import _api_internal from . import _api_internal
@register_object @tvm._ffi.register_object
class Attrs(Object): class Attrs(Object):
"""Attribute node, which is mainly use for defining attributes of relay operators. """Attribute node, which is mainly use for defining attributes of relay operators.
...@@ -92,4 +93,4 @@ class Attrs(Object): ...@@ -92,4 +93,4 @@ class Attrs(Object):
return self.__getattr__(item) return self.__getattr__(item)
_init_api("tvm.attrs") tvm._ffi._init_api("tvm.attrs")
...@@ -19,11 +19,10 @@ ...@@ -19,11 +19,10 @@
This module provides the functions to transform schedule to This module provides the functions to transform schedule to
LoweredFunc and compiled Module. LoweredFunc and compiled Module.
""" """
from __future__ import absolute_import as _abs
import warnings import warnings
import tvm._ffi
from ._ffi.function import Function from ._ffi.object import Object
from ._ffi.object import Object, register_object
from . import api from . import api
from . import _api_internal from . import _api_internal
from . import tensor from . import tensor
...@@ -115,7 +114,7 @@ class DumpIR(object): ...@@ -115,7 +114,7 @@ class DumpIR(object):
DumpIR.scope_level -= 1 DumpIR.scope_level -= 1
@register_object @tvm._ffi.register_object
class BuildConfig(Object): class BuildConfig(Object):
"""Configuration scope to set a build config option. """Configuration scope to set a build config option.
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Code generation related functions.""" """Code generation related functions."""
from ._ffi.function import _init_api import tvm._ffi
def build_module(lowered_func, target): def build_module(lowered_func, target):
"""Build lowered_func into Module. """Build lowered_func into Module.
...@@ -35,4 +35,4 @@ def build_module(lowered_func, target): ...@@ -35,4 +35,4 @@ def build_module(lowered_func, target):
""" """
return _Build(lowered_func, target) return _Build(lowered_func, target)
_init_api("tvm.codegen") tvm._ffi._init_api("tvm.codegen")
...@@ -15,13 +15,14 @@ ...@@ -15,13 +15,14 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Container data structures used in TVM DSL.""" """Container data structures used in TVM DSL."""
from __future__ import absolute_import as _abs import tvm._ffi
from tvm import ndarray as _nd from tvm import ndarray as _nd
from . import _api_internal from . import _api_internal
from ._ffi.object import Object, register_object, getitem_helper from ._ffi.object import Object, getitem_helper
from ._ffi.function import _init_api
@register_object @tvm._ffi.register_object
class Array(Object): class Array(Object):
"""Array container of TVM. """Array container of TVM.
...@@ -52,7 +53,7 @@ class Array(Object): ...@@ -52,7 +53,7 @@ class Array(Object):
return _api_internal._ArraySize(self) return _api_internal._ArraySize(self)
@register_object @tvm._ffi.register_object
class EnvFunc(Object): class EnvFunc(Object):
"""Environment function. """Environment function.
...@@ -66,7 +67,7 @@ class EnvFunc(Object): ...@@ -66,7 +67,7 @@ class EnvFunc(Object):
return _api_internal._EnvFuncGetPackedFunc(self) return _api_internal._EnvFuncGetPackedFunc(self)
@register_object @tvm._ffi.register_object
class Map(Object): class Map(Object):
"""Map container of TVM. """Map container of TVM.
...@@ -89,7 +90,7 @@ class Map(Object): ...@@ -89,7 +90,7 @@ class Map(Object):
return _api_internal._MapSize(self) return _api_internal._MapSize(self)
@register_object @tvm._ffi.register_object
class StrMap(Map): class StrMap(Map):
"""A special map container that has str as key. """A special map container that has str as key.
...@@ -101,7 +102,7 @@ class StrMap(Map): ...@@ -101,7 +102,7 @@ class StrMap(Map):
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)] return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
@register_object @tvm._ffi.register_object
class Range(Object): class Range(Object):
"""Represent a range in TVM. """Represent a range in TVM.
...@@ -110,7 +111,7 @@ class Range(Object): ...@@ -110,7 +111,7 @@ class Range(Object):
""" """
@register_object @tvm._ffi.register_object
class LoweredFunc(Object): class LoweredFunc(Object):
"""Represent a LoweredFunc in TVM.""" """Represent a LoweredFunc in TVM."""
MixedFunc = 0 MixedFunc = 0
...@@ -118,7 +119,7 @@ class LoweredFunc(Object): ...@@ -118,7 +119,7 @@ class LoweredFunc(Object):
DeviceFunc = 2 DeviceFunc = 2
@register_object("vm.ADT") @tvm._ffi.register_object("vm.ADT")
class ADT(Object): class ADT(Object):
"""Algebatic data type(ADT) object. """Algebatic data type(ADT) object.
...@@ -168,4 +169,4 @@ def tuple_object(fields=None): ...@@ -168,4 +169,4 @@ def tuple_object(fields=None):
return _Tuple(*fields) return _Tuple(*fields)
_init_api("tvm.container") tvm._ffi._init_api("tvm.container")
...@@ -19,8 +19,9 @@ ...@@ -19,8 +19,9 @@
import os import os
import tempfile import tempfile
import shutil import shutil
import tvm._ffi
from tvm._ffi.base import string_types from tvm._ffi.base import string_types
from tvm._ffi.function import get_global_func
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.ndarray import array from tvm.ndarray import array
from . import debug_result from . import debug_result
...@@ -64,7 +65,7 @@ def create(graph_json_str, libmod, ctx, dump_root=None): ...@@ -64,7 +65,7 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
fcreate = ctx[0]._rpc_sess.get_function( fcreate = ctx[0]._rpc_sess.get_function(
"tvm.graph_runtime_debug.create") "tvm.graph_runtime_debug.create")
else: else:
fcreate = get_global_func("tvm.graph_runtime_debug.create") fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_debug.create")
except ValueError: except ValueError:
raise ValueError( raise ValueError(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " "Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
......
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
# under the License. # under the License.
"""Minimum graph runtime that executes graph containing TVM PackedFunc.""" """Minimum graph runtime that executes graph containing TVM PackedFunc."""
import numpy as np import numpy as np
import tvm._ffi
from .._ffi.base import string_types from .._ffi.base import string_types
from .._ffi.function import get_global_func
from .._ffi.runtime_ctypes import TVMContext from .._ffi.runtime_ctypes import TVMContext
from ..rpc import base as rpc_base from ..rpc import base as rpc_base
...@@ -54,7 +54,7 @@ def create(graph_json_str, libmod, ctx): ...@@ -54,7 +54,7 @@ def create(graph_json_str, libmod, ctx):
if num_rpc_ctx == len(ctx): if num_rpc_ctx == len(ctx):
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create") fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create")
else: else:
fcreate = get_global_func("tvm.graph_runtime.create") fcreate = tvm._ffi.get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, *device_type_id)) return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
......
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""External function interface to NNPACK libraries.""" """External function interface to NNPACK libraries."""
from __future__ import absolute_import as _abs import tvm._ffi
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin from .. import intrin as _intrin
from .._ffi.function import _init_api
def is_available(): def is_available():
"""Check whether NNPACK is available, that is, `nnp_initialize()` """Check whether NNPACK is available, that is, `nnp_initialize()`
...@@ -202,4 +202,4 @@ def convolution_inference_weight_transform( ...@@ -202,4 +202,4 @@ def convolution_inference_weight_transform(
"tvm.contrib.nnpack.convolution_inference_weight_transform", "tvm.contrib.nnpack.convolution_inference_weight_transform",
ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype) ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype)
_init_api("tvm.contrib.nnpack") tvm._ffi._init_api("tvm.contrib.nnpack")
...@@ -15,11 +15,10 @@ ...@@ -15,11 +15,10 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""External function interface to random library.""" """External function interface to random library."""
from __future__ import absolute_import as _abs import tvm._ffi
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin from .. import intrin as _intrin
from .._ffi.function import _init_api
def randint(low, high, size, dtype='int32'): def randint(low, high, size, dtype='int32'):
...@@ -96,4 +95,4 @@ def normal(loc, scale, size): ...@@ -96,4 +95,4 @@ def normal(loc, scale, size):
"tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32') "tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32')
_init_api("tvm.contrib.random") tvm._ffi._init_api("tvm.contrib.random")
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""TFLite runtime that load and run tflite models.""" """TFLite runtime that load and run tflite models."""
from .._ffi.function import get_global_func import tvm._ffi
from ..rpc import base as rpc_base from ..rpc import base as rpc_base
def create(tflite_model_bytes, ctx, runtime_target='cpu'): def create(tflite_model_bytes, ctx, runtime_target='cpu'):
...@@ -44,7 +44,7 @@ def create(tflite_model_bytes, ctx, runtime_target='cpu'): ...@@ -44,7 +44,7 @@ def create(tflite_model_bytes, ctx, runtime_target='cpu'):
if device_type >= rpc_base.RPC_SESS_MASK: if device_type >= rpc_base.RPC_SESS_MASK:
fcreate = ctx._rpc_sess.get_function(runtime_func) fcreate = ctx._rpc_sess.get_function(runtime_func)
else: else:
fcreate = get_global_func(runtime_func) fcreate = tvm._ffi.get_global_func(runtime_func)
return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx)) return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))
......
...@@ -15,9 +15,8 @@ ...@@ -15,9 +15,8 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Custom datatype functionality""" """Custom datatype functionality"""
from __future__ import absolute_import as _abs import tvm._ffi
from ._ffi.function import register_func as _register_func
from . import make as _make from . import make as _make
from .api import convert from .api import convert
from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
...@@ -111,7 +110,7 @@ def register_op(lower_func, op_name, target, type_name, src_type_name=None): ...@@ -111,7 +110,7 @@ def register_op(lower_func, op_name, target, type_name, src_type_name=None):
else: else:
lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \ lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \
+ type_name + type_name
_register_func(lower_func_name, lower_func) tvm._ffi.register_func(lower_func_name, lower_func)
def create_lower_func(extern_func_name): def create_lower_func(extern_func_name):
......
...@@ -32,7 +32,10 @@ For example, you can use addexp.a to get the left operand of an Add node. ...@@ -32,7 +32,10 @@ For example, you can use addexp.a to get the left operand of an Add node.
""" """
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.object import Object, register_object, ObjectGeneric import tvm._ffi
from ._ffi.object import Object
from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType, TypeCode from ._ffi.runtime_ctypes import TVMType, TypeCode
from . import make as _make from . import make as _make
from . import generic as _generic from . import generic as _generic
...@@ -261,7 +264,7 @@ class CmpExpr(PrimExpr): ...@@ -261,7 +264,7 @@ class CmpExpr(PrimExpr):
class LogicalExpr(PrimExpr): class LogicalExpr(PrimExpr):
pass pass
@register_object("Variable") @tvm._ffi.register_object("Variable")
class Var(PrimExpr): class Var(PrimExpr):
"""Symbolic variable. """Symbolic variable.
...@@ -278,7 +281,7 @@ class Var(PrimExpr): ...@@ -278,7 +281,7 @@ class Var(PrimExpr):
_api_internal._Var, name, dtype) _api_internal._Var, name, dtype)
@register_object @tvm._ffi.register_object
class SizeVar(Var): class SizeVar(Var):
"""Symbolic variable to represent a tensor index size """Symbolic variable to represent a tensor index size
which is greater or equal to zero which is greater or equal to zero
...@@ -297,7 +300,7 @@ class SizeVar(Var): ...@@ -297,7 +300,7 @@ class SizeVar(Var):
_api_internal._SizeVar, name, dtype) _api_internal._SizeVar, name, dtype)
@register_object @tvm._ffi.register_object
class Reduce(PrimExpr): class Reduce(PrimExpr):
"""Reduce node. """Reduce node.
...@@ -324,7 +327,7 @@ class Reduce(PrimExpr): ...@@ -324,7 +327,7 @@ class Reduce(PrimExpr):
condition, value_index) condition, value_index)
@register_object @tvm._ffi.register_object
class FloatImm(ConstExpr): class FloatImm(ConstExpr):
"""Float constant. """Float constant.
...@@ -340,7 +343,7 @@ class FloatImm(ConstExpr): ...@@ -340,7 +343,7 @@ class FloatImm(ConstExpr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.FloatImm, dtype, value) _make.FloatImm, dtype, value)
@register_object @tvm._ffi.register_object
class IntImm(ConstExpr): class IntImm(ConstExpr):
"""Int constant. """Int constant.
...@@ -360,7 +363,7 @@ class IntImm(ConstExpr): ...@@ -360,7 +363,7 @@ class IntImm(ConstExpr):
return self.value return self.value
@register_object @tvm._ffi.register_object
class StringImm(ConstExpr): class StringImm(ConstExpr):
"""String constant. """String constant.
...@@ -384,7 +387,7 @@ class StringImm(ConstExpr): ...@@ -384,7 +387,7 @@ class StringImm(ConstExpr):
return self.value != other return self.value != other
@register_object @tvm._ffi.register_object
class Cast(PrimExpr): class Cast(PrimExpr):
"""Cast expression. """Cast expression.
...@@ -401,7 +404,7 @@ class Cast(PrimExpr): ...@@ -401,7 +404,7 @@ class Cast(PrimExpr):
_make.Cast, dtype, value) _make.Cast, dtype, value)
@register_object @tvm._ffi.register_object
class Add(BinaryOpExpr): class Add(BinaryOpExpr):
"""Add node. """Add node.
...@@ -418,7 +421,7 @@ class Add(BinaryOpExpr): ...@@ -418,7 +421,7 @@ class Add(BinaryOpExpr):
_make.Add, a, b) _make.Add, a, b)
@register_object @tvm._ffi.register_object
class Sub(BinaryOpExpr): class Sub(BinaryOpExpr):
"""Sub node. """Sub node.
...@@ -435,7 +438,7 @@ class Sub(BinaryOpExpr): ...@@ -435,7 +438,7 @@ class Sub(BinaryOpExpr):
_make.Sub, a, b) _make.Sub, a, b)
@register_object @tvm._ffi.register_object
class Mul(BinaryOpExpr): class Mul(BinaryOpExpr):
"""Mul node. """Mul node.
...@@ -452,7 +455,7 @@ class Mul(BinaryOpExpr): ...@@ -452,7 +455,7 @@ class Mul(BinaryOpExpr):
_make.Mul, a, b) _make.Mul, a, b)
@register_object @tvm._ffi.register_object
class Div(BinaryOpExpr): class Div(BinaryOpExpr):
"""Div node. """Div node.
...@@ -469,7 +472,7 @@ class Div(BinaryOpExpr): ...@@ -469,7 +472,7 @@ class Div(BinaryOpExpr):
_make.Div, a, b) _make.Div, a, b)
@register_object @tvm._ffi.register_object
class Mod(BinaryOpExpr): class Mod(BinaryOpExpr):
"""Mod node. """Mod node.
...@@ -486,7 +489,7 @@ class Mod(BinaryOpExpr): ...@@ -486,7 +489,7 @@ class Mod(BinaryOpExpr):
_make.Mod, a, b) _make.Mod, a, b)
@register_object @tvm._ffi.register_object
class FloorDiv(BinaryOpExpr): class FloorDiv(BinaryOpExpr):
"""FloorDiv node. """FloorDiv node.
...@@ -503,7 +506,7 @@ class FloorDiv(BinaryOpExpr): ...@@ -503,7 +506,7 @@ class FloorDiv(BinaryOpExpr):
_make.FloorDiv, a, b) _make.FloorDiv, a, b)
@register_object @tvm._ffi.register_object
class FloorMod(BinaryOpExpr): class FloorMod(BinaryOpExpr):
"""FloorMod node. """FloorMod node.
...@@ -520,7 +523,7 @@ class FloorMod(BinaryOpExpr): ...@@ -520,7 +523,7 @@ class FloorMod(BinaryOpExpr):
_make.FloorMod, a, b) _make.FloorMod, a, b)
@register_object @tvm._ffi.register_object
class Min(BinaryOpExpr): class Min(BinaryOpExpr):
"""Min node. """Min node.
...@@ -537,7 +540,7 @@ class Min(BinaryOpExpr): ...@@ -537,7 +540,7 @@ class Min(BinaryOpExpr):
_make.Min, a, b) _make.Min, a, b)
@register_object @tvm._ffi.register_object
class Max(BinaryOpExpr): class Max(BinaryOpExpr):
"""Max node. """Max node.
...@@ -554,7 +557,7 @@ class Max(BinaryOpExpr): ...@@ -554,7 +557,7 @@ class Max(BinaryOpExpr):
_make.Max, a, b) _make.Max, a, b)
@register_object @tvm._ffi.register_object
class EQ(CmpExpr): class EQ(CmpExpr):
"""EQ node. """EQ node.
...@@ -571,7 +574,7 @@ class EQ(CmpExpr): ...@@ -571,7 +574,7 @@ class EQ(CmpExpr):
_make.EQ, a, b) _make.EQ, a, b)
@register_object @tvm._ffi.register_object
class NE(CmpExpr): class NE(CmpExpr):
"""NE node. """NE node.
...@@ -588,7 +591,7 @@ class NE(CmpExpr): ...@@ -588,7 +591,7 @@ class NE(CmpExpr):
_make.NE, a, b) _make.NE, a, b)
@register_object @tvm._ffi.register_object
class LT(CmpExpr): class LT(CmpExpr):
"""LT node. """LT node.
...@@ -605,7 +608,7 @@ class LT(CmpExpr): ...@@ -605,7 +608,7 @@ class LT(CmpExpr):
_make.LT, a, b) _make.LT, a, b)
@register_object @tvm._ffi.register_object
class LE(CmpExpr): class LE(CmpExpr):
"""LE node. """LE node.
...@@ -622,7 +625,7 @@ class LE(CmpExpr): ...@@ -622,7 +625,7 @@ class LE(CmpExpr):
_make.LE, a, b) _make.LE, a, b)
@register_object @tvm._ffi.register_object
class GT(CmpExpr): class GT(CmpExpr):
"""GT node. """GT node.
...@@ -639,7 +642,7 @@ class GT(CmpExpr): ...@@ -639,7 +642,7 @@ class GT(CmpExpr):
_make.GT, a, b) _make.GT, a, b)
@register_object @tvm._ffi.register_object
class GE(CmpExpr): class GE(CmpExpr):
"""GE node. """GE node.
...@@ -656,7 +659,7 @@ class GE(CmpExpr): ...@@ -656,7 +659,7 @@ class GE(CmpExpr):
_make.GE, a, b) _make.GE, a, b)
@register_object @tvm._ffi.register_object
class And(LogicalExpr): class And(LogicalExpr):
"""And node. """And node.
...@@ -673,7 +676,7 @@ class And(LogicalExpr): ...@@ -673,7 +676,7 @@ class And(LogicalExpr):
_make.And, a, b) _make.And, a, b)
@register_object @tvm._ffi.register_object
class Or(LogicalExpr): class Or(LogicalExpr):
"""Or node. """Or node.
...@@ -690,7 +693,7 @@ class Or(LogicalExpr): ...@@ -690,7 +693,7 @@ class Or(LogicalExpr):
_make.Or, a, b) _make.Or, a, b)
@register_object @tvm._ffi.register_object
class Not(LogicalExpr): class Not(LogicalExpr):
"""Not node. """Not node.
...@@ -704,7 +707,7 @@ class Not(LogicalExpr): ...@@ -704,7 +707,7 @@ class Not(LogicalExpr):
_make.Not, a) _make.Not, a)
@register_object @tvm._ffi.register_object
class Select(PrimExpr): class Select(PrimExpr):
"""Select node. """Select node.
...@@ -732,7 +735,7 @@ class Select(PrimExpr): ...@@ -732,7 +735,7 @@ class Select(PrimExpr):
_make.Select, condition, true_value, false_value) _make.Select, condition, true_value, false_value)
@register_object @tvm._ffi.register_object
class Load(PrimExpr): class Load(PrimExpr):
"""Load node. """Load node.
...@@ -755,7 +758,7 @@ class Load(PrimExpr): ...@@ -755,7 +758,7 @@ class Load(PrimExpr):
_make.Load, dtype, buffer_var, index, predicate) _make.Load, dtype, buffer_var, index, predicate)
@register_object @tvm._ffi.register_object
class Ramp(PrimExpr): class Ramp(PrimExpr):
"""Ramp node. """Ramp node.
...@@ -775,7 +778,7 @@ class Ramp(PrimExpr): ...@@ -775,7 +778,7 @@ class Ramp(PrimExpr):
_make.Ramp, base, stride, lanes) _make.Ramp, base, stride, lanes)
@register_object @tvm._ffi.register_object
class Broadcast(PrimExpr): class Broadcast(PrimExpr):
"""Broadcast node. """Broadcast node.
...@@ -792,7 +795,7 @@ class Broadcast(PrimExpr): ...@@ -792,7 +795,7 @@ class Broadcast(PrimExpr):
_make.Broadcast, value, lanes) _make.Broadcast, value, lanes)
@register_object @tvm._ffi.register_object
class Shuffle(PrimExpr): class Shuffle(PrimExpr):
"""Shuffle node. """Shuffle node.
...@@ -809,7 +812,7 @@ class Shuffle(PrimExpr): ...@@ -809,7 +812,7 @@ class Shuffle(PrimExpr):
_make.Shuffle, vectors, indices) _make.Shuffle, vectors, indices)
@register_object @tvm._ffi.register_object
class Call(PrimExpr): class Call(PrimExpr):
"""Call node. """Call node.
...@@ -844,7 +847,7 @@ class Call(PrimExpr): ...@@ -844,7 +847,7 @@ class Call(PrimExpr):
_make.Call, dtype, name, args, call_type, func, value_index) _make.Call, dtype, name, args, call_type, func, value_index)
@register_object @tvm._ffi.register_object
class Let(PrimExpr): class Let(PrimExpr):
"""Let node. """Let node.
......
...@@ -28,13 +28,10 @@ HalideIR. ...@@ -28,13 +28,10 @@ HalideIR.
# TODO(@were): Make this module more complete. # TODO(@were): Make this module more complete.
# 1. Support HalideIR dumping to Hybrid Script # 1. Support HalideIR dumping to Hybrid Script
# 2. Support multi-level HalideIR # 2. Support multi-level HalideIR
from __future__ import absolute_import as _abs
import inspect import inspect
import tvm._ffi
from .._ffi.base import decorate from .._ffi.base import decorate
from .._ffi.function import _init_api
from ..build_module import form_body from ..build_module import form_body
from .module import HybridModule from .module import HybridModule
...@@ -97,4 +94,4 @@ def build(sch, inputs, outputs, name="hybrid_func"): ...@@ -97,4 +94,4 @@ def build(sch, inputs, outputs, name="hybrid_func"):
return HybridModule(src, name) return HybridModule(src, name)
_init_api("tvm.hybrid") tvm._ffi._init_api("tvm.hybrid")
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
# under the License. # under the License.
"""Expression Intrinsics and math functions in TVM.""" """Expression Intrinsics and math functions in TVM."""
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from __future__ import absolute_import as _abs import tvm._ffi
import tvm.codegen
from ._ffi.function import register_func as _register_func
from . import make as _make from . import make as _make
from .api import convert, const from .api import convert, const
from .expr import Call as _Call from .expr import Call as _Call
...@@ -189,7 +189,6 @@ def call_llvm_intrin(dtype, name, *args): ...@@ -189,7 +189,6 @@ def call_llvm_intrin(dtype, name, *args):
call : Expr call : Expr
The call expression. The call expression.
""" """
import tvm
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name) llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args) return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
...@@ -596,7 +595,7 @@ def register_intrin_rule(target, intrin, f=None, override=False): ...@@ -596,7 +595,7 @@ def register_intrin_rule(target, intrin, f=None, override=False):
register_intrin_rule("opencl", "exp", my_exp_rule, override=True) register_intrin_rule("opencl", "exp", my_exp_rule, override=True)
""" """
return _register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override) return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
def _rule_float_suffix(op): def _rule_float_suffix(op):
...@@ -650,7 +649,7 @@ def _rule_float_direct(op): ...@@ -650,7 +649,7 @@ def _rule_float_direct(op):
return call_pure_extern(op.dtype, op.name, *op.args) return call_pure_extern(op.dtype, op.name, *op.args)
return None return None
@_register_func("tvm.default_trace_action") @tvm._ffi.register_func("tvm.default_trace_action")
def _tvm_default_trace_action(*args): def _tvm_default_trace_action(*args):
print(list(args)) print(list(args))
......
...@@ -24,7 +24,7 @@ from . import make as _make ...@@ -24,7 +24,7 @@ from . import make as _make
from . import ir_pass as _pass from . import ir_pass as _pass
from . import container as _container from . import container as _container
from ._ffi.base import string_types from ._ffi.base import string_types
from ._ffi.object import ObjectGeneric from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType from ._ffi.runtime_ctypes import TVMType
from .expr import Call as _Call from .expr import Call as _Call
......
...@@ -23,6 +23,6 @@ Each api is a PackedFunc that can be called in a positional argument manner. ...@@ -23,6 +23,6 @@ Each api is a PackedFunc that can be called in a positional argument manner.
You can read "include/tvm/tir/ir_pass.h" for the function signature and You can read "include/tvm/tir/ir_pass.h" for the function signature and
"src/api/api_pass.cc" for the PackedFunc's body of these functions. "src/api/api_pass.cc" for the PackedFunc's body of these functions.
""" """
from ._ffi.function import _init_api import tvm._ffi
_init_api("tvm.ir_pass") tvm._ffi._init_api("tvm.ir_pass")
...@@ -22,8 +22,7 @@ The functions are automatically exported from C++ side via PackedFunc. ...@@ -22,8 +22,7 @@ The functions are automatically exported from C++ side via PackedFunc.
Each api is a PackedFunc that can be called in a positional argument manner. Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node. You can use make function to build the IR node.
""" """
from __future__ import absolute_import as _abs import tvm._ffi
from ._ffi.function import _init_api
def range_by_min_extent(min_value, extent): def range_by_min_extent(min_value, extent):
...@@ -85,4 +84,4 @@ def node(type_key, **kwargs): ...@@ -85,4 +84,4 @@ def node(type_key, **kwargs):
return _Node(*args) return _Node(*args)
_init_api("tvm.make") tvm._ffi._init_api("tvm.make")
...@@ -23,9 +23,11 @@ import sys ...@@ -23,9 +23,11 @@ import sys
from enum import Enum from enum import Enum
import tvm import tvm
import tvm._ffi
from tvm.contrib import util as _util from tvm.contrib import util as _util
from tvm.contrib import cc as _cc from tvm.contrib import cc as _cc
from .._ffi.function import _init_api
class LibType(Enum): class LibType(Enum):
"""Enumeration of library types that can be compiled and loaded onto a device""" """Enumeration of library types that can be compiled and loaded onto a device"""
...@@ -222,4 +224,4 @@ def get_micro_device_dir(): ...@@ -222,4 +224,4 @@ def get_micro_device_dir():
return micro_device_dir return micro_device_dir
_init_api("tvm.micro", "tvm.micro.base") tvm._ffi._init_api("tvm.micro", "tvm.micro.base")
...@@ -19,9 +19,9 @@ from __future__ import absolute_import as _abs ...@@ -19,9 +19,9 @@ from __future__ import absolute_import as _abs
import struct import struct
from collections import namedtuple from collections import namedtuple
import tvm._ffi
from ._ffi.function import ModuleBase, _set_class_module from ._ffi.module import ModuleBase, _set_class_module
from ._ffi.function import _init_api
from ._ffi.libinfo import find_include_path from ._ffi.libinfo import find_include_path
from .contrib import cc as _cc, tar as _tar, util as _util from .contrib import cc as _cc, tar as _tar, util as _util
...@@ -333,5 +333,5 @@ def enabled(target): ...@@ -333,5 +333,5 @@ def enabled(target):
return _Enabled(target) return _Enabled(target)
_init_api("tvm.module") tvm._ffi._init_api("tvm.module")
_set_class_module(Module) _set_class_module(Module)
...@@ -20,17 +20,15 @@ tvm.ndarray provides a minimum runtime array API to test ...@@ -20,17 +20,15 @@ tvm.ndarray provides a minimum runtime array API to test
the correctness of the program. the correctness of the program.
""" """
# pylint: disable=invalid-name,unused-import # pylint: disable=invalid-name,unused-import
from __future__ import absolute_import as _abs import tvm._ffi
import numpy as _np import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension
from ._ffi.object import register_object
@register_object @tvm._ffi.register_object
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime. """Lightweight NDArray class of TVM runtime.
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Node is the base class of all TVM AST.
Normally user do not need to touch this api.
"""
# pylint: disable=unused-import
from __future__ import absolute_import as _abs
from ._ffi.object import Object, register_object
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""FFI exposing the passes for Relay program analysis.""" """FFI exposing the passes for Relay program analysis."""
import tvm._ffi
from tvm._ffi.function import _init_api tvm._ffi._init_api("relay._analysis", __name__)
_init_api("relay._analysis", __name__)
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++.""" """The interface of expr function exposed from C++."""
from tvm._ffi.function import _init_api import tvm._ffi
_init_api("relay._base", __name__) tvm._ffi._init_api("relay._base", __name__)
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface for building Relay functions exposed from C++.""" """The interface for building Relay functions exposed from C++."""
from tvm._ffi.function import _init_api import tvm._ffi
_init_api("relay.build_module", __name__) tvm._ffi._init_api("relay.build_module", __name__)
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++.""" """The interface of expr function exposed from C++."""
from tvm._ffi.function import _init_api import tvm._ffi
_init_api("relay._expr", __name__) tvm._ffi._init_api("relay._expr", __name__)
...@@ -20,6 +20,6 @@ The constructors for all Relay AST nodes exposed from C++. ...@@ -20,6 +20,6 @@ The constructors for all Relay AST nodes exposed from C++.
This module includes MyPy type signatures for all of the This module includes MyPy type signatures for all of the
exposed modules. exposed modules.
""" """
from .._ffi.function import _init_api import tvm._ffi
_init_api("relay._make", __name__) tvm._ffi._init_api("relay._make", __name__)
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface to the Module exposed from C++.""" """The interface to the Module exposed from C++."""
from tvm._ffi.function import _init_api import tvm._ffi
_init_api("relay._module", __name__) tvm._ffi._init_api("relay._module", __name__)
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""FFI exposing the Relay type inference and checking.""" """FFI exposing the Relay type inference and checking."""
import tvm._ffi
from tvm._ffi.function import _init_api tvm._ffi._init_api("relay._transform", __name__)
_init_api("relay._transform", __name__)
...@@ -15,14 +15,13 @@ ...@@ -15,14 +15,13 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""The interface of expr function exposed from C++.""" """The interface of expr function exposed from C++."""
from __future__ import absolute_import import tvm._ffi
from ... import build_module as _build from ... import build_module as _build
from ... import container as _container from ... import container as _container
from ..._ffi.function import _init_api, register_func
@register_func("relay.backend.lower") @tvm._ffi.register_func("relay.backend.lower")
def lower(sch, inputs, func_name, source_func): def lower(sch, inputs, func_name, source_func):
"""Backend function for lowering. """Backend function for lowering.
...@@ -61,7 +60,7 @@ def lower(sch, inputs, func_name, source_func): ...@@ -61,7 +60,7 @@ def lower(sch, inputs, func_name, source_func):
f, (_container.Array, tuple, list)) else [f] f, (_container.Array, tuple, list)) else [f]
@register_func("relay.backend.build") @tvm._ffi.register_func("relay.backend.build")
def build(funcs, target, target_host=None): def build(funcs, target, target_host=None):
"""Backend build function. """Backend build function.
...@@ -88,14 +87,14 @@ def build(funcs, target, target_host=None): ...@@ -88,14 +87,14 @@ def build(funcs, target, target_host=None):
return _build.build(funcs, target=target, target_host=target_host) return _build.build(funcs, target=target, target_host=target_host)
@register_func("relay._tensor_value_repr") @tvm._ffi.register_func("relay._tensor_value_repr")
def _tensor_value_repr(tvalue): def _tensor_value_repr(tvalue):
return str(tvalue.data.asnumpy()) return str(tvalue.data.asnumpy())
@register_func("relay._constant_repr") @tvm._ffi.register_func("relay._constant_repr")
def _tensor_constant_repr(tvalue): def _tensor_constant_repr(tvalue):
return str(tvalue.data.asnumpy()) return str(tvalue.data.asnumpy())
_init_api("relay.backend", __name__) tvm._ffi._init_api("relay.backend", __name__)
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
# under the License. # under the License.
"""The Relay virtual machine FFI namespace. """The Relay virtual machine FFI namespace.
""" """
from tvm._ffi.function import _init_api import tvm._ffi
_init_api("relay._vm", __name__) tvm._ffi._init_api("relay._vm", __name__)
...@@ -25,7 +25,7 @@ import numpy as np ...@@ -25,7 +25,7 @@ import numpy as np
import tvm import tvm
import tvm.ndarray as _nd import tvm.ndarray as _nd
from tvm import autotvm, container from tvm import autotvm, container
from tvm.object import Object from tvm._ffi.object import Object
from tvm.relay import expr as _expr from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base from tvm._ffi import base as _base
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck # pylint: disable=no-else-return, unidiomatic-typecheck
"""The base node types for the Relay language.""" """The base node types for the Relay language."""
from __future__ import absolute_import as _abs import tvm._ffi
from .._ffi.object import register_object as _register_tvm_node
from .._ffi.object import Object from .._ffi.object import Object
from . import _make from . import _make
from . import _expr from . import _expr
...@@ -34,9 +34,9 @@ def register_relay_node(type_key=None): ...@@ -34,9 +34,9 @@ def register_relay_node(type_key=None):
The type key of the node. The type key of the node.
""" """
if not isinstance(type_key, str): if not isinstance(type_key, str):
return _register_tvm_node( return tvm._ffi.register_object(
"relay." + type_key.__name__)(type_key) "relay." + type_key.__name__)(type_key)
return _register_tvm_node(type_key) return tvm._ffi.register_object(type_key)
def register_relay_attr_node(type_key=None): def register_relay_attr_node(type_key=None):
...@@ -48,9 +48,9 @@ def register_relay_attr_node(type_key=None): ...@@ -48,9 +48,9 @@ def register_relay_attr_node(type_key=None):
The type key of the node. The type key of the node.
""" """
if not isinstance(type_key, str): if not isinstance(type_key, str):
return _register_tvm_node( return tvm._ffi.register_object(
"relay.attrs." + type_key.__name__)(type_key) "relay.attrs." + type_key.__name__)(type_key)
return _register_tvm_node(type_key) return tvm._ffi.register_object(type_key)
class RelayNode(Object): class RelayNode(Object):
......
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Constructor APIs""" """Constructor APIs"""
from ..._ffi.function import _init_api import tvm._ffi
_init_api("relay.op._make", __name__) tvm._ffi._init_api("relay.op._make", __name__)
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Constructor APIs""" """Constructor APIs"""
from ...._ffi.function import _init_api import tvm._ffi
_init_api("relay.op.annotation._make", __name__) tvm._ffi._init_api("relay.op.annotation._make", __name__)
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Constructor APIs""" """Constructor APIs"""
from ...._ffi.function import _init_api import tvm._ffi
_init_api("relay.op.contrib._make", __name__) tvm._ffi._init_api("relay.op.contrib._make", __name__)
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Constructor APIs""" """Constructor APIs"""
from ...._ffi.function import _init_api import tvm._ffi
_init_api("relay.op.image._make", __name__) tvm._ffi._init_api("relay.op.image._make", __name__)
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Constructor APIs""" """Constructor APIs"""
from ...._ffi.function import _init_api import tvm._ffi
_init_api("relay.op.memory._make", __name__) tvm._ffi._init_api("relay.op.memory._make", __name__)
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Constructor APIs""" """Constructor APIs"""
from ...._ffi.function import _init_api import tvm._ffi
_init_api("relay.op.nn._make", __name__) tvm._ffi._init_api("relay.op.nn._make", __name__)
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
#pylint: disable=unused-argument #pylint: disable=unused-argument
"""The base node types for the Relay language.""" """The base node types for the Relay language."""
import topi import topi
import tvm._ffi
from ..._ffi.function import _init_api
from ..base import register_relay_node from ..base import register_relay_node
from ..expr import Expr from ..expr import Expr
...@@ -283,8 +282,6 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10): ...@@ -283,8 +282,6 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
get(op_name).set_attr("TShapeDataDependant", data_dependant, level) get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
return register(op_name, "FShapeFunc", shape_func, level) return register(op_name, "FShapeFunc", shape_func, level)
_init_api("relay.op", __name__)
@register_func("relay.op.compiler._lower") @register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs): def _lower(name, schedule, inputs, outputs):
return lower(schedule, list(inputs) + list(outputs), name=name) return lower(schedule, list(inputs) + list(outputs), name=name)
...@@ -320,3 +317,5 @@ def debug(expr, debug_func=None): ...@@ -320,3 +317,5 @@ def debug(expr, debug_func=None):
name = '' name = ''
return _make.debug(expr, name) return _make.debug(expr, name)
tvm._ffi._init_api("relay.op", __name__)
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Constructor APIs""" """Constructor APIs"""
from ...._ffi.function import _init_api import tvm._ffi
_init_api("relay.op.vision._make", __name__) tvm._ffi._init_api("relay.op.vision._make", __name__)
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Constructor APIs""" """Constructor APIs"""
from ...._ffi.function import _init_api import tvm._ffi
_init_api("relay.qnn.op._make", __name__) tvm._ffi._init_api("relay.qnn.op._make", __name__)
...@@ -16,11 +16,10 @@ ...@@ -16,11 +16,10 @@
# under the License. # under the License.
#pylint: disable=unused-argument,inconsistent-return-statements #pylint: disable=unused-argument,inconsistent-return-statements
"""Internal module for registering attribute for annotation.""" """Internal module for registering attribute for annotation."""
from __future__ import absolute_import
import warnings import warnings
import topi import topi
from ..._ffi.function import register_func import tvm._ffi
from .. import expr as _expr from .. import expr as _expr
from .. import analysis as _analysis from .. import analysis as _analysis
from .. import op as _op from .. import op as _op
...@@ -144,7 +143,8 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): ...@@ -144,7 +143,8 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
qctx.qnode_map[key] = qnode qctx.qnode_map[key] = qnode
return qnode return qnode
register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize) tvm._ffi.register_func(
"relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
@register_annotate_function("nn.contrib_conv2d_NCHWc") @register_annotate_function("nn.contrib_conv2d_NCHWc")
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
# under the License. # under the License.
#pylint: disable=unused-argument #pylint: disable=unused-argument
"""Internal module for quantization.""" """Internal module for quantization."""
from __future__ import absolute_import import tvm._ffi
from tvm._ffi.function import _init_api
_init_api("relay._quantize", __name__) tvm._ffi._init_api("relay._quantize", __name__)
...@@ -26,8 +26,8 @@ import errno ...@@ -26,8 +26,8 @@ import errno
import struct import struct
import random import random
import logging import logging
import tvm._ffi
from .._ffi.function import _init_api
from .._ffi.base import py_str from .._ffi.base import py_str
# Magic header for RPC data plane # Magic header for RPC data plane
...@@ -179,4 +179,4 @@ def connect_with_retry(addr, timeout=60, retry_period=5): ...@@ -179,4 +179,4 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
# Still use tvm.rpc for the foreign functions # Still use tvm.rpc for the foreign functions
_init_api("tvm.rpc", "tvm.rpc.base") tvm._ffi._init_api("tvm.rpc", "tvm.rpc.base")
...@@ -21,11 +21,11 @@ import os ...@@ -21,11 +21,11 @@ import os
import socket import socket
import struct import struct
import time import time
import tvm._ffi
from . import base from . import base
from ..contrib import util from ..contrib import util
from .._ffi.base import TVMError from .._ffi.base import TVMError
from .._ffi import function
from .._ffi import ndarray as nd from .._ffi import ndarray as nd
from ..module import load as _load_module from ..module import load as _load_module
...@@ -185,7 +185,7 @@ class LocalSession(RPCSession): ...@@ -185,7 +185,7 @@ class LocalSession(RPCSession):
def __init__(self): def __init__(self):
# pylint: disable=super-init-not-called # pylint: disable=super-init-not-called
self.context = nd.context self.context = nd.context
self.get_function = function.get_global_func self.get_function = tvm._ffi.get_global_func
self._temp = util.tempdir() self._temp = util.tempdir()
def upload(self, data, target=None): def upload(self, data, target=None):
......
...@@ -25,9 +25,6 @@ Server is TCP based with the following protocol: ...@@ -25,9 +25,6 @@ Server is TCP based with the following protocol:
- {server|client}:device-type[:random-key] [-timeout=timeout] - {server|client}:device-type[:random-key] [-timeout=timeout]
""" """
# pylint: disable=invalid-name # pylint: disable=invalid-name
from __future__ import absolute_import
import os import os
import ctypes import ctypes
import socket import socket
...@@ -39,8 +36,8 @@ import subprocess ...@@ -39,8 +36,8 @@ import subprocess
import time import time
import sys import sys
import signal import signal
import tvm._ffi
from .._ffi.function import register_func
from .._ffi.base import py_str from .._ffi.base import py_str
from .._ffi.libinfo import find_lib_path from .._ffi.libinfo import find_lib_path
from ..module import load as _load_module from ..module import load as _load_module
...@@ -58,11 +55,11 @@ def _server_env(load_library, work_path=None): ...@@ -58,11 +55,11 @@ def _server_env(load_library, work_path=None):
temp = util.tempdir() temp = util.tempdir()
# pylint: disable=unused-variable # pylint: disable=unused-variable
@register_func("tvm.rpc.server.workpath") @tvm._ffi.register_func("tvm.rpc.server.workpath")
def get_workpath(path): def get_workpath(path):
return temp.relpath(path) return temp.relpath(path)
@register_func("tvm.rpc.server.load_module", override=True) @tvm._ffi.register_func("tvm.rpc.server.load_module", override=True)
def load_module(file_name): def load_module(file_name):
"""Load module from remote side.""" """Load module from remote side."""
path = temp.relpath(file_name) path = temp.relpath(file_name)
......
...@@ -15,38 +15,19 @@ ...@@ -15,38 +15,19 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""The computation schedule api of TVM.""" """The computation schedule api of TVM."""
from __future__ import absolute_import as _abs import tvm._ffi
from ._ffi.base import string_types from ._ffi.base import string_types
from ._ffi.object import Object, register_object from ._ffi.object import Object
from ._ffi.object import convert_to_object as _convert_to_object from ._ffi.object_generic import convert
from ._ffi.function import _init_api, Function
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal from . import _api_internal
from . import tensor as _tensor from . import tensor as _tensor
from . import expr as _expr from . import expr as _expr
from . import container as _container from . import container as _container
def convert(value):
"""Convert value to TVM object or function.
Parameters
----------
value : python value
Returns
-------
tvm_val : Object or Function
Converted value in TVM
"""
if isinstance(value, (Function, Object)):
return value
if callable(value):
return _convert_tvm_func(value)
return _convert_to_object(value)
@register_object @tvm._ffi.register_object
class Buffer(Object): class Buffer(Object):
"""Symbolic data buffer in TVM. """Symbolic data buffer in TVM.
...@@ -156,22 +137,22 @@ class Buffer(Object): ...@@ -156,22 +137,22 @@ class Buffer(Object):
return _api_internal._BufferVStore(self, begin, value) return _api_internal._BufferVStore(self, begin, value)
@register_object @tvm._ffi.register_object
class Split(Object): class Split(Object):
"""Split operation on axis.""" """Split operation on axis."""
@register_object @tvm._ffi.register_object
class Fuse(Object): class Fuse(Object):
"""Fuse operation on axis.""" """Fuse operation on axis."""
@register_object @tvm._ffi.register_object
class Singleton(Object): class Singleton(Object):
"""Singleton axis.""" """Singleton axis."""
@register_object @tvm._ffi.register_object
class IterVar(Object, _expr.ExprOp): class IterVar(Object, _expr.ExprOp):
"""Represent iteration variable. """Represent iteration variable.
...@@ -214,7 +195,7 @@ def create_schedule(ops): ...@@ -214,7 +195,7 @@ def create_schedule(ops):
return _api_internal._CreateSchedule(ops) return _api_internal._CreateSchedule(ops)
@register_object @tvm._ffi.register_object
class Schedule(Object): class Schedule(Object):
"""Schedule for all the stages.""" """Schedule for all the stages."""
def __getitem__(self, k): def __getitem__(self, k):
...@@ -348,7 +329,7 @@ class Schedule(Object): ...@@ -348,7 +329,7 @@ class Schedule(Object):
return factored[0] if len(factored) == 1 else factored return factored[0] if len(factored) == 1 else factored
@register_object @tvm._ffi.register_object
class Stage(Object): class Stage(Object):
"""A Stage represents schedule for one operation.""" """A Stage represents schedule for one operation."""
def split(self, parent, factor=None, nparts=None): def split(self, parent, factor=None, nparts=None):
...@@ -670,4 +651,4 @@ class Stage(Object): ...@@ -670,4 +651,4 @@ class Stage(Object):
""" """
_api_internal._StageOpenGL(self) _api_internal._StageOpenGL(self)
_init_api("tvm.schedule") tvm._ffi._init_api("tvm.schedule")
...@@ -29,15 +29,15 @@ Each statement node have subfields that can be visited from python side. ...@@ -29,15 +29,15 @@ Each statement node have subfields that can be visited from python side.
assert isinstance(st, tvm.stmt.Store) assert isinstance(st, tvm.stmt.Store)
assert(st.buffer_var == a) assert(st.buffer_var == a)
""" """
from __future__ import absolute_import as _abs import tvm._ffi
from ._ffi.object import Object, register_object from ._ffi.object import Object
from . import make as _make from . import make as _make
class Stmt(Object): class Stmt(Object):
pass pass
@register_object @tvm._ffi.register_object
class LetStmt(Stmt): class LetStmt(Stmt):
"""LetStmt node. """LetStmt node.
...@@ -57,7 +57,7 @@ class LetStmt(Stmt): ...@@ -57,7 +57,7 @@ class LetStmt(Stmt):
_make.LetStmt, var, value, body) _make.LetStmt, var, value, body)
@register_object @tvm._ffi.register_object
class AssertStmt(Stmt): class AssertStmt(Stmt):
"""AssertStmt node. """AssertStmt node.
...@@ -77,7 +77,7 @@ class AssertStmt(Stmt): ...@@ -77,7 +77,7 @@ class AssertStmt(Stmt):
_make.AssertStmt, condition, message, body) _make.AssertStmt, condition, message, body)
@register_object @tvm._ffi.register_object
class ProducerConsumer(Stmt): class ProducerConsumer(Stmt):
"""ProducerConsumer node. """ProducerConsumer node.
...@@ -97,7 +97,7 @@ class ProducerConsumer(Stmt): ...@@ -97,7 +97,7 @@ class ProducerConsumer(Stmt):
_make.ProducerConsumer, func, is_producer, body) _make.ProducerConsumer, func, is_producer, body)
@register_object @tvm._ffi.register_object
class For(Stmt): class For(Stmt):
"""For node. """For node.
...@@ -137,7 +137,7 @@ class For(Stmt): ...@@ -137,7 +137,7 @@ class For(Stmt):
for_type, device_api, body) for_type, device_api, body)
@register_object @tvm._ffi.register_object
class Store(Stmt): class Store(Stmt):
"""Store node. """Store node.
...@@ -160,7 +160,7 @@ class Store(Stmt): ...@@ -160,7 +160,7 @@ class Store(Stmt):
_make.Store, buffer_var, value, index, predicate) _make.Store, buffer_var, value, index, predicate)
@register_object @tvm._ffi.register_object
class Provide(Stmt): class Provide(Stmt):
"""Provide node. """Provide node.
...@@ -183,7 +183,7 @@ class Provide(Stmt): ...@@ -183,7 +183,7 @@ class Provide(Stmt):
_make.Provide, func, value_index, value, args) _make.Provide, func, value_index, value, args)
@register_object @tvm._ffi.register_object
class Allocate(Stmt): class Allocate(Stmt):
"""Allocate node. """Allocate node.
...@@ -215,7 +215,7 @@ class Allocate(Stmt): ...@@ -215,7 +215,7 @@ class Allocate(Stmt):
extents, condition, body) extents, condition, body)
@register_object @tvm._ffi.register_object
class AttrStmt(Stmt): class AttrStmt(Stmt):
"""AttrStmt node. """AttrStmt node.
...@@ -238,7 +238,7 @@ class AttrStmt(Stmt): ...@@ -238,7 +238,7 @@ class AttrStmt(Stmt):
_make.AttrStmt, node, attr_key, value, body) _make.AttrStmt, node, attr_key, value, body)
@register_object @tvm._ffi.register_object
class Free(Stmt): class Free(Stmt):
"""Free node. """Free node.
...@@ -252,7 +252,7 @@ class Free(Stmt): ...@@ -252,7 +252,7 @@ class Free(Stmt):
_make.Free, buffer_var) _make.Free, buffer_var)
@register_object @tvm._ffi.register_object
class Realize(Stmt): class Realize(Stmt):
"""Realize node. """Realize node.
...@@ -288,7 +288,7 @@ class Realize(Stmt): ...@@ -288,7 +288,7 @@ class Realize(Stmt):
bounds, condition, body) bounds, condition, body)
@register_object @tvm._ffi.register_object
class SeqStmt(Stmt): class SeqStmt(Stmt):
"""Sequence of statements. """Sequence of statements.
...@@ -308,7 +308,7 @@ class SeqStmt(Stmt): ...@@ -308,7 +308,7 @@ class SeqStmt(Stmt):
return len(self.seq) return len(self.seq)
@register_object @tvm._ffi.register_object
class IfThenElse(Stmt): class IfThenElse(Stmt):
"""IfThenElse node. """IfThenElse node.
...@@ -328,7 +328,7 @@ class IfThenElse(Stmt): ...@@ -328,7 +328,7 @@ class IfThenElse(Stmt):
_make.IfThenElse, condition, then_case, else_case) _make.IfThenElse, condition, then_case, else_case)
@register_object @tvm._ffi.register_object
class Evaluate(Stmt): class Evaluate(Stmt):
"""Evaluate node. """Evaluate node.
...@@ -342,7 +342,7 @@ class Evaluate(Stmt): ...@@ -342,7 +342,7 @@ class Evaluate(Stmt):
_make.Evaluate, value) _make.Evaluate, value)
@register_object @tvm._ffi.register_object
class Prefetch(Stmt): class Prefetch(Stmt):
"""Prefetch node. """Prefetch node.
......
...@@ -54,12 +54,11 @@ The list of options include: ...@@ -54,12 +54,11 @@ The list of options include:
We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string. We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string.
We can also use other specific function in this module to create specific targets. We can also use other specific function in this module to create specific targets.
""" """
from __future__ import absolute_import
import warnings import warnings
import tvm._ffi
from ._ffi.base import _LIB_NAME from ._ffi.base import _LIB_NAME
from ._ffi.object import Object, register_object from ._ffi.object import Object
from . import _api_internal from . import _api_internal
try: try:
...@@ -80,7 +79,7 @@ def _merge_opts(opts, new_opts): ...@@ -80,7 +79,7 @@ def _merge_opts(opts, new_opts):
return opts return opts
@register_object @tvm._ffi.register_object
class Target(Object): class Target(Object):
"""Target device information, use through TVM API. """Target device information, use through TVM API.
...@@ -146,7 +145,7 @@ class Target(Object): ...@@ -146,7 +145,7 @@ class Target(Object):
_api_internal._ExitTargetScope(self) _api_internal._ExitTargetScope(self)
@register_object @tvm._ffi.register_object
class GenericFunc(Object): class GenericFunc(Object):
"""GenericFunc node reference. This represents a generic function """GenericFunc node reference. This represents a generic function
that may be specialized for different targets. When this object is that may be specialized for different targets. When this object is
......
...@@ -16,9 +16,11 @@ ...@@ -16,9 +16,11 @@
# under the License. # under the License.
"""Tensor and Operation class for computation declaration.""" """Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
from __future__ import absolute_import as _abs import tvm._ffi
from ._ffi.object import Object, register_object, ObjectGeneric, \
convert_to_object from ._ffi.object import Object
from ._ffi.object_generic import ObjectGeneric, convert_to_object
from . import _api_internal from . import _api_internal
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
...@@ -47,7 +49,7 @@ class TensorSlice(ObjectGeneric, _expr.ExprOp): ...@@ -47,7 +49,7 @@ class TensorSlice(ObjectGeneric, _expr.ExprOp):
"""Data content of the tensor.""" """Data content of the tensor."""
return self.tensor.dtype return self.tensor.dtype
@register_object @tvm._ffi.register_object
class TensorIntrinCall(Object): class TensorIntrinCall(Object):
"""Intermediate structure for calling a tensor intrinsic.""" """Intermediate structure for calling a tensor intrinsic."""
...@@ -55,7 +57,7 @@ class TensorIntrinCall(Object): ...@@ -55,7 +57,7 @@ class TensorIntrinCall(Object):
itervar_cls = None itervar_cls = None
@register_object @tvm._ffi.register_object
class Tensor(Object, _expr.ExprOp): class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor""" """Tensor object, to construct, see function.Tensor"""
...@@ -157,12 +159,12 @@ class Operation(Object): ...@@ -157,12 +159,12 @@ class Operation(Object):
return _api_internal._OpInputTensors(self) return _api_internal._OpInputTensors(self)
@register_object @tvm._ffi.register_object
class PlaceholderOp(Operation): class PlaceholderOp(Operation):
"""Placeholder operation.""" """Placeholder operation."""
@register_object @tvm._ffi.register_object
class BaseComputeOp(Operation): class BaseComputeOp(Operation):
"""Compute operation.""" """Compute operation."""
@property @property
...@@ -176,18 +178,18 @@ class BaseComputeOp(Operation): ...@@ -176,18 +178,18 @@ class BaseComputeOp(Operation):
return self.__getattr__("reduce_axis") return self.__getattr__("reduce_axis")
@register_object @tvm._ffi.register_object
class ComputeOp(BaseComputeOp): class ComputeOp(BaseComputeOp):
"""Scalar operation.""" """Scalar operation."""
pass pass
@register_object @tvm._ffi.register_object
class TensorComputeOp(BaseComputeOp): class TensorComputeOp(BaseComputeOp):
"""Tensor operation.""" """Tensor operation."""
@register_object @tvm._ffi.register_object
class ScanOp(Operation): class ScanOp(Operation):
"""Scan operation.""" """Scan operation."""
@property @property
...@@ -196,12 +198,12 @@ class ScanOp(Operation): ...@@ -196,12 +198,12 @@ class ScanOp(Operation):
return self.__getattr__("scan_axis") return self.__getattr__("scan_axis")
@register_object @tvm._ffi.register_object
class ExternOp(Operation): class ExternOp(Operation):
"""External operation.""" """External operation."""
@register_object @tvm._ffi.register_object
class HybridOp(Operation): class HybridOp(Operation):
"""Hybrid operation.""" """Hybrid operation."""
@property @property
...@@ -210,7 +212,7 @@ class HybridOp(Operation): ...@@ -210,7 +212,7 @@ class HybridOp(Operation):
return self.__getattr__("axis") return self.__getattr__("axis")
@register_object @tvm._ffi.register_object
class Layout(Object): class Layout(Object):
"""Layout is composed of upper cases, lower cases and numbers, """Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and where upper case indicates a primal axis and
...@@ -270,7 +272,7 @@ class Layout(Object): ...@@ -270,7 +272,7 @@ class Layout(Object):
return _api_internal._LayoutFactorOf(self, axis) return _api_internal._LayoutFactorOf(self, axis)
@register_object @tvm._ffi.register_object
class BijectiveLayout(Object): class BijectiveLayout(Object):
"""Bijective mapping for two layouts (src-layout and dst-layout). """Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other. It provides shape and index conversion between each other.
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Tensor intrinsics""" """Tensor intrinsics"""
from __future__ import absolute_import as _abs import tvm._ffi
from . import _api_internal from . import _api_internal
from . import api as _api from . import api as _api
from . import expr as _expr from . import expr as _expr
...@@ -24,7 +25,7 @@ from . import make as _make ...@@ -24,7 +25,7 @@ from . import make as _make
from . import tensor as _tensor from . import tensor as _tensor
from . import schedule as _schedule from . import schedule as _schedule
from .build_module import current_build_config from .build_module import current_build_config
from ._ffi.object import Object, register_object from ._ffi.object import Object
def _get_region(tslice): def _get_region(tslice):
...@@ -41,7 +42,7 @@ def _get_region(tslice): ...@@ -41,7 +42,7 @@ def _get_region(tslice):
region.append(_make.range_by_min_extent(begin, 1)) region.append(_make.range_by_min_extent(begin, 1))
return region return region
@register_object @tvm._ffi.register_object
class TensorIntrin(Object): class TensorIntrin(Object):
"""Tensor intrinsic functions for certain computation. """Tensor intrinsic functions for certain computation.
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""FFI for CUDA TOPI ops and schedules""" """FFI for CUDA TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix tvm._ffi._init_api("topi.cuda", "topi.cpp.cuda")
_init_api_prefix("topi.cpp.cuda", "topi.cuda")
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""FFI for generic TOPI ops and schedules""" """FFI for generic TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix tvm._ffi._init_api("topi.generic", "topi.cpp.generic")
_init_api_prefix("topi.cpp.generic", "topi.generic")
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
import sys import sys
import os import os
import ctypes import ctypes
import tvm._ffi
from tvm._ffi.function import _init_api_prefix
from tvm._ffi import libinfo from tvm._ffi import libinfo
def _get_lib_names(): def _get_lib_names():
...@@ -41,4 +41,4 @@ def _load_lib(): ...@@ -41,4 +41,4 @@ def _load_lib():
_LIB, _LIB_NAME = _load_lib() _LIB, _LIB_NAME = _load_lib()
_init_api_prefix("topi.cpp", "topi") tvm._ffi._init_api("topi", "topi.cpp")
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""FFI for NN TOPI ops and schedules""" """FFI for NN TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix tvm._ffi._init_api("topi.nn", "topi.cpp.nn")
_init_api_prefix("topi.cpp.nn", "topi.nn")
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""FFI for Rocm TOPI ops and schedules""" """FFI for Rocm TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix tvm._ffi._init_api("topi.rocm", "topi.cpp.rocm")
_init_api_prefix("topi.cpp.rocm", "topi.rocm")
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""FFI for TOPI utility functions""" """FFI for TOPI utility functions"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix tvm._ffi._init_api("topi.util", "topi.cpp.util")
_init_api_prefix("topi.cpp.util", "topi.util")
...@@ -16,9 +16,8 @@ ...@@ -16,9 +16,8 @@
# under the License. # under the License.
"""FFI for vision TOPI ops and schedules""" """FFI for vision TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix
from . import yolo from . import yolo
_init_api_prefix("topi.cpp.vision", "topi.vision") tvm._ffi._init_api("topi.vision", "topi.cpp.vision")
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""FFI for Yolo TOPI ops and schedules""" """FFI for Yolo TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix tvm._ffi._init_api("topi.vision.yolo", "topi.cpp.vision.yolo")
_init_api_prefix("topi.cpp.vision.yolo", "topi.vision.yolo")
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""FFI for x86 TOPI ops and schedules""" """FFI for x86 TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix tvm._ffi._init_api("topi.x86", "topi.cpp.x86")
_init_api_prefix("topi.cpp.x86", "topi.x86")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment