Unverified Commit f9b46c43 by Tianqi Chen Committed by GitHub

[REFACTOR][PY] tvm._ffi (#4813)

* [REFACTOR][PY] tvm._ffi

- Remove from __future__ import absolute_import in the related files as they are no longer needed if the code only runs in python3
- Remove reverse dependency of _ctypes _cython to object_generic.
- function.py -> packed_func.py
- Function -> PackedFunc
- all registry related logics goes to tvm._ffi.registry
- Use absolute references for FFI related calls.
  - tvm._ffi.register_object
  - tvm._ffi.register_func
  - tvm._ffi.get_global_func

* Move get global func to the ffi side
parent 4a39e521
......@@ -16,13 +16,17 @@
# under the License.
# pylint: disable=redefined-builtin, wildcard-import
"""TVM: Low level DSL/IR stack for tensor computation."""
from __future__ import absolute_import as _abs
import multiprocessing
import sys
import traceback
from . import _pyversion
# import ffi related features
from ._ffi.base import TVMError, __version__
from ._ffi.runtime_ctypes import TypeCode, TVMType
from ._ffi.ndarray import TVMContext
from ._ffi.packed_func import PackedFunc as Function
from ._ffi.registry import register_object, register_func, register_extension
from ._ffi.object import Object
from . import tensor
from . import arith
......@@ -34,7 +38,6 @@ from . import codegen
from . import container
from . import schedule
from . import module
from . import object
from . import attrs
from . import ir_builder
from . import target
......@@ -48,15 +51,9 @@ from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from ._ffi.runtime_ctypes import TypeCode, TVMType
from ._ffi.ndarray import TVMContext
from ._ffi.function import Function
from ._ffi.base import TVMError, __version__
from .api import *
from .intrin import *
from .tensor_intrin import decl_tensor_intrin
from .object import register_object
from .ndarray import register_extension
from .schedule import create_schedule
from .build_module import build, lower, build_config
from .tag import tag_scope
......
......@@ -24,3 +24,7 @@ be used via ctypes function calls.
Some performance critical functions are implemented by cython
and have a ctypes fallback implementation.
"""
from . import _pyversion
from .base import register_error
from .registry import register_object, register_func, register_extension
from .registry import _init_api, get_global_func
......@@ -16,8 +16,6 @@
# under the License.
# pylint: disable=invalid-name
"""Runtime NDArray api"""
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle
......
......@@ -16,8 +16,6 @@
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
......
......@@ -17,15 +17,12 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import
"""Function configuration API."""
from __future__ import absolute_import
import ctypes
import traceback
from numbers import Number, Integral
from ..base import _LIB, get_last_ffi_error, py2cerror
from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import c_str, string_types
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
......@@ -35,7 +32,7 @@ from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_in
from .object import ObjectBase, _set_class_object
from . import object as _object
FunctionHandle = ctypes.c_void_p
PackedFuncHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
ObjectHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p
......@@ -49,6 +46,15 @@ def _ctypes_free_resource(rhandle):
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
def _make_packed_func(handle, is_global):
"""Make a packed function class"""
obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
obj.is_global = is_global
obj.handle = handle
return obj
def convert_to_tvm_func(pyfunc):
"""Convert a python function to TVM function
......@@ -89,7 +95,7 @@ def convert_to_tvm_func(pyfunc):
_ = rv
return 0
handle = FunctionHandle()
handle = PackedFuncHandle()
f = TVMPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed.
......@@ -98,7 +104,7 @@ def convert_to_tvm_func(pyfunc):
if _LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0:
raise get_last_ffi_error()
return _CLASS_FUNCTION(handle, False)
return _make_packed_func(handle, False)
def _make_tvm_args(args, temp_args):
......@@ -144,15 +150,15 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_object(arg)
elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
arg = _FUNC_CONVERT_TO_OBJECT(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, FunctionBase):
elif isinstance(arg, PackedFuncBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.PACKED_FUNC_HANDLE
elif isinstance(arg, ctypes.c_void_p):
......@@ -168,7 +174,7 @@ def _make_tvm_args(args, temp_args):
return values, type_codes, num_args
class FunctionBase(object):
class PackedFuncBase(object):
"""Function base."""
__slots__ = ["handle", "is_global"]
# pylint: disable=no-member
......@@ -177,7 +183,7 @@ class FunctionBase(object):
Parameters
----------
handle : FunctionHandle
handle : PackedFuncHandle
the handle to the underlying function.
is_global : bool
......@@ -238,9 +244,22 @@ def _return_module(x):
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, FunctionHandle):
handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False)
if not isinstance(handle, PackedFuncHandle):
handle = PackedFuncHandle(handle)
return _CLASS_PACKED_FUNC(handle, False)
def _get_global_func(name, allow_missing=False):
handle = PackedFuncHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
if handle.value:
return _make_packed_func(handle, False)
if allow_missing:
return None
raise ValueError("Cannot find global function %s" % name)
# setup return handle for function type
_object.__init_by_constructor__ = __init_handle_by_constructor__
......@@ -255,13 +274,22 @@ C_TO_PY_ARG_SWITCH[TypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle,
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
_CLASS_MODULE = None
_CLASS_FUNCTION = None
_CLASS_PACKED_FUNC = None
_CLASS_OBJECT_GENERIC = None
_FUNC_CONVERT_TO_OBJECT = None
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
def _set_class_packed_func(packed_func_class):
global _CLASS_PACKED_FUNC
_CLASS_PACKED_FUNC = packed_func_class
def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _CLASS_OBJECT_GENERIC
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object
......@@ -16,8 +16,6 @@
# under the License.
"""The C Types used in API."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import ctypes
import struct
from ..base import py_str, check_call, _LIB
......
......@@ -75,7 +75,7 @@ ctypedef int64_t tvm_index_t
ctypedef DLTensor* DLTensorHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* TVMPackedFuncHandle
ctypedef void* ObjectHandle
ctypedef struct TVMObject:
......@@ -96,13 +96,15 @@ ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)
cdef extern from "tvm/runtime/c_runtime_api.h":
void TVMAPISetLastError(const char* msg)
const char *TVMGetLastError()
int TVMFuncCall(TVMFunctionHandle func,
int TVMFuncGetGlobal(const char* name,
TVMPackedFuncHandle* out);
int TVMFuncCall(TVMPackedFuncHandle func,
TVMValue* arg_values,
int* type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code)
int TVMFuncFree(TVMFunctionHandle func)
int TVMFuncFree(TVMPackedFuncHandle func)
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue* value,
int* type_code,
......@@ -110,7 +112,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out)
TVMPackedFuncHandle *out)
int TVMCbArgToReturn(TVMValue* value, int code)
int TVMArrayAlloc(tvm_index_t* shape,
tvm_index_t ndim,
......
......@@ -17,7 +17,5 @@
include "./base.pxi"
include "./object.pxi"
# include "./node.pxi"
include "./function.pxi"
include "./packed_func.pxi"
include "./ndarray.pxi"
......@@ -96,6 +96,6 @@ cdef class ObjectBase:
self.chandle = NULL
cdef void* chandle
ConstructorCall(
(<FunctionBase>fconstructor).chandle,
(<PackedFuncBase>fconstructor).chandle,
kTVMObjectHandle, args, &chandle)
self.chandle = chandle
......@@ -20,7 +20,6 @@ import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types, py2cerror
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
......@@ -67,6 +66,13 @@ cdef int tvm_callback(TVMValue* args,
return 0
cdef object make_packed_func(TVMPackedFuncHandle chandle, int is_global):
obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
(<PackedFuncBase>obj).chandle = chandle
(<PackedFuncBase>obj).is_global = is_global
return obj
def convert_to_tvm_func(object pyfunc):
"""Convert a python function to TVM function
......@@ -80,15 +86,13 @@ def convert_to_tvm_func(object pyfunc):
tvmfunc: tvm.Function
The converted tvm function.
"""
cdef TVMFunctionHandle chandle
cdef TVMPackedFuncHandle chandle
Py_INCREF(pyfunc)
CALL(TVMFuncCreateFromCFunc(tvm_callback,
<void*>(pyfunc),
tvm_callback_finalize,
&chandle))
ret = _CLASS_FUNCTION(None, False)
(<FunctionBase>ret).chandle = chandle
return ret
return make_packed_func(chandle, False)
cdef inline int make_arg(object arg,
......@@ -149,29 +153,30 @@ cdef inline int make_arg(object arg,
value[0].v_str = tstr
tcode[0] = kTVMStr
temp_args.append(tstr)
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_object(arg)
elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
arg = _FUNC_CONVERT_TO_OBJECT(arg)
value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kTVMObjectHandle
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kTVMModuleHandle
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
elif isinstance(arg, PackedFuncBase):
value[0].v_handle = (<PackedFuncBase>arg).chandle
tcode[0] = kTVMPackedFuncHandle
elif isinstance(arg, ctypes.c_void_p):
value[0].v_handle = c_handle(arg)
tcode[0] = kTVMOpaqueHandle
elif callable(arg):
arg = convert_to_tvm_func(arg)
value[0].v_handle = (<FunctionBase>arg).chandle
value[0].v_handle = (<PackedFuncBase>arg).chandle
tcode[0] = kTVMPackedFuncHandle
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return 0
cdef inline bytearray make_ret_bytes(void* chandle):
handle = ctypes_handle(chandle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
......@@ -182,6 +187,7 @@ cdef inline bytearray make_ret_bytes(void* chandle):
raise RuntimeError('memmove failed')
return res
cdef inline object make_ret(TVMValue value, int tcode):
"""convert result to return value."""
if tcode == kTVMObjectHandle:
......@@ -205,9 +211,7 @@ cdef inline object make_ret(TVMValue value, int tcode):
elif tcode == kTVMModuleHandle:
return _CLASS_MODULE(ctypes_handle(value.v_handle))
elif tcode == kTVMPackedFuncHandle:
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
return make_packed_func(value.v_handle, False)
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
......@@ -264,8 +268,8 @@ cdef inline int ConstructorCall(void* constructor_handle,
return 0
cdef class FunctionBase:
cdef TVMFunctionHandle chandle
cdef class PackedFuncBase:
cdef TVMPackedFuncHandle chandle
cdef int is_global
cdef inline _set_handle(self, handle):
......@@ -305,19 +309,39 @@ cdef class FunctionBase:
return make_ret(ret_val, ret_tcode)
_CLASS_FUNCTION = None
def _get_global_func(name, allow_missing):
cdef TVMPackedFuncHandle chandle
CALL(TVMFuncGetGlobal(c_str(name), &chandle))
if chandle != NULL:
return make_packed_func(chandle, True)
if allow_missing:
return None
raise ValueError("Cannot find global function %s" % name)
_CLASS_PACKED_FUNC = None
_CLASS_MODULE = None
_CLASS_OBJECT = None
_CLASS_OBJECT_GENERIC = None
_FUNC_CONVERT_TO_OBJECT = None
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
def _set_class_packed_func(func_class):
global _CLASS_PACKED_FUNC
_CLASS_PACKED_FUNC = func_class
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _CLASS_OBJECT_GENERIC
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object
......@@ -18,6 +18,9 @@
"""
import sys
#----------------------------
# Python3 version.
#----------------------------
if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 5):
PY3STATEMENT = """TVM project proudly dropped support of Python2.
The minimal Python requirement is Python 3.5
......
......@@ -17,8 +17,6 @@
# coding: utf-8
# pylint: disable=invalid-name
"""Base library for TVM FFI."""
from __future__ import absolute_import
import sys
import os
import ctypes
......@@ -28,27 +26,22 @@ from . import libinfo
#----------------------------
# library loading
#----------------------------
if sys.version_info[0] == 3:
string_types = (str,)
integer_types = (int, np.int32)
numeric_types = integer_types + (float, np.float32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
if sys.platform == "win32":
def _py_str(x):
try:
return x.decode('utf-8')
except UnicodeDecodeError:
encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP())
return x.decode(encoding)
py_str = _py_str
else:
py_str = lambda x: x.decode('utf-8')
string_types = (str,)
integer_types = (int, np.int32)
numeric_types = integer_types + (float, np.float32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
if sys.platform == "win32":
def _py_str(x):
try:
return x.decode('utf-8')
except UnicodeDecodeError:
encoding = 'cp' + str(ctypes.cdll.kernel32.GetACP())
return x.decode(encoding)
py_str = _py_str
else:
string_types = (basestring,)
integer_types = (int, long, np.int32)
numeric_types = integer_types + (float, np.float32)
py_str = lambda x: x
py_str = lambda x: x.decode('utf-8')
def _load_lib():
......
......@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Library information."""
from __future__ import absolute_import
import sys
import os
......@@ -39,6 +38,7 @@ def split_env_var(env_var, split):
return [p.strip() for p in os.environ[env_var].split(split)]
return []
def find_lib_path(name=None, search_path=None, optional=False):
"""Find dynamic library files.
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
"""Runtime Module namespace."""
import ctypes
from .base import _LIB, check_call, c_str, string_types
from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
class ModuleBase(object):
"""Base class for module"""
__slots__ = ["handle", "_entry", "entry_name"]
def __init__(self, handle):
self.handle = handle
self._entry = None
self.entry_name = "__tvm_main__"
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
@property
def entry_func(self):
"""Get the entry function
Returns
-------
f : Function
The entry function if exist
"""
if self._entry:
return self._entry
self._entry = self.get_function(self.entry_name)
return self._entry
def get_function(self, name, query_imports=False):
"""Get function from the module.
Parameters
----------
name : str
The name of the function
query_imports : bool
Whether also query modules imported by this module.
Returns
-------
f : Function
The result function.
"""
ret_handle = PackedFuncHandle()
check_call(_LIB.TVMModGetFunction(
self.handle, c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle)))
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
return PackedFunc(ret_handle, False)
def import_module(self, module):
"""Add module to the import list of current one.
Parameters
----------
module : Module
The other module.
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))
def __getitem__(self, name):
if not isinstance(name, string_types):
raise ValueError("Can only take string as function name")
return self.get_function(name)
def __call__(self, *args):
if self._entry:
return self._entry(*args)
f = self.entry_func
return f(*args)
......@@ -16,35 +16,22 @@
# under the License.
# pylint: disable=invalid-name, unused-import
"""Runtime NDArray api"""
from __future__ import absolute_import
import sys
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import _reg_extension
else:
from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import _reg_extension
except IMPORT_EXCEPT:
from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import _reg_extension
def context(dev_type, dev_id=0):
......@@ -297,59 +284,3 @@ class NDArrayBase(_NDArrayBase):
res = empty(self.shape, self.dtype, target)
return self._copyto(res)
raise ValueError("Unsupported target type %s" % str(type(target)))
def register_extension(cls, fcreate=None):
"""Register a extension class to TVM.
After the class is registered, the class will be able
to directly pass as Function argument generated by TVM.
Parameters
----------
cls : class
The class object to be registered as extension.
fcreate : function, optional
The creation function to create a class object given handle value.
Note
----
The registered class is requires one property: _tvm_handle.
If the registered class is a subclass of NDArray,
it is required to have a class attribute _array_type_code.
Otherwise, it is required to have a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type
code of the class.
Returns
-------
cls : class
The class being registered.
Example
-------
The following code registers user defined class
MyTensor to be DLTensor compatible.
.. code-block:: python
@tvm.register_extension
class MyTensor(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self):
self.handle = _LIB.NewDLTensor()
@property
def _tvm_handle(self):
return self.handle.value
"""
assert hasattr(cls, "_tvm_tcode")
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls
......@@ -16,33 +16,20 @@
# under the License.
# pylint: disable=invalid-name, unused-import
"""Runtime Object API"""
from __future__ import absolute_import
import sys
import ctypes
from .. import _api_internal
from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
from .object_generic import ObjectGeneric, convert_to_object, const
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position,unused-import
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_object
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
from ._cy3.core import _set_class_object, _set_class_object_generic
from ._cy3.core import ObjectBase
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import
from ._ctypes.function import _set_class_object
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.object import _register_object
from ._ctypes.packed_func import _set_class_object, _set_class_object_generic
from ._ctypes.object import ObjectBase
def _new_object(cls):
......@@ -50,7 +37,7 @@ def _new_object(cls):
return cls.__new__(cls)
class Object(_ObjectBase):
class Object(ObjectBase):
"""Base class for all tvm's runtime objects."""
def __repr__(self):
return _api_internal._format_str(self)
......@@ -104,52 +91,6 @@ class Object(_ObjectBase):
return self.__hash__() == other.__hash__()
def register_object(type_key=None):
"""register object type.
Parameters
----------
type_key : str or cls
The type key of the node
Examples
--------
The following code registers MyObject
using type key "test.MyObject"
.. code-block:: python
@tvm.register_object("test.MyObject")
class MyObject(Object):
pass
"""
object_name = type_key if isinstance(type_key, str) else type_key.__name__
def register(cls):
"""internal register function"""
if hasattr(cls, "_type_index"):
tindex = cls._type_index
else:
tidx = ctypes.c_uint()
if not _RUNTIME_ONLY:
check_call(_LIB.TVMObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tidx)))
else:
# directly skip unknown objects during runtime.
ret = _LIB.TVMObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tidx))
if ret != 0:
return cls
tindex = tidx.value
_register_object(tindex, cls)
return cls
if isinstance(type_key, str):
return register
return register(type_key)
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
......
......@@ -16,35 +16,14 @@
# under the License.
"""Common implementation of object generic related logic"""
# pylint: disable=unused-import
from __future__ import absolute_import
from numbers import Number, Integral
from .. import _api_internal
from .base import string_types
# Object base class
_CLASS_OBJECTS = None
def _set_class_objects(cls):
global _CLASS_OBJECTS
_CLASS_OBJECTS = cls
def _scalar_type_inference(value):
if hasattr(value, 'dtype'):
dtype = str(value.dtype)
elif isinstance(value, bool):
dtype = 'bool'
elif isinstance(value, float):
# We intentionally convert the float to float32 since it's more common in DL.
dtype = 'float32'
elif isinstance(value, int):
# We intentionally convert the python int to int32 since it's more common in DL.
dtype = 'int32'
else:
raise NotImplementedError('Cannot automatically inference the type.'
' value={}'.format(value))
return dtype
from .base import string_types
from .object import ObjectBase, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
from .module import ModuleBase
class ObjectGeneric(object):
......@@ -54,6 +33,9 @@ class ObjectGeneric(object):
raise NotImplementedError()
_CLASS_OBJECTS = (ObjectBase, NDArrayBase, ModuleBase)
def convert_to_object(value):
"""Convert a python value to corresponding object type.
......@@ -95,22 +77,65 @@ def convert_to_object(value):
raise ValueError("don't know how to convert type %s to object" % type(value))
def convert(value):
"""Convert value to TVM object or function.
Parameters
----------
value : python value
Returns
-------
tvm_val : Object or Function
Converted value in TVM
"""
if isinstance(value, (PackedFuncBase, ObjectBase)):
return value
if callable(value):
return convert_to_tvm_func(value)
return convert_to_object(value)
def _scalar_type_inference(value):
if hasattr(value, 'dtype'):
dtype = str(value.dtype)
elif isinstance(value, bool):
dtype = 'bool'
elif isinstance(value, float):
# We intentionally convert the float to float32 since it's more common in DL.
dtype = 'float32'
elif isinstance(value, int):
# We intentionally convert the python int to int32 since it's more common in DL.
dtype = 'int32'
else:
raise NotImplementedError('Cannot automatically inference the type.'
' value={}'.format(value))
return dtype
def const(value, dtype=None):
"""Construct a constant value for a given type.
"""construct a constant
Parameters
----------
value : int or float
The input value
value : number
The content of the constant number.
dtype : str or None, optional
The data type.
Returns
-------
expr : Expr
Constant expression corresponds to the value.
const_val: tvm.Expr
The result expression.
"""
if dtype is None:
dtype = _scalar_type_inference(value)
if dtype == "uint64" and value >= (1 << 63):
return _api_internal._LargeUIntImm(
dtype, value & ((1 << 32) - 1), value >> 32)
return _api_internal._const(value, dtype)
_set_class_object_generic(ObjectGeneric, convert_to_object)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
"""Packed Function namespace."""
import ctypes
from .base import _LIB, check_call, c_str, string_types, _FFI_MODE
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
from ._cy3.core import _set_class_packed_func, _set_class_module
from ._cy3.core import PackedFuncBase
from ._cy3.core import convert_to_tvm_func
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position
from ._ctypes.packed_func import _set_class_packed_func, _set_class_module
from ._ctypes.packed_func import PackedFuncBase
from ._ctypes.packed_func import convert_to_tvm_func
PackedFuncHandle = ctypes.c_void_p
class PackedFunc(PackedFuncBase):
"""The PackedFunc object used in TVM.
Function plays an key role to bridge front and backend in TVM.
Function provide a type-erased interface, you can call function with positional arguments.
The compiled module returns Function.
TVM backend also registers and exposes its API as Functions.
For example, the developer function exposed in tvm.ir_pass are actually
C++ functions that are registered as PackedFunc
The following are list of common usage scenario of tvm.Function.
- Automatic exposure of C++ API into python
- To call PackedFunc from python side
- To call python callbacks to inspect results in generated code
- Bring python hook into C++ backend
See Also
--------
tvm.register_func: How to register global function.
tvm.get_global_func: How to get global function.
"""
_set_class_packed_func(PackedFunc)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-import
"""FFI registry to register function and objects."""
import sys
import ctypes
from .. import _api_internal
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE, _RUNTIME_ONLY
try:
# pylint: disable=wrong-import-position,unused-import
if _FFI_MODE == "ctypes":
raise ImportError()
from ._cy3.core import _register_object
from ._cy3.core import _reg_extension
from ._cy3.core import convert_to_tvm_func, _get_global_func, PackedFuncBase
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import
from ._ctypes.object import _register_object
from ._ctypes.ndarray import _reg_extension
from ._ctypes.packed_func import convert_to_tvm_func, _get_global_func, PackedFuncBase
def register_object(type_key=None):
"""register object type.
Parameters
----------
type_key : str or cls
The type key of the node
Examples
--------
The following code registers MyObject
using type key "test.MyObject"
.. code-block:: python
@tvm.register_object("test.MyObject")
class MyObject(Object):
pass
"""
object_name = type_key if isinstance(type_key, str) else type_key.__name__
def register(cls):
"""internal register function"""
if hasattr(cls, "_type_index"):
tindex = cls._type_index
else:
tidx = ctypes.c_uint()
if not _RUNTIME_ONLY:
check_call(_LIB.TVMObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tidx)))
else:
# directly skip unknown objects during runtime.
ret = _LIB.TVMObjectTypeKey2Index(
c_str(object_name), ctypes.byref(tidx))
if ret != 0:
return cls
tindex = tidx.value
_register_object(tindex, cls)
return cls
if isinstance(type_key, str):
return register
return register(type_key)
def register_extension(cls, fcreate=None):
"""Register a extension class to TVM.
After the class is registered, the class will be able
to directly pass as Function argument generated by TVM.
Parameters
----------
cls : class
The class object to be registered as extension.
fcreate : function, optional
The creation function to create a class object given handle value.
Note
----
The registered class is requires one property: _tvm_handle.
If the registered class is a subclass of NDArray,
it is required to have a class attribute _array_type_code.
Otherwise, it is required to have a class attribute _tvm_tcode.
- ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` or ```_array_type_code``` gives integer represents type
code of the class.
Returns
-------
cls : class
The class being registered.
Example
-------
The following code registers user defined class
MyTensor to be DLTensor compatible.
.. code-block:: python
@tvm.register_extension
class MyTensor(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE
def __init__(self):
self.handle = _LIB.NewDLTensor()
@property
def _tvm_handle(self):
return self.handle.value
"""
assert hasattr(cls, "_tvm_tcode")
if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
raise ValueError("Cannot register create when extension tcode is same as buildin")
_reg_extension(cls, fcreate)
return cls
def register_func(func_name, f=None, override=False):
"""Register global function
Parameters
----------
func_name : str or function
The function name
f : function, optional
The function to be registered.
override: boolean optional
Whether override existing entry.
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
The following code registers my_packed_func as global function.
Note that we simply get it back from global function table to invoke
it from python side. However, we can also invoke the same function
from C++ backend, or in the compiled TVM code.
.. code-block:: python
targs = (10, 10.0, "hello")
@tvm.register_func
def my_packed_func(*args):
assert(tuple(args) == targs)
return 10
# Get it out from global function table
f = tvm.get_global_func("my_packed_func")
assert isinstance(f, tvm.PackedFunc)
y = f(*targs)
assert y == 10
"""
if callable(func_name):
f = func_name
func_name = f.__name__
if not isinstance(func_name, str):
raise ValueError("expect string function name")
ioverride = ctypes.c_int(override)
def register(myf):
"""internal register function"""
if not isinstance(myf, PackedFuncBase):
myf = convert_to_tvm_func(myf)
check_call(_LIB.TVMFuncRegisterGlobal(
c_str(func_name), myf.handle, ioverride))
return myf
if f:
return register(f)
return register
def get_global_func(name, allow_missing=False):
"""Get a global function by name
Parameters
----------
name : str
The name of the global function
allow_missing : bool
Whether allow missing function or raise an error.
Returns
-------
func : PackedFunc
The function to be returned, None if function is missing.
"""
return _get_global_func(name, allow_missing)
def list_global_func_names():
"""Get list of global functions registered.
Returns
-------
names : list
List of global functions names.
"""
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size),
ctypes.byref(plist)))
fnames = []
for i in range(size.value):
fnames.append(py_str(plist[i]))
return fnames
def extract_ext_funcs(finit):
"""
Extract the extension PackedFuncs from a C module.
Parameters
----------
finit : ctypes function
a ctypes that takes signature of TVMExtensionDeclarer
Returns
-------
fdict : dict of str to Function
The extracted functions
"""
fdict = {}
def _list(name, func):
fdict[name] = func
myf = convert_to_tvm_func(_list)
ret = finit(myf.handle)
_ = myf
if ret != 0:
raise RuntimeError("cannot initialize with %s" % finit)
return fdict
def _get_api(f):
flocal = f
flocal.is_global = True
return flocal
def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
namespace : str
The namespace of the source registry
target_module_name : str
The target module name if different from namespace
"""
target_module_name = (
target_module_name if target_module_name else namespace)
if namespace.startswith("tvm."):
_init_api_prefix(target_module_name, namespace[4:])
else:
_init_api_prefix(target_module_name, namespace)
def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]
for name in list_global_func_names():
if prefix == "api":
fname = name
if name.startswith("_"):
target_module = sys.modules["tvm._api_internal"]
else:
target_module = module
else:
if not name.startswith(prefix):
continue
fname = name[len(prefix)+1:]
target_module = module
if fname.find(".") != -1:
continue
f = get_global_func(name)
ff = _get_api(f)
ff.__name__ = fname
ff.__doc__ = ("TVM PackedFunc %s. " % fname)
setattr(target_module, ff.__name__, ff)
......@@ -16,8 +16,6 @@
# under the License.
"""Common runtime ctypes."""
# pylint: disable=invalid-name
from __future__ import absolute_import
import ctypes
import json
import numpy as np
......
......@@ -16,17 +16,13 @@
# under the License.
"""Functions defined in TVM."""
# pylint: disable=invalid-name,unused-import,redefined-builtin
from __future__ import absolute_import as _abs
from numbers import Integral as _Integral
import tvm._ffi
from ._ffi.base import string_types, TVMError
from ._ffi.object import register_object, Object
from ._ffi.object import convert_to_object as _convert_to_object
from ._ffi.object_generic import _scalar_type_inference
from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from ._ffi.object_generic import convert, const
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
from ._ffi.runtime_ctypes import TVMType
from . import _api_internal
from . import make as _make
......@@ -75,30 +71,6 @@ def max_value(dtype):
return _api_internal._max_value(dtype)
def const(value, dtype=None):
"""construct a constant
Parameters
----------
value : number
The content of the constant number.
dtype : str or None, optional
The data type.
Returns
-------
const_val: tvm.Expr
The result expression.
"""
if dtype is None:
dtype = _scalar_type_inference(value)
if dtype == "uint64" and value >= (1 << 63):
return _api_internal._LargeUIntImm(
dtype, value & ((1 << 32) - 1), value >> 32)
return _api_internal._const(value, dtype)
def get_env_func(name):
"""Get an EnvFunc by a global name.
......@@ -121,27 +93,6 @@ def get_env_func(name):
return _api_internal._EnvFuncGet(name)
def convert(value):
"""Convert value to TVM node or function.
Parameters
----------
value : python value
Returns
-------
tvm_val : Object or Function
Converted value in TVM
"""
if isinstance(value, (Function, Object)):
return value
if callable(value):
return _convert_tvm_func(value)
return _convert_to_object(value)
def load_json(json_str):
"""Load tvm object from json_str.
......@@ -1073,10 +1024,9 @@ def floormod(a, b):
"""
return _make._OpFloorMod(a, b)
_init_api("tvm.api")
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max')
tvm._ffi._init_api("tvm.api")
......@@ -16,9 +16,9 @@
# under the License.
"""Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
import tvm._ffi
from ._ffi.object import Object, register_object
from ._ffi.function import _init_api
from ._ffi.object import Object
from . import _api_internal
class IntSet(Object):
......@@ -32,7 +32,7 @@ class IntSet(Object):
return _api_internal._IntSetIsEverything(self)
@register_object("arith.IntervalSet")
@tvm._ffi.register_object("arith.IntervalSet")
class IntervalSet(IntSet):
"""Represent set of continuous interval [min_value, max_value]
......@@ -49,7 +49,7 @@ class IntervalSet(IntSet):
_make_IntervalSet, min_value, max_value)
@register_object("arith.ModularSet")
@tvm._ffi.register_object("arith.ModularSet")
class ModularSet(Object):
"""Represent range of (coeff * x + base) for x in Z """
def __init__(self, coeff, base):
......@@ -57,7 +57,7 @@ class ModularSet(Object):
_make_ModularSet, coeff, base)
@register_object("arith.ConstIntBound")
@tvm._ffi.register_object("arith.ConstIntBound")
class ConstIntBound(Object):
"""Represent constant integer bound
......@@ -258,4 +258,4 @@ class Analyzer:
"Do not know how to handle type {}".format(type(info)))
_init_api("tvm.arith")
tvm._ffi._init_api("tvm.arith")
......@@ -15,12 +15,13 @@
# specific language governing permissions and limitations
# under the License.
""" TVM Attribute module, which is mainly used for defining attributes of operators"""
from ._ffi.object import Object, register_object
from ._ffi.function import _init_api
import tvm._ffi
from ._ffi.object import Object
from . import _api_internal
@register_object
@tvm._ffi.register_object
class Attrs(Object):
"""Attribute node, which is mainly use for defining attributes of relay operators.
......@@ -92,4 +93,4 @@ class Attrs(Object):
return self.__getattr__(item)
_init_api("tvm.attrs")
tvm._ffi._init_api("tvm.attrs")
......@@ -19,11 +19,10 @@
This module provides the functions to transform schedule to
LoweredFunc and compiled Module.
"""
from __future__ import absolute_import as _abs
import warnings
import tvm._ffi
from ._ffi.function import Function
from ._ffi.object import Object, register_object
from ._ffi.object import Object
from . import api
from . import _api_internal
from . import tensor
......@@ -115,7 +114,7 @@ class DumpIR(object):
DumpIR.scope_level -= 1
@register_object
@tvm._ffi.register_object
class BuildConfig(Object):
"""Configuration scope to set a build config option.
......
......@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Code generation related functions."""
from ._ffi.function import _init_api
import tvm._ffi
def build_module(lowered_func, target):
"""Build lowered_func into Module.
......@@ -35,4 +35,4 @@ def build_module(lowered_func, target):
"""
return _Build(lowered_func, target)
_init_api("tvm.codegen")
tvm._ffi._init_api("tvm.codegen")
......@@ -15,13 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""Container data structures used in TVM DSL."""
from __future__ import absolute_import as _abs
import tvm._ffi
from tvm import ndarray as _nd
from . import _api_internal
from ._ffi.object import Object, register_object, getitem_helper
from ._ffi.function import _init_api
from ._ffi.object import Object, getitem_helper
@register_object
@tvm._ffi.register_object
class Array(Object):
"""Array container of TVM.
......@@ -52,7 +53,7 @@ class Array(Object):
return _api_internal._ArraySize(self)
@register_object
@tvm._ffi.register_object
class EnvFunc(Object):
"""Environment function.
......@@ -66,7 +67,7 @@ class EnvFunc(Object):
return _api_internal._EnvFuncGetPackedFunc(self)
@register_object
@tvm._ffi.register_object
class Map(Object):
"""Map container of TVM.
......@@ -89,7 +90,7 @@ class Map(Object):
return _api_internal._MapSize(self)
@register_object
@tvm._ffi.register_object
class StrMap(Map):
"""A special map container that has str as key.
......@@ -101,7 +102,7 @@ class StrMap(Map):
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
@register_object
@tvm._ffi.register_object
class Range(Object):
"""Represent a range in TVM.
......@@ -110,7 +111,7 @@ class Range(Object):
"""
@register_object
@tvm._ffi.register_object
class LoweredFunc(Object):
"""Represent a LoweredFunc in TVM."""
MixedFunc = 0
......@@ -118,7 +119,7 @@ class LoweredFunc(Object):
DeviceFunc = 2
@register_object("vm.ADT")
@tvm._ffi.register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
......@@ -168,4 +169,4 @@ def tuple_object(fields=None):
return _Tuple(*fields)
_init_api("tvm.container")
tvm._ffi._init_api("tvm.container")
......@@ -19,8 +19,9 @@
import os
import tempfile
import shutil
import tvm._ffi
from tvm._ffi.base import string_types
from tvm._ffi.function import get_global_func
from tvm.contrib import graph_runtime
from tvm.ndarray import array
from . import debug_result
......@@ -64,7 +65,7 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
fcreate = ctx[0]._rpc_sess.get_function(
"tvm.graph_runtime_debug.create")
else:
fcreate = get_global_func("tvm.graph_runtime_debug.create")
fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_debug.create")
except ValueError:
raise ValueError(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
......
......@@ -16,9 +16,9 @@
# under the License.
"""Minimum graph runtime that executes graph containing TVM PackedFunc."""
import numpy as np
import tvm._ffi
from .._ffi.base import string_types
from .._ffi.function import get_global_func
from .._ffi.runtime_ctypes import TVMContext
from ..rpc import base as rpc_base
......@@ -54,7 +54,7 @@ def create(graph_json_str, libmod, ctx):
if num_rpc_ctx == len(ctx):
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create")
else:
fcreate = get_global_func("tvm.graph_runtime.create")
fcreate = tvm._ffi.get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
......
......@@ -15,11 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""External function interface to NNPACK libraries."""
from __future__ import absolute_import as _abs
import tvm._ffi
from .. import api as _api
from .. import intrin as _intrin
from .._ffi.function import _init_api
def is_available():
"""Check whether NNPACK is available, that is, `nnp_initialize()`
......@@ -202,4 +202,4 @@ def convolution_inference_weight_transform(
"tvm.contrib.nnpack.convolution_inference_weight_transform",
ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype)
_init_api("tvm.contrib.nnpack")
tvm._ffi._init_api("tvm.contrib.nnpack")
......@@ -15,11 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""External function interface to random library."""
from __future__ import absolute_import as _abs
import tvm._ffi
from .. import api as _api
from .. import intrin as _intrin
from .._ffi.function import _init_api
def randint(low, high, size, dtype='int32'):
......@@ -96,4 +95,4 @@ def normal(loc, scale, size):
"tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32')
_init_api("tvm.contrib.random")
tvm._ffi._init_api("tvm.contrib.random")
......@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""TFLite runtime that load and run tflite models."""
from .._ffi.function import get_global_func
import tvm._ffi
from ..rpc import base as rpc_base
def create(tflite_model_bytes, ctx, runtime_target='cpu'):
......@@ -44,7 +44,7 @@ def create(tflite_model_bytes, ctx, runtime_target='cpu'):
if device_type >= rpc_base.RPC_SESS_MASK:
fcreate = ctx._rpc_sess.get_function(runtime_func)
else:
fcreate = get_global_func(runtime_func)
fcreate = tvm._ffi.get_global_func(runtime_func)
return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))
......
......@@ -15,9 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Custom datatype functionality"""
from __future__ import absolute_import as _abs
import tvm._ffi
from ._ffi.function import register_func as _register_func
from . import make as _make
from .api import convert
from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
......@@ -111,7 +110,7 @@ def register_op(lower_func, op_name, target, type_name, src_type_name=None):
else:
lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \
+ type_name
_register_func(lower_func_name, lower_func)
tvm._ffi.register_func(lower_func_name, lower_func)
def create_lower_func(extern_func_name):
......
......@@ -32,7 +32,10 @@ For example, you can use addexp.a to get the left operand of an Add node.
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
from ._ffi.object import Object, register_object, ObjectGeneric
import tvm._ffi
from ._ffi.object import Object
from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType, TypeCode
from . import make as _make
from . import generic as _generic
......@@ -261,7 +264,7 @@ class CmpExpr(PrimExpr):
class LogicalExpr(PrimExpr):
pass
@register_object("Variable")
@tvm._ffi.register_object("Variable")
class Var(PrimExpr):
"""Symbolic variable.
......@@ -278,7 +281,7 @@ class Var(PrimExpr):
_api_internal._Var, name, dtype)
@register_object
@tvm._ffi.register_object
class SizeVar(Var):
"""Symbolic variable to represent a tensor index size
which is greater or equal to zero
......@@ -297,7 +300,7 @@ class SizeVar(Var):
_api_internal._SizeVar, name, dtype)
@register_object
@tvm._ffi.register_object
class Reduce(PrimExpr):
"""Reduce node.
......@@ -324,7 +327,7 @@ class Reduce(PrimExpr):
condition, value_index)
@register_object
@tvm._ffi.register_object
class FloatImm(ConstExpr):
"""Float constant.
......@@ -340,7 +343,7 @@ class FloatImm(ConstExpr):
self.__init_handle_by_constructor__(
_make.FloatImm, dtype, value)
@register_object
@tvm._ffi.register_object
class IntImm(ConstExpr):
"""Int constant.
......@@ -360,7 +363,7 @@ class IntImm(ConstExpr):
return self.value
@register_object
@tvm._ffi.register_object
class StringImm(ConstExpr):
"""String constant.
......@@ -384,7 +387,7 @@ class StringImm(ConstExpr):
return self.value != other
@register_object
@tvm._ffi.register_object
class Cast(PrimExpr):
"""Cast expression.
......@@ -401,7 +404,7 @@ class Cast(PrimExpr):
_make.Cast, dtype, value)
@register_object
@tvm._ffi.register_object
class Add(BinaryOpExpr):
"""Add node.
......@@ -418,7 +421,7 @@ class Add(BinaryOpExpr):
_make.Add, a, b)
@register_object
@tvm._ffi.register_object
class Sub(BinaryOpExpr):
"""Sub node.
......@@ -435,7 +438,7 @@ class Sub(BinaryOpExpr):
_make.Sub, a, b)
@register_object
@tvm._ffi.register_object
class Mul(BinaryOpExpr):
"""Mul node.
......@@ -452,7 +455,7 @@ class Mul(BinaryOpExpr):
_make.Mul, a, b)
@register_object
@tvm._ffi.register_object
class Div(BinaryOpExpr):
"""Div node.
......@@ -469,7 +472,7 @@ class Div(BinaryOpExpr):
_make.Div, a, b)
@register_object
@tvm._ffi.register_object
class Mod(BinaryOpExpr):
"""Mod node.
......@@ -486,7 +489,7 @@ class Mod(BinaryOpExpr):
_make.Mod, a, b)
@register_object
@tvm._ffi.register_object
class FloorDiv(BinaryOpExpr):
"""FloorDiv node.
......@@ -503,7 +506,7 @@ class FloorDiv(BinaryOpExpr):
_make.FloorDiv, a, b)
@register_object
@tvm._ffi.register_object
class FloorMod(BinaryOpExpr):
"""FloorMod node.
......@@ -520,7 +523,7 @@ class FloorMod(BinaryOpExpr):
_make.FloorMod, a, b)
@register_object
@tvm._ffi.register_object
class Min(BinaryOpExpr):
"""Min node.
......@@ -537,7 +540,7 @@ class Min(BinaryOpExpr):
_make.Min, a, b)
@register_object
@tvm._ffi.register_object
class Max(BinaryOpExpr):
"""Max node.
......@@ -554,7 +557,7 @@ class Max(BinaryOpExpr):
_make.Max, a, b)
@register_object
@tvm._ffi.register_object
class EQ(CmpExpr):
"""EQ node.
......@@ -571,7 +574,7 @@ class EQ(CmpExpr):
_make.EQ, a, b)
@register_object
@tvm._ffi.register_object
class NE(CmpExpr):
"""NE node.
......@@ -588,7 +591,7 @@ class NE(CmpExpr):
_make.NE, a, b)
@register_object
@tvm._ffi.register_object
class LT(CmpExpr):
"""LT node.
......@@ -605,7 +608,7 @@ class LT(CmpExpr):
_make.LT, a, b)
@register_object
@tvm._ffi.register_object
class LE(CmpExpr):
"""LE node.
......@@ -622,7 +625,7 @@ class LE(CmpExpr):
_make.LE, a, b)
@register_object
@tvm._ffi.register_object
class GT(CmpExpr):
"""GT node.
......@@ -639,7 +642,7 @@ class GT(CmpExpr):
_make.GT, a, b)
@register_object
@tvm._ffi.register_object
class GE(CmpExpr):
"""GE node.
......@@ -656,7 +659,7 @@ class GE(CmpExpr):
_make.GE, a, b)
@register_object
@tvm._ffi.register_object
class And(LogicalExpr):
"""And node.
......@@ -673,7 +676,7 @@ class And(LogicalExpr):
_make.And, a, b)
@register_object
@tvm._ffi.register_object
class Or(LogicalExpr):
"""Or node.
......@@ -690,7 +693,7 @@ class Or(LogicalExpr):
_make.Or, a, b)
@register_object
@tvm._ffi.register_object
class Not(LogicalExpr):
"""Not node.
......@@ -704,7 +707,7 @@ class Not(LogicalExpr):
_make.Not, a)
@register_object
@tvm._ffi.register_object
class Select(PrimExpr):
"""Select node.
......@@ -732,7 +735,7 @@ class Select(PrimExpr):
_make.Select, condition, true_value, false_value)
@register_object
@tvm._ffi.register_object
class Load(PrimExpr):
"""Load node.
......@@ -755,7 +758,7 @@ class Load(PrimExpr):
_make.Load, dtype, buffer_var, index, predicate)
@register_object
@tvm._ffi.register_object
class Ramp(PrimExpr):
"""Ramp node.
......@@ -775,7 +778,7 @@ class Ramp(PrimExpr):
_make.Ramp, base, stride, lanes)
@register_object
@tvm._ffi.register_object
class Broadcast(PrimExpr):
"""Broadcast node.
......@@ -792,7 +795,7 @@ class Broadcast(PrimExpr):
_make.Broadcast, value, lanes)
@register_object
@tvm._ffi.register_object
class Shuffle(PrimExpr):
"""Shuffle node.
......@@ -809,7 +812,7 @@ class Shuffle(PrimExpr):
_make.Shuffle, vectors, indices)
@register_object
@tvm._ffi.register_object
class Call(PrimExpr):
"""Call node.
......@@ -844,7 +847,7 @@ class Call(PrimExpr):
_make.Call, dtype, name, args, call_type, func, value_index)
@register_object
@tvm._ffi.register_object
class Let(PrimExpr):
"""Let node.
......
......@@ -28,13 +28,10 @@ HalideIR.
# TODO(@were): Make this module more complete.
# 1. Support HalideIR dumping to Hybrid Script
# 2. Support multi-level HalideIR
from __future__ import absolute_import as _abs
import inspect
import tvm._ffi
from .._ffi.base import decorate
from .._ffi.function import _init_api
from ..build_module import form_body
from .module import HybridModule
......@@ -97,4 +94,4 @@ def build(sch, inputs, outputs, name="hybrid_func"):
return HybridModule(src, name)
_init_api("tvm.hybrid")
tvm._ffi._init_api("tvm.hybrid")
......@@ -16,9 +16,9 @@
# under the License.
"""Expression Intrinsics and math functions in TVM."""
# pylint: disable=redefined-builtin
from __future__ import absolute_import as _abs
import tvm._ffi
import tvm.codegen
from ._ffi.function import register_func as _register_func
from . import make as _make
from .api import convert, const
from .expr import Call as _Call
......@@ -189,7 +189,6 @@ def call_llvm_intrin(dtype, name, *args):
call : Expr
The call expression.
"""
import tvm
llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name)
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args)
......@@ -596,7 +595,7 @@ def register_intrin_rule(target, intrin, f=None, override=False):
register_intrin_rule("opencl", "exp", my_exp_rule, override=True)
"""
return _register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
def _rule_float_suffix(op):
......@@ -650,7 +649,7 @@ def _rule_float_direct(op):
return call_pure_extern(op.dtype, op.name, *op.args)
return None
@_register_func("tvm.default_trace_action")
@tvm._ffi.register_func("tvm.default_trace_action")
def _tvm_default_trace_action(*args):
print(list(args))
......
......@@ -24,7 +24,7 @@ from . import make as _make
from . import ir_pass as _pass
from . import container as _container
from ._ffi.base import string_types
from ._ffi.object import ObjectGeneric
from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType
from .expr import Call as _Call
......
......@@ -23,6 +23,6 @@ Each api is a PackedFunc that can be called in a positional argument manner.
You can read "include/tvm/tir/ir_pass.h" for the function signature and
"src/api/api_pass.cc" for the PackedFunc's body of these functions.
"""
from ._ffi.function import _init_api
import tvm._ffi
_init_api("tvm.ir_pass")
tvm._ffi._init_api("tvm.ir_pass")
......@@ -22,8 +22,7 @@ The functions are automatically exported from C++ side via PackedFunc.
Each api is a PackedFunc that can be called in a positional argument manner.
You can use make function to build the IR node.
"""
from __future__ import absolute_import as _abs
from ._ffi.function import _init_api
import tvm._ffi
def range_by_min_extent(min_value, extent):
......@@ -85,4 +84,4 @@ def node(type_key, **kwargs):
return _Node(*args)
_init_api("tvm.make")
tvm._ffi._init_api("tvm.make")
......@@ -23,9 +23,11 @@ import sys
from enum import Enum
import tvm
import tvm._ffi
from tvm.contrib import util as _util
from tvm.contrib import cc as _cc
from .._ffi.function import _init_api
class LibType(Enum):
"""Enumeration of library types that can be compiled and loaded onto a device"""
......@@ -222,4 +224,4 @@ def get_micro_device_dir():
return micro_device_dir
_init_api("tvm.micro", "tvm.micro.base")
tvm._ffi._init_api("tvm.micro", "tvm.micro.base")
......@@ -19,9 +19,9 @@ from __future__ import absolute_import as _abs
import struct
from collections import namedtuple
import tvm._ffi
from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api
from ._ffi.module import ModuleBase, _set_class_module
from ._ffi.libinfo import find_include_path
from .contrib import cc as _cc, tar as _tar, util as _util
......@@ -333,5 +333,5 @@ def enabled(target):
return _Enabled(target)
_init_api("tvm.module")
tvm._ffi._init_api("tvm.module")
_set_class_module(Module)
......@@ -20,17 +20,15 @@ tvm.ndarray provides a minimum runtime array API to test
the correctness of the program.
"""
# pylint: disable=invalid-name,unused-import
from __future__ import absolute_import as _abs
import tvm._ffi
import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension
from ._ffi.object import register_object
@register_object
@tvm._ffi.register_object
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Node is the base class of all TVM AST.
Normally user do not need to touch this api.
"""
# pylint: disable=unused-import
from __future__ import absolute_import as _abs
from ._ffi.object import Object, register_object
......@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""FFI exposing the passes for Relay program analysis."""
import tvm._ffi
from tvm._ffi.function import _init_api
_init_api("relay._analysis", __name__)
tvm._ffi._init_api("relay._analysis", __name__)
......@@ -16,6 +16,6 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
from tvm._ffi.function import _init_api
import tvm._ffi
_init_api("relay._base", __name__)
tvm._ffi._init_api("relay._base", __name__)
......@@ -16,6 +16,6 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface for building Relay functions exposed from C++."""
from tvm._ffi.function import _init_api
import tvm._ffi
_init_api("relay.build_module", __name__)
tvm._ffi._init_api("relay.build_module", __name__)
......@@ -16,6 +16,6 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
from tvm._ffi.function import _init_api
import tvm._ffi
_init_api("relay._expr", __name__)
tvm._ffi._init_api("relay._expr", __name__)
......@@ -20,6 +20,6 @@ The constructors for all Relay AST nodes exposed from C++.
This module includes MyPy type signatures for all of the
exposed modules.
"""
from .._ffi.function import _init_api
import tvm._ffi
_init_api("relay._make", __name__)
tvm._ffi._init_api("relay._make", __name__)
......@@ -16,6 +16,6 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface to the Module exposed from C++."""
from tvm._ffi.function import _init_api
import tvm._ffi
_init_api("relay._module", __name__)
tvm._ffi._init_api("relay._module", __name__)
......@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""FFI exposing the Relay type inference and checking."""
import tvm._ffi
from tvm._ffi.function import _init_api
_init_api("relay._transform", __name__)
tvm._ffi._init_api("relay._transform", __name__)
......@@ -15,14 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""The interface of expr function exposed from C++."""
from __future__ import absolute_import
import tvm._ffi
from ... import build_module as _build
from ... import container as _container
from ..._ffi.function import _init_api, register_func
@register_func("relay.backend.lower")
@tvm._ffi.register_func("relay.backend.lower")
def lower(sch, inputs, func_name, source_func):
"""Backend function for lowering.
......@@ -61,7 +60,7 @@ def lower(sch, inputs, func_name, source_func):
f, (_container.Array, tuple, list)) else [f]
@register_func("relay.backend.build")
@tvm._ffi.register_func("relay.backend.build")
def build(funcs, target, target_host=None):
"""Backend build function.
......@@ -88,14 +87,14 @@ def build(funcs, target, target_host=None):
return _build.build(funcs, target=target, target_host=target_host)
@register_func("relay._tensor_value_repr")
@tvm._ffi.register_func("relay._tensor_value_repr")
def _tensor_value_repr(tvalue):
return str(tvalue.data.asnumpy())
@register_func("relay._constant_repr")
@tvm._ffi.register_func("relay._constant_repr")
def _tensor_constant_repr(tvalue):
return str(tvalue.data.asnumpy())
_init_api("relay.backend", __name__)
tvm._ffi._init_api("relay.backend", __name__)
......@@ -16,6 +16,6 @@
# under the License.
"""The Relay virtual machine FFI namespace.
"""
from tvm._ffi.function import _init_api
import tvm._ffi
_init_api("relay._vm", __name__)
tvm._ffi._init_api("relay._vm", __name__)
......@@ -25,7 +25,7 @@ import numpy as np
import tvm
import tvm.ndarray as _nd
from tvm import autotvm, container
from tvm.object import Object
from tvm._ffi.object import Object
from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
......
......@@ -16,8 +16,8 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck
"""The base node types for the Relay language."""
from __future__ import absolute_import as _abs
from .._ffi.object import register_object as _register_tvm_node
import tvm._ffi
from .._ffi.object import Object
from . import _make
from . import _expr
......@@ -34,9 +34,9 @@ def register_relay_node(type_key=None):
The type key of the node.
"""
if not isinstance(type_key, str):
return _register_tvm_node(
return tvm._ffi.register_object(
"relay." + type_key.__name__)(type_key)
return _register_tvm_node(type_key)
return tvm._ffi.register_object(type_key)
def register_relay_attr_node(type_key=None):
......@@ -48,9 +48,9 @@ def register_relay_attr_node(type_key=None):
The type key of the node.
"""
if not isinstance(type_key, str):
return _register_tvm_node(
return tvm._ffi.register_object(
"relay.attrs." + type_key.__name__)(type_key)
return _register_tvm_node(type_key)
return tvm._ffi.register_object(type_key)
class RelayNode(Object):
......
......@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ..._ffi.function import _init_api
import tvm._ffi
_init_api("relay.op._make", __name__)
tvm._ffi._init_api("relay.op._make", __name__)
......@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api
import tvm._ffi
_init_api("relay.op.annotation._make", __name__)
tvm._ffi._init_api("relay.op.annotation._make", __name__)
......@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api
import tvm._ffi
_init_api("relay.op.contrib._make", __name__)
tvm._ffi._init_api("relay.op.contrib._make", __name__)
......@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api
import tvm._ffi
_init_api("relay.op.image._make", __name__)
tvm._ffi._init_api("relay.op.image._make", __name__)
......@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api
import tvm._ffi
_init_api("relay.op.memory._make", __name__)
tvm._ffi._init_api("relay.op.memory._make", __name__)
......@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api
import tvm._ffi
_init_api("relay.op.nn._make", __name__)
tvm._ffi._init_api("relay.op.nn._make", __name__)
......@@ -17,8 +17,7 @@
#pylint: disable=unused-argument
"""The base node types for the Relay language."""
import topi
from ..._ffi.function import _init_api
import tvm._ffi
from ..base import register_relay_node
from ..expr import Expr
......@@ -283,8 +282,6 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
return register(op_name, "FShapeFunc", shape_func, level)
_init_api("relay.op", __name__)
@register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs):
return lower(schedule, list(inputs) + list(outputs), name=name)
......@@ -320,3 +317,5 @@ def debug(expr, debug_func=None):
name = ''
return _make.debug(expr, name)
tvm._ffi._init_api("relay.op", __name__)
......@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api
import tvm._ffi
_init_api("relay.op.vision._make", __name__)
tvm._ffi._init_api("relay.op.vision._make", __name__)
......@@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api
import tvm._ffi
_init_api("relay.qnn.op._make", __name__)
tvm._ffi._init_api("relay.qnn.op._make", __name__)
......@@ -16,11 +16,10 @@
# under the License.
#pylint: disable=unused-argument,inconsistent-return-statements
"""Internal module for registering attribute for annotation."""
from __future__ import absolute_import
import warnings
import topi
from ..._ffi.function import register_func
import tvm._ffi
from .. import expr as _expr
from .. import analysis as _analysis
from .. import op as _op
......@@ -144,7 +143,8 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
qctx.qnode_map[key] = qnode
return qnode
register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
tvm._ffi.register_func(
"relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
@register_annotate_function("nn.contrib_conv2d_NCHWc")
......
......@@ -16,7 +16,6 @@
# under the License.
#pylint: disable=unused-argument
"""Internal module for quantization."""
from __future__ import absolute_import
from tvm._ffi.function import _init_api
import tvm._ffi
_init_api("relay._quantize", __name__)
tvm._ffi._init_api("relay._quantize", __name__)
......@@ -26,8 +26,8 @@ import errno
import struct
import random
import logging
import tvm._ffi
from .._ffi.function import _init_api
from .._ffi.base import py_str
# Magic header for RPC data plane
......@@ -179,4 +179,4 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
# Still use tvm.rpc for the foreign functions
_init_api("tvm.rpc", "tvm.rpc.base")
tvm._ffi._init_api("tvm.rpc", "tvm.rpc.base")
......@@ -21,11 +21,11 @@ import os
import socket
import struct
import time
import tvm._ffi
from . import base
from ..contrib import util
from .._ffi.base import TVMError
from .._ffi import function
from .._ffi import ndarray as nd
from ..module import load as _load_module
......@@ -185,7 +185,7 @@ class LocalSession(RPCSession):
def __init__(self):
# pylint: disable=super-init-not-called
self.context = nd.context
self.get_function = function.get_global_func
self.get_function = tvm._ffi.get_global_func
self._temp = util.tempdir()
def upload(self, data, target=None):
......
......@@ -25,9 +25,6 @@ Server is TCP based with the following protocol:
- {server|client}:device-type[:random-key] [-timeout=timeout]
"""
# pylint: disable=invalid-name
from __future__ import absolute_import
import os
import ctypes
import socket
......@@ -39,8 +36,8 @@ import subprocess
import time
import sys
import signal
import tvm._ffi
from .._ffi.function import register_func
from .._ffi.base import py_str
from .._ffi.libinfo import find_lib_path
from ..module import load as _load_module
......@@ -58,11 +55,11 @@ def _server_env(load_library, work_path=None):
temp = util.tempdir()
# pylint: disable=unused-variable
@register_func("tvm.rpc.server.workpath")
@tvm._ffi.register_func("tvm.rpc.server.workpath")
def get_workpath(path):
return temp.relpath(path)
@register_func("tvm.rpc.server.load_module", override=True)
@tvm._ffi.register_func("tvm.rpc.server.load_module", override=True)
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
......
......@@ -15,38 +15,19 @@
# specific language governing permissions and limitations
# under the License.
"""The computation schedule api of TVM."""
from __future__ import absolute_import as _abs
import tvm._ffi
from ._ffi.base import string_types
from ._ffi.object import Object, register_object
from ._ffi.object import convert_to_object as _convert_to_object
from ._ffi.function import _init_api, Function
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from ._ffi.object import Object
from ._ffi.object_generic import convert
from . import _api_internal
from . import tensor as _tensor
from . import expr as _expr
from . import container as _container
def convert(value):
"""Convert value to TVM object or function.
Parameters
----------
value : python value
Returns
-------
tvm_val : Object or Function
Converted value in TVM
"""
if isinstance(value, (Function, Object)):
return value
if callable(value):
return _convert_tvm_func(value)
return _convert_to_object(value)
@register_object
@tvm._ffi.register_object
class Buffer(Object):
"""Symbolic data buffer in TVM.
......@@ -156,22 +137,22 @@ class Buffer(Object):
return _api_internal._BufferVStore(self, begin, value)
@register_object
@tvm._ffi.register_object
class Split(Object):
"""Split operation on axis."""
@register_object
@tvm._ffi.register_object
class Fuse(Object):
"""Fuse operation on axis."""
@register_object
@tvm._ffi.register_object
class Singleton(Object):
"""Singleton axis."""
@register_object
@tvm._ffi.register_object
class IterVar(Object, _expr.ExprOp):
"""Represent iteration variable.
......@@ -214,7 +195,7 @@ def create_schedule(ops):
return _api_internal._CreateSchedule(ops)
@register_object
@tvm._ffi.register_object
class Schedule(Object):
"""Schedule for all the stages."""
def __getitem__(self, k):
......@@ -348,7 +329,7 @@ class Schedule(Object):
return factored[0] if len(factored) == 1 else factored
@register_object
@tvm._ffi.register_object
class Stage(Object):
"""A Stage represents schedule for one operation."""
def split(self, parent, factor=None, nparts=None):
......@@ -670,4 +651,4 @@ class Stage(Object):
"""
_api_internal._StageOpenGL(self)
_init_api("tvm.schedule")
tvm._ffi._init_api("tvm.schedule")
......@@ -29,15 +29,15 @@ Each statement node have subfields that can be visited from python side.
assert isinstance(st, tvm.stmt.Store)
assert(st.buffer_var == a)
"""
from __future__ import absolute_import as _abs
from ._ffi.object import Object, register_object
import tvm._ffi
from ._ffi.object import Object
from . import make as _make
class Stmt(Object):
pass
@register_object
@tvm._ffi.register_object
class LetStmt(Stmt):
"""LetStmt node.
......@@ -57,7 +57,7 @@ class LetStmt(Stmt):
_make.LetStmt, var, value, body)
@register_object
@tvm._ffi.register_object
class AssertStmt(Stmt):
"""AssertStmt node.
......@@ -77,7 +77,7 @@ class AssertStmt(Stmt):
_make.AssertStmt, condition, message, body)
@register_object
@tvm._ffi.register_object
class ProducerConsumer(Stmt):
"""ProducerConsumer node.
......@@ -97,7 +97,7 @@ class ProducerConsumer(Stmt):
_make.ProducerConsumer, func, is_producer, body)
@register_object
@tvm._ffi.register_object
class For(Stmt):
"""For node.
......@@ -137,7 +137,7 @@ class For(Stmt):
for_type, device_api, body)
@register_object
@tvm._ffi.register_object
class Store(Stmt):
"""Store node.
......@@ -160,7 +160,7 @@ class Store(Stmt):
_make.Store, buffer_var, value, index, predicate)
@register_object
@tvm._ffi.register_object
class Provide(Stmt):
"""Provide node.
......@@ -183,7 +183,7 @@ class Provide(Stmt):
_make.Provide, func, value_index, value, args)
@register_object
@tvm._ffi.register_object
class Allocate(Stmt):
"""Allocate node.
......@@ -215,7 +215,7 @@ class Allocate(Stmt):
extents, condition, body)
@register_object
@tvm._ffi.register_object
class AttrStmt(Stmt):
"""AttrStmt node.
......@@ -238,7 +238,7 @@ class AttrStmt(Stmt):
_make.AttrStmt, node, attr_key, value, body)
@register_object
@tvm._ffi.register_object
class Free(Stmt):
"""Free node.
......@@ -252,7 +252,7 @@ class Free(Stmt):
_make.Free, buffer_var)
@register_object
@tvm._ffi.register_object
class Realize(Stmt):
"""Realize node.
......@@ -288,7 +288,7 @@ class Realize(Stmt):
bounds, condition, body)
@register_object
@tvm._ffi.register_object
class SeqStmt(Stmt):
"""Sequence of statements.
......@@ -308,7 +308,7 @@ class SeqStmt(Stmt):
return len(self.seq)
@register_object
@tvm._ffi.register_object
class IfThenElse(Stmt):
"""IfThenElse node.
......@@ -328,7 +328,7 @@ class IfThenElse(Stmt):
_make.IfThenElse, condition, then_case, else_case)
@register_object
@tvm._ffi.register_object
class Evaluate(Stmt):
"""Evaluate node.
......@@ -342,7 +342,7 @@ class Evaluate(Stmt):
_make.Evaluate, value)
@register_object
@tvm._ffi.register_object
class Prefetch(Stmt):
"""Prefetch node.
......
......@@ -54,12 +54,11 @@ The list of options include:
We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string.
We can also use other specific function in this module to create specific targets.
"""
from __future__ import absolute_import
import warnings
import tvm._ffi
from ._ffi.base import _LIB_NAME
from ._ffi.object import Object, register_object
from ._ffi.object import Object
from . import _api_internal
try:
......@@ -80,7 +79,7 @@ def _merge_opts(opts, new_opts):
return opts
@register_object
@tvm._ffi.register_object
class Target(Object):
"""Target device information, use through TVM API.
......@@ -146,7 +145,7 @@ class Target(Object):
_api_internal._ExitTargetScope(self)
@register_object
@tvm._ffi.register_object
class GenericFunc(Object):
"""GenericFunc node reference. This represents a generic function
that may be specialized for different targets. When this object is
......
......@@ -16,9 +16,11 @@
# under the License.
"""Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
from ._ffi.object import Object, register_object, ObjectGeneric, \
convert_to_object
import tvm._ffi
from ._ffi.object import Object
from ._ffi.object_generic import ObjectGeneric, convert_to_object
from . import _api_internal
from . import make as _make
from . import expr as _expr
......@@ -47,7 +49,7 @@ class TensorSlice(ObjectGeneric, _expr.ExprOp):
"""Data content of the tensor."""
return self.tensor.dtype
@register_object
@tvm._ffi.register_object
class TensorIntrinCall(Object):
"""Intermediate structure for calling a tensor intrinsic."""
......@@ -55,7 +57,7 @@ class TensorIntrinCall(Object):
itervar_cls = None
@register_object
@tvm._ffi.register_object
class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""
......@@ -157,12 +159,12 @@ class Operation(Object):
return _api_internal._OpInputTensors(self)
@register_object
@tvm._ffi.register_object
class PlaceholderOp(Operation):
"""Placeholder operation."""
@register_object
@tvm._ffi.register_object
class BaseComputeOp(Operation):
"""Compute operation."""
@property
......@@ -176,18 +178,18 @@ class BaseComputeOp(Operation):
return self.__getattr__("reduce_axis")
@register_object
@tvm._ffi.register_object
class ComputeOp(BaseComputeOp):
"""Scalar operation."""
pass
@register_object
@tvm._ffi.register_object
class TensorComputeOp(BaseComputeOp):
"""Tensor operation."""
@register_object
@tvm._ffi.register_object
class ScanOp(Operation):
"""Scan operation."""
@property
......@@ -196,12 +198,12 @@ class ScanOp(Operation):
return self.__getattr__("scan_axis")
@register_object
@tvm._ffi.register_object
class ExternOp(Operation):
"""External operation."""
@register_object
@tvm._ffi.register_object
class HybridOp(Operation):
"""Hybrid operation."""
@property
......@@ -210,7 +212,7 @@ class HybridOp(Operation):
return self.__getattr__("axis")
@register_object
@tvm._ffi.register_object
class Layout(Object):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
......@@ -270,7 +272,7 @@ class Layout(Object):
return _api_internal._LayoutFactorOf(self, axis)
@register_object
@tvm._ffi.register_object
class BijectiveLayout(Object):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
......
......@@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Tensor intrinsics"""
from __future__ import absolute_import as _abs
import tvm._ffi
from . import _api_internal
from . import api as _api
from . import expr as _expr
......@@ -24,7 +25,7 @@ from . import make as _make
from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config
from ._ffi.object import Object, register_object
from ._ffi.object import Object
def _get_region(tslice):
......@@ -41,7 +42,7 @@ def _get_region(tslice):
region.append(_make.range_by_min_extent(begin, 1))
return region
@register_object
@tvm._ffi.register_object
class TensorIntrin(Object):
"""Tensor intrinsic functions for certain computation.
......
......@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""FFI for CUDA TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix
_init_api_prefix("topi.cpp.cuda", "topi.cuda")
tvm._ffi._init_api("topi.cuda", "topi.cpp.cuda")
......@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""FFI for generic TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix
_init_api_prefix("topi.cpp.generic", "topi.generic")
tvm._ffi._init_api("topi.generic", "topi.cpp.generic")
......@@ -18,8 +18,8 @@
import sys
import os
import ctypes
import tvm._ffi
from tvm._ffi.function import _init_api_prefix
from tvm._ffi import libinfo
def _get_lib_names():
......@@ -41,4 +41,4 @@ def _load_lib():
_LIB, _LIB_NAME = _load_lib()
_init_api_prefix("topi.cpp", "topi")
tvm._ffi._init_api("topi", "topi.cpp")
......@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""FFI for NN TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix
_init_api_prefix("topi.cpp.nn", "topi.nn")
tvm._ffi._init_api("topi.nn", "topi.cpp.nn")
......@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""FFI for Rocm TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix
_init_api_prefix("topi.cpp.rocm", "topi.rocm")
tvm._ffi._init_api("topi.rocm", "topi.cpp.rocm")
......@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""FFI for TOPI utility functions"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix
_init_api_prefix("topi.cpp.util", "topi.util")
tvm._ffi._init_api("topi.util", "topi.cpp.util")
......@@ -16,9 +16,8 @@
# under the License.
"""FFI for vision TOPI ops and schedules"""
from tvm._ffi.function import _init_api_prefix
import tvm._ffi
from . import yolo
_init_api_prefix("topi.cpp.vision", "topi.vision")
tvm._ffi._init_api("topi.vision", "topi.cpp.vision")
......@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""FFI for Yolo TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix
_init_api_prefix("topi.cpp.vision.yolo", "topi.vision.yolo")
tvm._ffi._init_api("topi.vision.yolo", "topi.cpp.vision.yolo")
......@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""FFI for x86 TOPI ops and schedules"""
import tvm._ffi
from tvm._ffi.function import _init_api_prefix
_init_api_prefix("topi.cpp.x86", "topi.x86")
tvm._ffi._init_api("topi.x86", "topi.cpp.x86")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment