Commit 535a97c9 by Tianqi Chen Committed by GitHub

[PYTHON/FFI] Enable Cython FFI (#106)

* [PYTHON/FFI] Enable Cython FFI

* fix cython
parent 26d91985
...@@ -11,7 +11,7 @@ endif ...@@ -11,7 +11,7 @@ endif
include $(config) include $(config)
# specify tensor path # specify tensor path
.PHONY: clean all test doc pylint cpplint lint verilog .PHONY: clean all test doc pylint cpplint lint verilog cython cython2 cython3
all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a
...@@ -95,8 +95,6 @@ lib/libtvm_runtime.so: $(RUNTIME_DEP) ...@@ -95,8 +95,6 @@ lib/libtvm_runtime.so: $(RUNTIME_DEP)
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS) $(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lib/libtvm.a: $(ALL_DEP) lib/libtvm.a: $(ALL_DEP)
@mkdir -p $(@D) @mkdir -p $(@D)
ar crv $@ $(filter %.o, $?) ar crv $@ $(filter %.o, $?)
...@@ -117,6 +115,19 @@ lint: cpplint pylint ...@@ -117,6 +115,19 @@ lint: cpplint pylint
doc: doc:
doxygen docs/Doxyfile doxygen docs/Doxyfile
# Cython build
cython:
cd python; python setup.py build_ext --inplace
cython2:
cd python; python2 setup.py build_ext --inplace
cython3:
cd python; python3 setup.py build_ext --inplace
cyclean:
rm -rf python/tvm/*/*/*.so python/tvm/*/*/*.cpp
clean: clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d $(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d
......
...@@ -39,6 +39,24 @@ TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code); ...@@ -39,6 +39,24 @@ TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
TVM_DLL int TVMNodeFree(NodeHandle handle); TVM_DLL int TVMNodeFree(NodeHandle handle);
/*! /*!
* \brief Convert type key to type index.
* \param type_key The key of the type.
* \param out_index the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMNodeTypeKey2Index(const char* type_key,
int* out_index);
/*!
* \brief Get runtime type index of the node.
* \param handle the node handle.
* \param out_index the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index);
/*!
* \brief get attributes given key * \brief get attributes given key
* \param handle The node handle * \param handle The node handle
* \param key The attribute name * \param key The attribute name
......
build
*.cpp
\ No newline at end of file
...@@ -6,21 +6,74 @@ import os ...@@ -6,21 +6,74 @@ import os
import sys import sys
import setuptools import setuptools
# need to use distutils.core for correct placement of cython dll
if "--inplace" in sys.argv:
from distutils.core import setup
from distutils.extension import Extension
else:
from setuptools import setup
from setuptools.extension import Extension
CURRENT_DIR = os.path.dirname(__file__) CURRENT_DIR = os.path.dirname(__file__)
libinfo_py = os.path.join(CURRENT_DIR, 'tvm/libinfo.py') libinfo_py = os.path.join(CURRENT_DIR, 'tvm/_ffi/libinfo.py')
libinfo = {'__file__': libinfo_py} libinfo = {'__file__': libinfo_py}
exec(compile(open(libinfo_py, 'rb').read(), libinfo_py, 'exec'), libinfo, libinfo) exec(compile(open(libinfo_py, 'rb').read(), libinfo_py, 'exec'), libinfo, libinfo)
LIB_PATH = libinfo['find_lib_path']() LIB_PATH = libinfo['find_lib_path']()
print(LIB_PATH)
__version__ = libinfo['__version__'] __version__ = libinfo['__version__']
def config_cython():
"""Try to configure cython and return cython configuration"""
if os.name == 'nt':
print("WARNING: Cython is not supported on Windows, will compile without cython module")
return []
try:
from Cython.Build import cythonize
# from setuptools.extension import Extension
if sys.version_info >= (3, 0):
subdir = "_cy3"
else:
subdir = "_cy2"
ret = []
path = "tvm/_ffi/_cython"
if os.name == 'nt':
library_dirs = ['tvm', '../build/Release', '../build']
libraries = ['libtvm']
else:
library_dirs = None
libraries = None
for fn in os.listdir(path):
if not fn.endswith(".pyx"):
continue
ret.append(Extension(
"tvm._ffi.%s.%s" % (subdir, fn[:-4]),
["tvm/_ffi/_cython/%s" % fn],
include_dirs=["../include/",
"../dmlc-core/include",
"../dlpack/include",
],
library_dirs=library_dirs,
libraries=libraries,
language="c++"))
return cythonize(ret)
except ImportError:
print("WARNING: Cython is not installed, will compile without cython module")
return []
setuptools.setup( setuptools.setup(
name='tvm', name='tvm',
version=__version__, version=__version__,
description='A domain specific language(DSL) for tensor computations.', description='A domain specific language(DSL) for tensor computations.',
packages=setuptools.find_packages(),
install_requires=[ install_requires=[
'numpy', 'numpy',
], ],
data_files=[('tvm', [LIB_PATH[0]])] zip_safe=False,
) packages=[
'tvm', 'tvm.addon',
'tvm._ffi', 'tvm._ffi._ctypes',
'tvm._ffi._cy2', 'tvm._ffi._cy3'
],
data_files=[('tvm', [LIB_PATH[0]])],
url='https://github.com/tqchen/tvm',
ext_modules=config_cython())
...@@ -19,9 +19,7 @@ from . import ndarray as nd ...@@ -19,9 +19,7 @@ from . import ndarray as nd
from .ndarray import cpu, gpu, opencl, cl, vpi from .ndarray import cpu, gpu, opencl, cl, vpi
from ._ffi.function import Function from ._ffi.function import Function
from ._ffi.base import TVMError, __version__
from ._base import TVMError
from ._base import __version__
from .api import * from .api import *
from .intrin import * from .intrin import *
from .node import register_node from .node import register_node
......
"""ctypes support""" """C interfacing code.
This namespace contains everything that interacts with C code.
Most TVM C related object are ctypes compatible, which means
they contains a handle field that is ctypes.c_void_p and can
be used via ctypes function calls.
Some performance critical functions are implemented by cython
and have a ctypes fallback implementation.
"""
"""ctypes specific implementation of FFI"""
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement
"""Function configuration API."""
from __future__ import absolute_import
import ctypes
import traceback
from numbers import Number, Integral
from ..base import _LIB, check_call
from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import TVMType, TVMByteArray, NDArrayBase
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .node import NodeBase
FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle):
"""callback to free resources when it it not needed."""
pyobj = ctypes.cast(rhandle, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
def convert_to_tvm_func(pyfunc):
"""Convert a python function to TVM function
Parameters
----------
pyfunc : python function
The python function to be converted.
Returns
-------
tvmfunc: tvm.nd.Function
The converted tvm function.
"""
local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, ret, _):
""" ctypes function """
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
pyargs = (C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args))
# pylint: disable=broad-except
try:
rv = local_pyfunc(*pyargs)
except Exception:
msg = traceback.format_exc()
_LIB.TVMAPISetLastError(c_str(msg))
return -1
if rv is not None:
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value")
temp_args = []
values, tcodes, _ = _make_tvm_args((rv,), temp_args)
if not isinstance(ret, TVMRetValueHandle):
ret = TVMRetValueHandle(ret)
check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0])))
_ = temp_args
_ = rv
return 0
handle = FunctionHandle()
f = TVMPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
return _CLASS_FUNCTION(handle, False)
def _make_tvm_args(args, temp_args):
"""Pack arguments into c args tvm call accept"""
num_args = len(args)
values = (TVMValue * num_args)()
type_codes = (ctypes.c_int * num_args)()
for i, arg in enumerate(args):
if isinstance(arg, NodeBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.NODE_HANDLE
elif arg is None:
values[i].v_handle = None
type_codes[i] = TypeCode.NULL
elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = TypeCode.ARRAY_HANDLE
elif isinstance(arg, Integral):
values[i].v_int64 = arg
type_codes[i] = TypeCode.INT
elif isinstance(arg, Number):
values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, TVMType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr)
type_codes[i] = TypeCode.BYTES
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.NODE_HANDLE
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, FunctionBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
elif callable(arg):
arg = convert_to_tvm_func(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
class FunctionBase(object):
"""Function base."""
__slots__ = ["handle", "is_global"]
# pylint: disable=no-member
def __init__(self, handle, is_global):
"""Initialize the function with handle
Parameters
----------
handle : FunctionHandle
the handle to the underlying function.
is_global : bool
Whether this is a global function in python
"""
self.handle = handle
self.is_global = is_global
def __del__(self):
if not self.is_global:
check_call(_LIB.TVMFuncFree(self.handle))
def __call__(self, *args):
"""Call the function with positional arguments
args : list
The positional arguments to the function call.
"""
temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.TVMFuncCall(
self.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
def _return_module(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, ModuleHandle):
handle = ModuleHandle(handle)
return _CLASS_MODULE(handle)
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, FunctionHandle):
handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False)
# setup return handle for function type
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE)
_CLASS_MODULE = None
_CLASS_FUNCTION = None
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
# pylint: disable=invalid-name, protected-access
# pylint: disable=no-member, missing-docstring
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..node_generic import _set_class_node_base
from .types import TVMValue, TypeCode
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
NodeHandle = ctypes.c_void_p
"""Maps node type to its constructor"""
NODE_TYPE = {}
def _register_node(index, cls):
"""register node class"""
NODE_TYPE[index] = cls
def _return_node(x):
"""Return node function"""
handle = x.v_handle
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
tindex = ctypes.c_int()
check_call(_LIB.TVMNodeGetTypeIndex(handle, ctypes.byref(tindex)))
return NODE_TYPE.get(tindex.value, NodeBase)(handle)
RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
_return_node, TypeCode.NODE_HANDLE)
class NodeBase(object):
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
self.handle = handle
def __del__(self):
check_call(_LIB.TVMNodeFree(self.handle))
def __getattr__(self, name):
ret_val = TVMValue()
ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name),
ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
if not ret_success.value:
raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return RETURN_SWITCH[ret_type_code.value](ret_val)
_set_class_node_base(NodeBase)
...@@ -3,10 +3,8 @@ ...@@ -3,10 +3,8 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import ctypes import ctypes
import numpy as np from ..base import py_str, check_call, _LIB
from .._base import py_str, check_call, _LIB from ..ndarray import TVMByteArray
tvm_shape_index_t = ctypes.c_int64
class TypeCode(object): class TypeCode(object):
"""Type code used in API calls""" """Type code used in API calls"""
...@@ -23,66 +21,6 @@ class TypeCode(object): ...@@ -23,66 +21,6 @@ class TypeCode(object):
STR = 10 STR = 10
BYTES = 11 BYTES = 11
def _api_type(code):
"""create a type accepted by API"""
t = TVMType()
t.bits = 64
t.lanes = 1
t.type_code = code
return t
class TVMType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str, lanes=1):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
elif type_str.startswith("handle"):
self.type_code = 4
bits = 64
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
def __ne__(self, other):
return not self.__eq__(other)
class TVMValue(ctypes.Union): class TVMValue(ctypes.Union):
"""TVMValue in C API""" """TVMValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64), _fields_ = [("v_int64", ctypes.c_int64),
...@@ -90,11 +28,6 @@ class TVMValue(ctypes.Union): ...@@ -90,11 +28,6 @@ class TVMValue(ctypes.Union):
("v_handle", ctypes.c_void_p), ("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p)] ("v_str", ctypes.c_char_p)]
class TVMByteArray(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
TVMPackedCFunc = ctypes.CFUNCTYPE( TVMPackedCFunc = ctypes.CFUNCTYPE(
ctypes.c_int, ctypes.c_int,
......
"""cython2 namespace"""
"""cython3 namespace"""
from ..base import TVMError
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from libc.stdint cimport int64_t, uint8_t, uint16_t
import ctypes
cdef enum TVMTypeCode:
kInt = 0
kUInt = 1
kFloat = 2
kHandle = 3
kNull = 4
kArrayHandle = 5
kTVMType = 6
kNodeHandle = 7
kModuleHandle = 8
kFuncHandle = 9
kStr = 10
kBytes = 11
cdef extern from "tvm/runtime/c_runtime_api.h":
struct DLType:
uint8_t code
uint8_t bits
uint16_t lanes
ctypedef struct TVMValue:
int64_t v_int64
double v_float64
void* v_handle
const char* v_str
DLType v_type
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* NodeHandle
ctypedef int (*TVMPackedCFunc)(
TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* resource_handle)
ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)
cdef extern from "tvm/runtime/c_runtime_api.h":
void TVMAPISetLastError(const char* msg);
const char *TVMGetLastError();
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* arg_values,
int* type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code)
int TVMFuncFree(TVMFunctionHandle func)
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue value,
int type_code)
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out)
cdef extern from "tvm/c_api.h":
int TVMCbArgToReturn(TVMValue* value, int code)
int TVMNodeFree(NodeHandle handle)
TVMNodeTypeKey2Index(const char* type_key,
int* out_index)
int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index)
int TVMNodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* out_value,
int* out_type_code,
int* out_success)
cdef inline py_str(const char* x):
if PY_MAJOR_VERSION < 3:
return x
else:
return x.decode("utf-8")
cdef inline c_str(pystr):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return pystr.encode("utf-8")
cdef inline CALL(int ret):
if ret != 0:
raise TVMError(TVMGetLastError())
cdef inline object ctypes_handle(void* chandle):
"""Cast C handle to ctypes handle."""
return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)
cdef inline void* c_handle(object handle):
"""Cast C types handle to c handle."""
cdef unsigned long long v_ptr
v_ptr = handle.value
return <void*>(v_ptr)
include "./base.pxi"
include "./node.pxi"
include "./function.pxi"
import ctypes
import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import NDArrayBase, TVMType, TVMByteArray
print("TVM: Initializing cython mode...")
cdef void tvm_callback_finalize(void* fhandle):
local_pyfunc = <object>(fhandle)
Py_DECREF(local_pyfunc)
cdef int tvm_callback(TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* fhandle):
cdef list pyargs
cdef TVMValue value
cdef int tcode
local_pyfunc = <object>(fhandle)
pyargs = []
for i in range(num_args):
value = args[i]
tcode = type_codes[i]
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle):
CALL(TVMCbArgToReturn(&value, tcode))
pyargs.append(make_ret(value, tcode))
try:
rv = local_pyfunc(*pyargs)
except Exception:
msg = traceback.format_exc()
TVMAPISetLastError(c_str(msg))
return -1
if rv is not None:
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value")
temp_args = []
make_arg(rv, &value, &tcode, temp_args)
CALL(TVMCFuncSetReturn(ret, value, tcode))
return 0
def convert_to_tvm_func(object pyfunc):
"""Convert a python function to TVM function
Parameters
----------
pyfunc : python function
The python function to be converted.
Returns
-------
tvmfunc: tvm.Function
The converted tvm function.
"""
cdef TVMFunctionHandle chandle
Py_INCREF(pyfunc)
CALL(TVMFuncCreateFromCFunc(tvm_callback,
<void*>(pyfunc),
tvm_callback_finalize,
&chandle))
return _CLASS_FUNCTION(ctypes_handle(chandle), False)
cdef inline void make_arg(object arg,
TVMValue* value,
int* tcode,
list temp_args):
"""Pack arguments into c args tvm call accept"""
if isinstance(arg, NodeBase):
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
elif isinstance(arg, NDArrayBase):
value[0].v_handle = c_handle(
ctypes.cast(arg.handle, ctypes.c_void_p))
tcode[0] = kArrayHandle
elif isinstance(arg, Integral):
value[0].v_int64 = arg
tcode[0] = kInt
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
elif arg is None:
value[0].v_handle = NULL
tcode[0] = kNull
elif isinstance(arg, TVMType):
tstr = c_str(str(arg))
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
arr.size = len(arg)
value[0].v_handle = <void*>(
<unsigned long long>ctypes.addressof(arr))
tcode[0] = kBytes
temp_args.append(arr)
elif isinstance(arg, string_types):
tstr = c_str(arg)
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
elif callable(arg):
arg = convert_to_tvm_func(arg)
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
cdef inline bytearray make_ret_bytes(void* chandle):
handle = ctypes_handle(chandle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
size = arr.size
res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed')
return res
cdef inline object make_ret(TVMValue value, int tcode):
"""convert result to return value."""
if tcode == kNodeHandle:
return make_ret_node(value.v_handle)
elif tcode == kNull:
return None
elif tcode == kInt:
return value.v_int64
elif tcode == kFloat:
return value.v_float64
elif tcode == kStr:
return py_str(value.v_str)
elif tcode == kBytes:
return make_ret_bytes(value.v_handle)
elif tcode == kHandle:
return ctypes_handle(value.v_handle)
elif tcode == kModuleHandle:
return _CLASS_MODULE(ctypes_handle(value.v_handle))
elif tcode == kFuncHandle:
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
else:
raise ValueError("Unhandled type code %d" % tcode)
cdef inline object FuncCall2(void* chandle, tuple args, int nargs):
cdef TVMValue[2] values
cdef int[2] tcodes
cdef TVMValue ret_val
cdef int ret_code
nargs = len(args)
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, &ret_val, &ret_code))
return make_ret(ret_val, ret_code)
cdef inline object FuncCall(void* chandle, tuple args):
cdef int nargs
nargs = len(args)
if nargs <= 2:
return FuncCall2(chandle, args, nargs)
cdef vector[TVMValue] values
cdef vector[int] tcodes
cdef TVMValue ret_val
cdef int ret_code
values.resize(max(nargs, 1))
tcodes.resize(max(nargs, 1))
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, &ret_val, &ret_code))
return make_ret(ret_val, ret_code)
cdef class FunctionBase:
cdef TVMFunctionHandle chandle
cdef int is_global
cdef _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property is_global:
def __get__(self):
return self.c_is_global != 0
def __set__(self, value):
self.c_is_global = value
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle, is_global):
self._set_handle(handle)
self.c_is_global = is_global
def __dealloc__(self):
if self.is_global == 0:
CALL(TVMFuncFree(self.chandle))
def __call__(self, *args):
return FuncCall(self.chandle, args)
_CLASS_FUNCTION = None
_CLASS_MODULE = None
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
from ..base import string_types
from ..node_generic import _set_class_node_base
"""Maps node type to its constructor"""
NODE_TYPE = []
def _register_node(int index, object cls):
"""register node class"""
while len(NODE_TYPE) <= index:
NODE_TYPE.append(None)
NODE_TYPE[index] = cls
cdef inline object make_ret_node(void* chandle):
global NODE_TYPE
cdef int tindex
cdef list node_type
cdef object cls
node_type = NODE_TYPE
CALL(TVMNodeGetTypeIndex(chandle, &tindex))
if tindex < len(node_type):
cls = node_type[tindex]
if cls is not None:
obj = cls(None)
else:
obj = NodeBase(None)
(<NodeBase>obj).chandle = chandle
return obj
cdef class NodeBase:
cdef void* chandle
cdef _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
ptr = handle.value
self.chandle = <void*>(ptr)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes_handle(self.chandle)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
def __dealloc__(self):
CALL(TVMNodeFree(self.chandle))
def __getattr__(self, name):
cdef TVMValue ret_val
cdef int ret_type_code, ret_succ
CALL(TVMNodeGetAttr(self.chandle, c_str(name),
&ret_val, &ret_type_code, &ret_succ))
if ret_succ == 0:
raise AttributeError(
"'%s' object has no attribute '%s'" % (type(self), name))
return make_ret(ret_val, ret_type_code)
_set_class_node_base(NodeBase)
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
from __future__ import absolute_import from __future__ import absolute_import
import sys import sys
import os
import ctypes import ctypes
import numpy as np import numpy as np
from . import libinfo from . import libinfo
__all__ = ['TVMError']
#---------------------------- #----------------------------
# library loading # library loading
#---------------------------- #----------------------------
...@@ -25,7 +25,7 @@ else: ...@@ -25,7 +25,7 @@ else:
class TVMError(Exception): class TVMError(Exception):
"""Error that will be throwed by all functions""" """Error thrown by TVM function"""
pass pass
def _load_lib(): def _load_lib():
...@@ -40,9 +40,11 @@ def _load_lib(): ...@@ -40,9 +40,11 @@ def _load_lib():
__version__ = libinfo.__version__ __version__ = libinfo.__version__
# library instance of nnvm # library instance of nnvm
_LIB = _load_lib() _LIB = _load_lib()
# The FFI mode of TVM
_FFI_MODE = os.environ.get("TVM_FFI", "auto")
#---------------------------- #----------------------------
# helper function definition # helper function in ctypes.
#---------------------------- #----------------------------
def check_call(ret): def check_call(ret):
"""Check the return value of C API call """Check the return value of C API call
......
# coding: utf-8 # pylint: disable=invalid-name, unused-import
# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement """Function namespace."""
"""Function configuration API."""
from __future__ import absolute_import from __future__ import absolute_import
import ctypes
import sys import sys
import traceback import ctypes
from numbers import Number, Integral from .base import _LIB, check_call, py_str, c_str, _FFI_MODE
from .._base import _LIB, check_call IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
from .._base import c_str, py_str, string_types
from .types import TVMValue, TypeCode, TVMType, TVMByteArray try:
from .types import TVMPackedCFunc, TVMCFuncFinalizer # pylint: disable=wrong-import-position
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func if _FFI_MODE == "ctypes":
from .node import NodeBase, NodeGeneric, convert_to_node raise ImportError()
from .ndarray import NDArrayBase if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import convert_to_tvm_func
else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle):
"""callback to free resources when it it not needed."""
pyobj = ctypes.cast(rhandle, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
def convert_to_tvm_func(pyfunc):
"""Convert a python function to TVM function
Parameters
----------
pyfunc : python function
The python function to be converted.
Returns
-------
tvmfunc: tvm.nd.Function
The converted tvm function.
"""
local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, ret, _):
""" ctypes function """
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
pyargs = [C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
# pylint: disable=broad-except
try:
rv = local_pyfunc(*pyargs)
except Exception:
msg = traceback.format_exc()
_LIB.TVMAPISetLastError(c_str(msg))
return -1
if rv is not None:
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value")
temp_args = []
values, tcodes, _ = _make_tvm_args((rv,), temp_args)
if not isinstance(ret, TVMRetValueHandle):
ret = TVMRetValueHandle(ret)
check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0])))
_ = temp_args
_ = rv
return 0
handle = FunctionHandle()
f = TVMPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
return Function(handle, False)
def _make_tvm_args(args, temp_args):
"""Pack arguments into c args tvm call accept"""
num_args = len(args)
values = (TVMValue * num_args)()
type_codes = (ctypes.c_int * num_args)()
for i, arg in enumerate(args):
if isinstance(arg, NodeBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.NODE_HANDLE
elif arg is None:
values[i].v_handle = None
type_codes[i] = TypeCode.NULL
elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = TypeCode.ARRAY_HANDLE
elif isinstance(arg, Integral):
values[i].v_int64 = arg
type_codes[i] = TypeCode.INT
elif isinstance(arg, Number):
values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, TVMType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr)
type_codes[i] = TypeCode.BYTES
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.NODE_HANDLE
temp_args.append(arg)
elif isinstance(arg, ModuleBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, Function):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
elif callable(arg):
arg = convert_to_tvm_func(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
class Function(_FunctionBase):
class Function(object):
"""The PackedFunc object used in TVM. """The PackedFunc object used in TVM.
Function plays an key role to bridge front and backend in TVM. Function plays an key role to bridge front and backend in TVM.
...@@ -158,42 +51,7 @@ class Function(object): ...@@ -158,42 +51,7 @@ class Function(object):
tvm.register_func: How to register global function. tvm.register_func: How to register global function.
tvm.get_global_func: How to get global function. tvm.get_global_func: How to get global function.
""" """
__slots__ = ["handle", "is_global"] pass
# pylint: disable=no-member
def __init__(self, handle, is_global):
"""Initialize the function with handle
Parameters
----------
handle : FunctionHandle
the handle to the underlying function.
is_global : bool
Whether this is a global function in python
"""
self.handle = handle
self.is_global = is_global
def __del__(self):
if not self.is_global:
check_call(_LIB.TVMFuncFree(self.handle))
def __call__(self, *args):
"""Call the function with positional arguments
args : list
The positional arguments to the function call.
"""
temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.TVMFuncCall(
self.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
class ModuleBase(object): class ModuleBase(object):
...@@ -214,9 +72,8 @@ class ModuleBase(object): ...@@ -214,9 +72,8 @@ class ModuleBase(object):
""" """
if self._entry: if self._entry:
return self._entry return self._entry
else: self._entry = self.get_function("__tvm_main__")
self._entry = self.get_function("__tvm_main__") return self._entry
return self._entry
def get_function(self, name, query_imports=False): def get_function(self, name, query_imports=False):
"""Get function from the module. """Get function from the module.
...@@ -273,39 +130,11 @@ class ModuleBase(object): ...@@ -273,39 +130,11 @@ class ModuleBase(object):
raise ValueError("Can only take string as function name") raise ValueError("Can only take string as function name")
return self.get_function(name) return self.get_function(name)
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
def __call__(self, *args): def __call__(self, *args):
if self._entry: if self._entry:
return self._entry(*args) return self._entry(*args)
else: f = self.entry_func
f = self.entry_func return f(*args)
return f(*args)
_module_cls = None
def _return_module(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, ModuleHandle):
handle = ModuleHandle(handle)
return _module_cls(handle)
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, FunctionHandle):
handle = FunctionHandle(handle)
return Function(handle, False)
# setup return handle for function type
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE)
def register_func(func_name, f=None, override=False): def register_func(func_name, f=None, override=False):
...@@ -456,8 +285,4 @@ def _init_api(namespace): ...@@ -456,8 +285,4 @@ def _init_api(namespace):
ff.__doc__ = ("TVM PackedFunc %s. " % fname) ff.__doc__ = ("TVM PackedFunc %s. " % fname)
setattr(target_module, ff.__name__, ff) setattr(target_module, ff.__name__, ff)
_set_class_function(Function)
def _init_module_module(module_class):
"""Initialize the module."""
global _module_cls
_module_cls = module_class
# coding: utf-8 """Library information."""
"""Information about nnvm."""
from __future__ import absolute_import from __future__ import absolute_import
import sys import sys
import os import os
...@@ -16,17 +15,17 @@ def find_lib_path(): ...@@ -16,17 +15,17 @@ def find_lib_path():
""" """
use_runtime = os.environ.get("TVM_USE_RUNTIME_LIB", False) use_runtime = os.environ.get("TVM_USE_RUNTIME_LIB", False)
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
api_path = os.path.join(curr_path, '../../lib/') api_path = os.path.join(curr_path, '../../../lib/')
cmake_build_path = os.path.join(curr_path, '../../build/Release/') cmake_build_path = os.path.join(curr_path, '../../../build/Release/')
dll_path = [curr_path, api_path, cmake_build_path] dll_path = [curr_path, api_path, cmake_build_path]
if os.name == 'nt': if os.name == 'nt':
vs_configuration = 'Release' vs_configuration = 'Release'
if platform.architecture()[0] == '64bit': if platform.architecture()[0] == '64bit':
dll_path.append(os.path.join(curr_path, '../../build', vs_configuration)) dll_path.append(os.path.join(curr_path, '../../../build', vs_configuration))
dll_path.append(os.path.join(curr_path, '../../windows/x64', vs_configuration)) dll_path.append(os.path.join(curr_path, '../../../windows/x64', vs_configuration))
else: else:
dll_path.append(os.path.join(curr_path, '../../build', vs_configuration)) dll_path.append(os.path.join(curr_path, '../../../build', vs_configuration))
dll_path.append(os.path.join(curr_path, '../../windows', vs_configuration)) dll_path.append(os.path.join(curr_path, '../../../windows', vs_configuration))
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None): elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")]) dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
...@@ -40,7 +39,7 @@ def find_lib_path(): ...@@ -40,7 +39,7 @@ def find_lib_path():
dll_path = runtime_dll_path if use_runtime else lib_dll_path dll_path = runtime_dll_path if use_runtime else lib_dll_path
lib_found = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] lib_found = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_found) == 0: if not lib_found:
raise RuntimeError('Cannot find the files.\n' + raise RuntimeError('Cannot find the files.\n' +
'List of candidates:\n' + str('\n'.join(dll_path))) 'List of candidates:\n' + str('\n'.join(dll_path)))
if use_runtime: if use_runtime:
......
# pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement # pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring # pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
"""Symbolic configuration API.""" """Runtime NDArray api"""
from __future__ import absolute_import from __future__ import absolute_import
import ctypes import ctypes
import numpy as np import numpy as np
from .base import _LIB, check_call, c_array
tvm_shape_index_t = ctypes.c_int64
class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
class TVMType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str, lanes=1):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
elif type_str.startswith("handle"):
self.type_code = 4
bits = 64
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
from .._base import _LIB, check_call def __ne__(self, other):
from .._base import c_array return not self.__eq__(other)
from .types import TVMType, tvm_shape_index_t
class TVMContext(ctypes.Structure): class TVMContext(ctypes.Structure):
"""TVM context strucure.""" """TVM context strucure."""
...@@ -104,10 +157,6 @@ def numpyasarray(np_data): ...@@ -104,10 +157,6 @@ def numpyasarray(np_data):
arr.ctx = cpu(0) arr.ctx = cpu(0)
return arr, shape return arr, shape
_ndarray_cls = None
def empty(shape, dtype="float32", ctx=cpu(0)): def empty(shape, dtype="float32", ctx=cpu(0)):
"""Create an empty array given shape and device """Create an empty array given shape and device
...@@ -133,7 +182,7 @@ def empty(shape, dtype="float32", ctx=cpu(0)): ...@@ -133,7 +182,7 @@ def empty(shape, dtype="float32", ctx=cpu(0)):
dtype = TVMType(dtype) dtype = TVMType(dtype)
check_call(_LIB.TVMArrayAlloc( check_call(_LIB.TVMArrayAlloc(
shape, ndim, dtype, ctx, ctypes.byref(handle))) shape, ndim, dtype, ctx, ctypes.byref(handle)))
return _ndarray_cls(handle) return _CLASS_NDARRAY(handle)
def sync(ctx): def sync(ctx):
...@@ -253,7 +302,8 @@ class NDArrayBase(object): ...@@ -253,7 +302,8 @@ class NDArrayBase(object):
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
return target return target
_CLASS_NDARRAY = None
def _init_ndarray_module(ndarray_class): def _set_class_ndarray(cls):
global _ndarray_cls global _CLASS_NDARRAY
_ndarray_cls = ndarray_class _CLASS_NDARRAY = cls
# coding: utf-8 """Node namespace"""
# pylint: disable=invalid-name, protected-access # pylint: disable=unused-import
# pylint: disable=no-member, missing-docstring
"""Symbolic configuration API."""
from __future__ import absolute_import from __future__ import absolute_import
import ctypes import ctypes
from numbers import Number, Integral import sys
from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types
from .. import _api_internal from .. import _api_internal
from .types import TVMValue, TypeCode from .node_generic import NodeGeneric, convert_to_node, const
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .base import _LIB, check_call, c_str, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
NodeHandle = ctypes.c_void_p try:
# pylint: disable=wrong-import-position
"""Maps node type to its constructor""" if _FFI_MODE == "ctypes":
NODE_TYPE = { raise ImportError()
} if sys.version_info >= (3, 0):
from ._cy3.core import _register_node, NodeBase as _NodeBase
def _return_node(x): else:
"""Return node function""" from ._cy2.core import _register_node, NodeBase as _NodeBase
handle = x.v_handle except IMPORT_EXCEPT:
if not isinstance(handle, NodeHandle): # pylint: disable=wrong-import-position
handle = NodeHandle(handle) from ._ctypes.node import _register_node, NodeBase as _NodeBase
ret_val = TVMValue()
ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
handle, c_str("type_key"),
ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle)
RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
_return_node, TypeCode.NODE_HANDLE)
class NodeGeneric(object):
"""Base class for all classes that can be converted to node."""
def asnode(self):
"""Convert value to node"""
raise NotImplementedError()
class NodeBase(object): class NodeBase(_NodeBase):
"""NodeBase is the base class of all TVM language AST object.""" """NodeBase is the base class of all TVM language AST object."""
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
self.handle = handle
def __repr__(self): def __repr__(self):
return _api_internal._format_str(self) return _api_internal._format_str(self)
def __del__(self): def __dir__(self):
check_call(_LIB.TVMNodeFree(self.handle)) plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
def __getattr__(self, name): check_call(_LIB.TVMNodeListAttrNames(
ret_val = TVMValue() self.handle, ctypes.byref(size), ctypes.byref(plist)))
ret_type_code = ctypes.c_int() names = []
ret_success = ctypes.c_int() for i in range(size.value):
check_call(_LIB.TVMNodeGetAttr( names.append(py_str(plist[i]))
self.handle, c_str(name), return names
ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
if not ret_success.value:
raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return RETURN_SWITCH[ret_type_code.value](ret_val)
def __hash__(self): def __hash__(self):
return _api_internal._raw_ptr(self) return _api_internal._raw_ptr(self)
...@@ -93,16 +47,6 @@ class NodeBase(object): ...@@ -93,16 +47,6 @@ class NodeBase(object):
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __dir__(self):
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMNodeListAttrNames(
self.handle, ctypes.byref(size), ctypes.byref(plist)))
names = []
for i in range(size.value):
names.append(py_str(plist[i]))
return names
def __reduce__(self): def __reduce__(self):
return (type(self), (None,), self.__getstate__()) return (type(self), (None,), self.__getstate__())
...@@ -110,8 +54,7 @@ class NodeBase(object): ...@@ -110,8 +54,7 @@ class NodeBase(object):
handle = self.handle handle = self.handle
if handle is not None: if handle is not None:
return {'handle': _api_internal._save_json(self)} return {'handle': _api_internal._save_json(self)}
else: return {'handle': None}
return {'handle': None}
def __setstate__(self, state): def __setstate__(self, state):
# pylint: disable=assigning-non-slot # pylint: disable=assigning-non-slot
...@@ -125,66 +68,6 @@ class NodeBase(object): ...@@ -125,66 +68,6 @@ class NodeBase(object):
self.handle = None self.handle = None
def const(value, dtype=None):
"""Construct a constant value for a given type.
Parameters
----------
value : int or float
The input value
dtype : str
The data type.
Returns
-------
expr : Expr
Constant expression corresponds to the value.
"""
if dtype is None:
if isinstance(value, Integral):
dtype = 'int32'
else:
dtype = 'float32'
return _api_internal._const(value, dtype)
def convert_to_node(value):
"""Convert a python value to corresponding node type.
Parameters
----------
value : str
The value to be inspected.
Returns
-------
node : Node
The corresponding node value.
"""
if isinstance(value, NodeBase):
return value
elif isinstance(value, Number):
return const(value)
elif isinstance(value, string_types):
return _api_internal._str(value)
elif isinstance(value, (list, tuple)):
value = [convert_to_node(x) for x in value]
return _api_internal._Array(*value)
elif isinstance(value, dict):
vlist = []
for it in value.items():
if not isinstance(it[0], NodeBase):
raise ValueError("key of map must already been a container type")
vlist.append(it[0])
vlist.append(convert_to_node(it[1]))
return _api_internal._Map(*vlist)
elif isinstance(value, NodeGeneric):
return value.asnode()
else:
raise ValueError("don't know how to convert type %s to node" % type(value))
def register_node(type_key=None): def register_node(type_key=None):
"""register node type """register node type
...@@ -197,10 +80,14 @@ def register_node(type_key=None): ...@@ -197,10 +80,14 @@ def register_node(type_key=None):
def register(cls): def register(cls):
"""internal register function""" """internal register function"""
NODE_TYPE[node_name] = cls tindex = ctypes.c_int()
try:
check_call(_LIB.TVMNodeTypeKey2Index(c_str(node_name), ctypes.byref(tindex)))
_register_node(tindex.value, cls)
except AttributeError:
pass
return cls return cls
if isinstance(type_key, str): if isinstance(type_key, str):
return register return register
else: return register(type_key)
return register(type_key)
"""Common implementation of Node generic related logic"""
# pylint: disable=unused-import
from __future__ import absolute_import
from numbers import Number, Integral
from .. import _api_internal
from .base import string_types
# Node base class
_CLASS_NODE_BASE = None
def _set_class_node_base(cls):
global _CLASS_NODE_BASE
_CLASS_NODE_BASE = cls
class NodeGeneric(object):
"""Base class for all classes that can be converted to node."""
def asnode(self):
"""Convert value to node"""
raise NotImplementedError()
def convert_to_node(value):
"""Convert a python value to corresponding node type.
Parameters
----------
value : str
The value to be inspected.
Returns
-------
node : Node
The corresponding node value.
"""
if isinstance(value, _CLASS_NODE_BASE):
return value
elif isinstance(value, Number):
return const(value)
elif isinstance(value, string_types):
return _api_internal._str(value)
elif isinstance(value, (list, tuple)):
value = [convert_to_node(x) for x in value]
return _api_internal._Array(*value)
elif isinstance(value, dict):
vlist = []
for item in value.items():
if not isinstance(item[0], _CLASS_NODE_BASE):
raise ValueError("key of map must already been a container type")
vlist.append(item[0])
vlist.append(convert_to_node(item[1]))
return _api_internal._Map(*vlist)
elif isinstance(value, NodeGeneric):
return value.asnode()
else:
raise ValueError("don't know how to convert type %s to node" % type(value))
def const(value, dtype=None):
"""Construct a constant value for a given type.
Parameters
----------
value : int or float
The input value
dtype : str
The data type.
Returns
-------
expr : Expr
Constant expression corresponds to the value.
"""
if dtype is None:
if isinstance(value, Integral):
dtype = 'int32'
else:
dtype = 'float32'
return _api_internal._const(value, dtype)
...@@ -7,7 +7,7 @@ import os ...@@ -7,7 +7,7 @@ import os
import ctypes import ctypes
from .. import _api_internal from .. import _api_internal
from .._base import string_types from .._ffi.base import string_types
from .._ffi.node import NodeBase, register_node from .._ffi.node import NodeBase, register_node
from .._ffi.function import register_func from .._ffi.function import register_func
from . import testing from . import testing
...@@ -23,7 +23,10 @@ class VPISession(NodeBase): ...@@ -23,7 +23,10 @@ class VPISession(NodeBase):
def __del__(self): def __del__(self):
self.proc.kill() self.proc.kill()
super(VPISession, self).__del__() try:
super(VPISession, self).__del__()
except AttributeError:
pass
def arg(self, index): def arg(self, index):
"""Get handle passed to host session. """Get handle passed to host session.
...@@ -143,10 +146,9 @@ def find_file(file_name): ...@@ -143,10 +146,9 @@ def find_file(file_name):
ver_path = search_path() ver_path = search_path()
flist = [os.path.join(p, file_name) for p in ver_path] flist = [os.path.join(p, file_name) for p in ver_path]
found = [p for p in flist if os.path.exists(p) and os.path.isfile(p)] found = [p for p in flist if os.path.exists(p) and os.path.isfile(p)]
if len(found): if not found:
return found[0]
else:
raise ValueError("Cannot find %s in %s" % (file_name, flist)) raise ValueError("Cannot find %s in %s" % (file_name, flist))
return found[0]
def compile_file(file_name, file_target, options=None): def compile_file(file_name, file_target, options=None):
......
...@@ -4,13 +4,13 @@ from __future__ import absolute_import as _abs ...@@ -4,13 +4,13 @@ from __future__ import absolute_import as _abs
from numbers import Integral as _Integral from numbers import Integral as _Integral
from ._ffi.base import string_types
from ._ffi.node import register_node, NodeBase from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import Function from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func from ._ffi.function import _init_api, register_func, get_global_func
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal from . import _api_internal
from . import _base
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
from . import tensor as _tensor from . import tensor as _tensor
...@@ -57,8 +57,8 @@ def convert(value): ...@@ -57,8 +57,8 @@ def convert(value):
if callable(value): if callable(value):
return _convert_tvm_func(value) return _convert_tvm_func(value)
else:
return _convert_to_node(value) return _convert_to_node(value)
def load_json(json_str): def load_json(json_str):
...@@ -396,9 +396,9 @@ def thread_axis(dom=None, tag='', name=''): ...@@ -396,9 +396,9 @@ def thread_axis(dom=None, tag='', name=''):
axis : IterVar axis : IterVar
The thread itervar. The thread itervar.
""" """
if isinstance(dom, _base.string_types): if isinstance(dom, string_types):
tag, dom = dom, None tag, dom = dom, None
if len(tag) == 0: if not tag:
raise ValueError("tag must be given as Positional or keyword argument") raise ValueError("tag must be given as Positional or keyword argument")
name = name if name else tag name = name if name else tag
return _IterVar(dom, name, 1, tag) return _IterVar(dom, name, 1, tag)
......
...@@ -157,8 +157,7 @@ def _rule_float_suffix(op): ...@@ -157,8 +157,7 @@ def _rule_float_suffix(op):
return call_pure_extern(op.dtype, "%sf" % op.name, *op.args) return call_pure_extern(op.dtype, "%sf" % op.name, *op.args)
elif op.dtype == "float64": elif op.dtype == "float64":
return call_pure_extern(op.dtype, op.name, *op.args) return call_pure_extern(op.dtype, op.name, *op.args)
else: return op
return op
def _rule_float_direct(op): def _rule_float_direct(op):
...@@ -183,8 +182,7 @@ def _rule_float_direct(op): ...@@ -183,8 +182,7 @@ def _rule_float_direct(op):
""" """
if str(op.dtype).startswith("float"): if str(op.dtype).startswith("float"):
return call_pure_extern(op.dtype, op.name, *op.args) return call_pure_extern(op.dtype, op.name, *op.args)
else: return None
return None
# opencl pattern for exp # opencl pattern for exp
register_intrin_rule("opencl", "exp", _rule_float_direct, override=True) register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
......
...@@ -7,7 +7,7 @@ from . import expr as _expr ...@@ -7,7 +7,7 @@ from . import expr as _expr
from . import make as _make from . import make as _make
from . import ir_pass as _pass from . import ir_pass as _pass
from . import collections as _collections from . import collections as _collections
from ._base import string_types from ._ffi.base import string_types
from ._ffi.node import NodeGeneric from ._ffi.node import NodeGeneric
class WithScope(object): class WithScope(object):
...@@ -89,7 +89,7 @@ class IRBuilder(object): ...@@ -89,7 +89,7 @@ class IRBuilder(object):
def _pop_seq(self): def _pop_seq(self):
"""Pop sequence from stack""" """Pop sequence from stack"""
seq = self._seq_stack.pop() seq = self._seq_stack.pop()
if len(seq) == 0 or callable(seq[-1]): if not seq or callable(seq[-1]):
seq.append(_make.Evaluate(0)) seq.append(_make.Evaluate(0))
stmt = seq[-1] stmt = seq[-1]
for s in reversed(seq[:-1]): for s in reversed(seq[:-1]):
...@@ -232,7 +232,7 @@ class IRBuilder(object): ...@@ -232,7 +232,7 @@ class IRBuilder(object):
with ib.else_scope(): with ib.else_scope():
x[i] = x[i - 1] + 2 x[i] = x[i - 1] + 2
""" """
if len(self._seq_stack[-1]) == 0: if not self._seq_stack[-1]:
raise RuntimeError("else_scope can only follow an if_scope") raise RuntimeError("else_scope can only follow an if_scope")
prev = self._seq_stack[-1][-1] prev = self._seq_stack[-1][-1]
if not isinstance(prev, _stmt.IfThenElse) or prev.else_case: if not isinstance(prev, _stmt.IfThenElse) or prev.else_case:
...@@ -317,7 +317,7 @@ class IRBuilder(object): ...@@ -317,7 +317,7 @@ class IRBuilder(object):
The result statement. The result statement.
""" """
seq = self._pop_seq() seq = self._pop_seq()
if len(self._seq_stack) != 0: if self._seq_stack:
raise RuntimeError("cannot call get inside construction scope") raise RuntimeError("cannot call get inside construction scope")
return seq return seq
......
"""Container of compiled functions of TVM.""" """Container of compiled functions of TVM."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.function import ModuleBase, _init_module_module from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api from ._ffi.function import _init_api
...@@ -97,4 +97,4 @@ def enabled(target): ...@@ -97,4 +97,4 @@ def enabled(target):
_init_api("tvm.module") _init_api("tvm.module")
_init_module_module(Module) _set_class_module(Module)
...@@ -9,7 +9,7 @@ import numpy as _np ...@@ -9,7 +9,7 @@ import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import cpu, gpu, opencl, vpi, empty, sync from ._ffi.ndarray import cpu, gpu, opencl, vpi, empty, sync
from ._ffi.ndarray import _init_ndarray_module from ._ffi.ndarray import _set_class_ndarray
cl = opencl cl = opencl
...@@ -49,5 +49,4 @@ def array(arr, ctx=cpu(0)): ...@@ -49,5 +49,4 @@ def array(arr, ctx=cpu(0)):
ret[:] = arr ret[:] = arr
return ret return ret
_set_class_ndarray(NDArray)
_init_ndarray_module(NDArray)
...@@ -93,8 +93,7 @@ class Tensor(NodeBase): ...@@ -93,8 +93,7 @@ class Tensor(NodeBase):
op = self.op op = self.op
if op.num_outputs == 1: if op.num_outputs == 1:
return op.name return op.name
else: return "%s.v%d" % (op.name, self.value_index)
return "%s.v%d" % (op.name, self.value_index)
class Operation(NodeBase): class Operation(NodeBase):
......
...@@ -113,10 +113,18 @@ int TVMCbArgToReturn(TVMValue* value, int code) { ...@@ -113,10 +113,18 @@ int TVMCbArgToReturn(TVMValue* value, int code) {
API_END(); API_END();
} }
int TVMNodeDupe(NodeHandle handle, NodeHandle* out_handle) { int TVMNodeTypeKey2Index(const char* type_key,
int* out_index) {
API_BEGIN(); API_BEGIN();
*out_index = static_cast<int>(Node::TypeKey2Index(type_key));
API_END();
}
*out_handle = new TVMAPINode(*static_cast<TVMAPINode*>(handle)); int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index) {
API_BEGIN();
*out_index = static_cast<int>(
(*static_cast<TVMAPINode*>(handle))->type_index());
API_END(); API_END();
} }
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
# Add files or directories to the blacklist. They should be base names, not # Add files or directories to the blacklist. They should be base names, not
# paths. # paths.
ignore=CVS ignore=CVS, _cy2, _cy3
# Add files or directories matching the regex patterns to the blacklist. The # Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths. # regex matches against base names, not paths.
...@@ -67,7 +67,6 @@ enable=indexing-exception,old-raise-syntax ...@@ -67,7 +67,6 @@ enable=indexing-exception,old-raise-syntax
# --disable=W" # --disable=W"
disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,protected-access disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,protected-access
[REPORTS] [REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs # Set the output format. Available formats are text, parseable, colorized, msvs
......
...@@ -7,7 +7,9 @@ import subprocess ...@@ -7,7 +7,9 @@ import subprocess
runtime_py = """ runtime_py = """
import os import os
import sys import sys
os.environ["TVM_USE_RUNTIME_LIB"] = "1" os.environ["TVM_USE_RUNTIME_LIB"] = "1"
os.environ["TVM_FFI"] = "ctypes"
import tvm import tvm
import numpy as np import numpy as np
path_dso = sys.argv[1] path_dso = sys.argv[1]
......
...@@ -54,6 +54,10 @@ if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then ...@@ -54,6 +54,10 @@ if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
if [ ${TRAVIS_OS_NAME} == "osx" ]; then if [ ${TRAVIS_OS_NAME} == "osx" ]; then
python -m nose -v tests/python/unittest || exit -1 python -m nose -v tests/python/unittest || exit -1
python3 -m nose -v tests/python/unittest || exit -1 python3 -m nose -v tests/python/unittest || exit -1
make cython || exit -1
make cython3 || exit -1
TVM_FFI=cython python -m nose -v tests/python/unittest || exit -1
TVM_FFI=cython python3 -m nose -v tests/python/unittest || exit -1
else else
nosetests -v tests/python/unittest || exit -1 nosetests -v tests/python/unittest || exit -1
nosetests3 -v tests/python/unittest || exit -1 nosetests3 -v tests/python/unittest || exit -1
......
#!/bin/bash #!/bin/bash
if [ ${TRAVIS_OS_NAME} == "osx" ]; then if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then if [ ${TRAVIS_OS_NAME} == "osx" ]; then
brew update brew update
brew install python3 brew install python3
python -m pip install --user nose numpy python -m pip install --user nose numpy cython
python3 -m pip install --user nose numpy python3 -m pip install --user nose numpy cython
fi fi
fi fi
......
...@@ -73,11 +73,11 @@ def test_ram_read(): ...@@ -73,11 +73,11 @@ def test_ram_read():
# yield until read is done # yield until read is done
for i in range(a.shape[0] * 3): for i in range(a.shape[0] * 3):
sess.yield_until_next_cycle() sess.yield_until_next_cycle()
sess.shutdown()
# check if result matches # check if result matches
r = np.concatenate((a_np, a_np[2:])) r = np.concatenate((a_np, a_np[2:]))
np.testing.assert_equal(np.array(reader.data), r) np.testing.assert_equal(np.array(reader.data), r)
def test_ram_write(): def test_ram_write():
n = 10 n = 10
# read from offset # read from offset
...@@ -122,7 +122,7 @@ def test_ram_write(): ...@@ -122,7 +122,7 @@ def test_ram_write():
# yield until write is done # yield until write is done
for i in range(a.shape[0]+2): for i in range(a.shape[0]+2):
sess.yield_until_next_cycle() sess.yield_until_next_cycle()
sess.shutdown()
# check if result matches # check if result matches
np.testing.assert_equal(a.asnumpy()[2:], r_data) np.testing.assert_equal(a.asnumpy()[2:], r_data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment