Commit 535a97c9 by Tianqi Chen Committed by GitHub

[PYTHON/FFI] Enable Cython FFI (#106)

* [PYTHON/FFI] Enable Cython FFI

* fix cython
parent 26d91985
......@@ -11,7 +11,7 @@ endif
include $(config)
# specify tensor path
.PHONY: clean all test doc pylint cpplint lint verilog
.PHONY: clean all test doc pylint cpplint lint verilog cython cython2 cython3
all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a
......@@ -95,8 +95,6 @@ lib/libtvm_runtime.so: $(RUNTIME_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lib/libtvm.a: $(ALL_DEP)
@mkdir -p $(@D)
ar crv $@ $(filter %.o, $?)
......@@ -117,6 +115,19 @@ lint: cpplint pylint
doc:
doxygen docs/Doxyfile
# Cython build
cython:
cd python; python setup.py build_ext --inplace
cython2:
cd python; python2 setup.py build_ext --inplace
cython3:
cd python; python3 setup.py build_ext --inplace
cyclean:
rm -rf python/tvm/*/*/*.so python/tvm/*/*/*.cpp
clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d
......
......@@ -39,6 +39,24 @@ TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
TVM_DLL int TVMNodeFree(NodeHandle handle);
/*!
* \brief Convert type key to type index.
* \param type_key The key of the type.
* \param out_index the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMNodeTypeKey2Index(const char* type_key,
int* out_index);
/*!
* \brief Get runtime type index of the node.
* \param handle the node handle.
* \param out_index the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index);
/*!
* \brief get attributes given key
* \param handle The node handle
* \param key The attribute name
......
build
*.cpp
\ No newline at end of file
......@@ -6,21 +6,74 @@ import os
import sys
import setuptools
# need to use distutils.core for correct placement of cython dll
if "--inplace" in sys.argv:
from distutils.core import setup
from distutils.extension import Extension
else:
from setuptools import setup
from setuptools.extension import Extension
CURRENT_DIR = os.path.dirname(__file__)
libinfo_py = os.path.join(CURRENT_DIR, 'tvm/libinfo.py')
libinfo_py = os.path.join(CURRENT_DIR, 'tvm/_ffi/libinfo.py')
libinfo = {'__file__': libinfo_py}
exec(compile(open(libinfo_py, 'rb').read(), libinfo_py, 'exec'), libinfo, libinfo)
LIB_PATH = libinfo['find_lib_path']()
print(LIB_PATH)
__version__ = libinfo['__version__']
def config_cython():
"""Try to configure cython and return cython configuration"""
if os.name == 'nt':
print("WARNING: Cython is not supported on Windows, will compile without cython module")
return []
try:
from Cython.Build import cythonize
# from setuptools.extension import Extension
if sys.version_info >= (3, 0):
subdir = "_cy3"
else:
subdir = "_cy2"
ret = []
path = "tvm/_ffi/_cython"
if os.name == 'nt':
library_dirs = ['tvm', '../build/Release', '../build']
libraries = ['libtvm']
else:
library_dirs = None
libraries = None
for fn in os.listdir(path):
if not fn.endswith(".pyx"):
continue
ret.append(Extension(
"tvm._ffi.%s.%s" % (subdir, fn[:-4]),
["tvm/_ffi/_cython/%s" % fn],
include_dirs=["../include/",
"../dmlc-core/include",
"../dlpack/include",
],
library_dirs=library_dirs,
libraries=libraries,
language="c++"))
return cythonize(ret)
except ImportError:
print("WARNING: Cython is not installed, will compile without cython module")
return []
setuptools.setup(
name='tvm',
version=__version__,
description='A domain specific language(DSL) for tensor computations.',
packages=setuptools.find_packages(),
install_requires=[
'numpy',
],
data_files=[('tvm', [LIB_PATH[0]])]
)
zip_safe=False,
packages=[
'tvm', 'tvm.addon',
'tvm._ffi', 'tvm._ffi._ctypes',
'tvm._ffi._cy2', 'tvm._ffi._cy3'
],
data_files=[('tvm', [LIB_PATH[0]])],
url='https://github.com/tqchen/tvm',
ext_modules=config_cython())
......@@ -19,9 +19,7 @@ from . import ndarray as nd
from .ndarray import cpu, gpu, opencl, cl, vpi
from ._ffi.function import Function
from ._base import TVMError
from ._base import __version__
from ._ffi.base import TVMError, __version__
from .api import *
from .intrin import *
from .node import register_node
......
"""ctypes support"""
"""C interfacing code.
This namespace contains everything that interacts with C code.
Most TVM C related object are ctypes compatible, which means
they contains a handle field that is ctypes.c_void_p and can
be used via ctypes function calls.
Some performance critical functions are implemented by cython
and have a ctypes fallback implementation.
"""
"""ctypes specific implementation of FFI"""
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement
"""Function configuration API."""
from __future__ import absolute_import
import ctypes
import traceback
from numbers import Number, Integral
from ..base import _LIB, check_call
from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import TVMType, TVMByteArray, NDArrayBase
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .node import NodeBase
FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle):
"""callback to free resources when it it not needed."""
pyobj = ctypes.cast(rhandle, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
def convert_to_tvm_func(pyfunc):
"""Convert a python function to TVM function
Parameters
----------
pyfunc : python function
The python function to be converted.
Returns
-------
tvmfunc: tvm.nd.Function
The converted tvm function.
"""
local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, ret, _):
""" ctypes function """
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
pyargs = (C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args))
# pylint: disable=broad-except
try:
rv = local_pyfunc(*pyargs)
except Exception:
msg = traceback.format_exc()
_LIB.TVMAPISetLastError(c_str(msg))
return -1
if rv is not None:
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value")
temp_args = []
values, tcodes, _ = _make_tvm_args((rv,), temp_args)
if not isinstance(ret, TVMRetValueHandle):
ret = TVMRetValueHandle(ret)
check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0])))
_ = temp_args
_ = rv
return 0
handle = FunctionHandle()
f = TVMPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
return _CLASS_FUNCTION(handle, False)
def _make_tvm_args(args, temp_args):
"""Pack arguments into c args tvm call accept"""
num_args = len(args)
values = (TVMValue * num_args)()
type_codes = (ctypes.c_int * num_args)()
for i, arg in enumerate(args):
if isinstance(arg, NodeBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.NODE_HANDLE
elif arg is None:
values[i].v_handle = None
type_codes[i] = TypeCode.NULL
elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = TypeCode.ARRAY_HANDLE
elif isinstance(arg, Integral):
values[i].v_int64 = arg
type_codes[i] = TypeCode.INT
elif isinstance(arg, Number):
values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, TVMType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr)
type_codes[i] = TypeCode.BYTES
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.NODE_HANDLE
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, FunctionBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
elif callable(arg):
arg = convert_to_tvm_func(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
class FunctionBase(object):
"""Function base."""
__slots__ = ["handle", "is_global"]
# pylint: disable=no-member
def __init__(self, handle, is_global):
"""Initialize the function with handle
Parameters
----------
handle : FunctionHandle
the handle to the underlying function.
is_global : bool
Whether this is a global function in python
"""
self.handle = handle
self.is_global = is_global
def __del__(self):
if not self.is_global:
check_call(_LIB.TVMFuncFree(self.handle))
def __call__(self, *args):
"""Call the function with positional arguments
args : list
The positional arguments to the function call.
"""
temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.TVMFuncCall(
self.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
def _return_module(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, ModuleHandle):
handle = ModuleHandle(handle)
return _CLASS_MODULE(handle)
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, FunctionHandle):
handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False)
# setup return handle for function type
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE)
_CLASS_MODULE = None
_CLASS_FUNCTION = None
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
# pylint: disable=invalid-name, protected-access
# pylint: disable=no-member, missing-docstring
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..node_generic import _set_class_node_base
from .types import TVMValue, TypeCode
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
NodeHandle = ctypes.c_void_p
"""Maps node type to its constructor"""
NODE_TYPE = {}
def _register_node(index, cls):
"""register node class"""
NODE_TYPE[index] = cls
def _return_node(x):
"""Return node function"""
handle = x.v_handle
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
tindex = ctypes.c_int()
check_call(_LIB.TVMNodeGetTypeIndex(handle, ctypes.byref(tindex)))
return NODE_TYPE.get(tindex.value, NodeBase)(handle)
RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
_return_node, TypeCode.NODE_HANDLE)
class NodeBase(object):
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
self.handle = handle
def __del__(self):
check_call(_LIB.TVMNodeFree(self.handle))
def __getattr__(self, name):
ret_val = TVMValue()
ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name),
ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
if not ret_success.value:
raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return RETURN_SWITCH[ret_type_code.value](ret_val)
_set_class_node_base(NodeBase)
......@@ -3,10 +3,8 @@
from __future__ import absolute_import as _abs
import ctypes
import numpy as np
from .._base import py_str, check_call, _LIB
tvm_shape_index_t = ctypes.c_int64
from ..base import py_str, check_call, _LIB
from ..ndarray import TVMByteArray
class TypeCode(object):
"""Type code used in API calls"""
......@@ -23,66 +21,6 @@ class TypeCode(object):
STR = 10
BYTES = 11
def _api_type(code):
"""create a type accepted by API"""
t = TVMType()
t.bits = 64
t.lanes = 1
t.type_code = code
return t
class TVMType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str, lanes=1):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
elif type_str.startswith("handle"):
self.type_code = 4
bits = 64
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
def __ne__(self, other):
return not self.__eq__(other)
class TVMValue(ctypes.Union):
"""TVMValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64),
......@@ -90,11 +28,6 @@ class TVMValue(ctypes.Union):
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p)]
class TVMByteArray(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
TVMPackedCFunc = ctypes.CFUNCTYPE(
ctypes.c_int,
......
"""cython2 namespace"""
"""cython3 namespace"""
from ..base import TVMError
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from libc.stdint cimport int64_t, uint8_t, uint16_t
import ctypes
cdef enum TVMTypeCode:
kInt = 0
kUInt = 1
kFloat = 2
kHandle = 3
kNull = 4
kArrayHandle = 5
kTVMType = 6
kNodeHandle = 7
kModuleHandle = 8
kFuncHandle = 9
kStr = 10
kBytes = 11
cdef extern from "tvm/runtime/c_runtime_api.h":
struct DLType:
uint8_t code
uint8_t bits
uint16_t lanes
ctypedef struct TVMValue:
int64_t v_int64
double v_float64
void* v_handle
const char* v_str
DLType v_type
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* NodeHandle
ctypedef int (*TVMPackedCFunc)(
TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* resource_handle)
ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)
cdef extern from "tvm/runtime/c_runtime_api.h":
void TVMAPISetLastError(const char* msg);
const char *TVMGetLastError();
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* arg_values,
int* type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code)
int TVMFuncFree(TVMFunctionHandle func)
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue value,
int type_code)
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out)
cdef extern from "tvm/c_api.h":
int TVMCbArgToReturn(TVMValue* value, int code)
int TVMNodeFree(NodeHandle handle)
TVMNodeTypeKey2Index(const char* type_key,
int* out_index)
int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index)
int TVMNodeGetAttr(NodeHandle handle,
const char* key,
TVMValue* out_value,
int* out_type_code,
int* out_success)
cdef inline py_str(const char* x):
if PY_MAJOR_VERSION < 3:
return x
else:
return x.decode("utf-8")
cdef inline c_str(pystr):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return pystr.encode("utf-8")
cdef inline CALL(int ret):
if ret != 0:
raise TVMError(TVMGetLastError())
cdef inline object ctypes_handle(void* chandle):
"""Cast C handle to ctypes handle."""
return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)
cdef inline void* c_handle(object handle):
"""Cast C types handle to c handle."""
cdef unsigned long long v_ptr
v_ptr = handle.value
return <void*>(v_ptr)
include "./base.pxi"
include "./node.pxi"
include "./function.pxi"
import ctypes
import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import NDArrayBase, TVMType, TVMByteArray
print("TVM: Initializing cython mode...")
cdef void tvm_callback_finalize(void* fhandle):
local_pyfunc = <object>(fhandle)
Py_DECREF(local_pyfunc)
cdef int tvm_callback(TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* fhandle):
cdef list pyargs
cdef TVMValue value
cdef int tcode
local_pyfunc = <object>(fhandle)
pyargs = []
for i in range(num_args):
value = args[i]
tcode = type_codes[i]
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle):
CALL(TVMCbArgToReturn(&value, tcode))
pyargs.append(make_ret(value, tcode))
try:
rv = local_pyfunc(*pyargs)
except Exception:
msg = traceback.format_exc()
TVMAPISetLastError(c_str(msg))
return -1
if rv is not None:
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value")
temp_args = []
make_arg(rv, &value, &tcode, temp_args)
CALL(TVMCFuncSetReturn(ret, value, tcode))
return 0
def convert_to_tvm_func(object pyfunc):
"""Convert a python function to TVM function
Parameters
----------
pyfunc : python function
The python function to be converted.
Returns
-------
tvmfunc: tvm.Function
The converted tvm function.
"""
cdef TVMFunctionHandle chandle
Py_INCREF(pyfunc)
CALL(TVMFuncCreateFromCFunc(tvm_callback,
<void*>(pyfunc),
tvm_callback_finalize,
&chandle))
return _CLASS_FUNCTION(ctypes_handle(chandle), False)
cdef inline void make_arg(object arg,
TVMValue* value,
int* tcode,
list temp_args):
"""Pack arguments into c args tvm call accept"""
if isinstance(arg, NodeBase):
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
elif isinstance(arg, NDArrayBase):
value[0].v_handle = c_handle(
ctypes.cast(arg.handle, ctypes.c_void_p))
tcode[0] = kArrayHandle
elif isinstance(arg, Integral):
value[0].v_int64 = arg
tcode[0] = kInt
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
elif arg is None:
value[0].v_handle = NULL
tcode[0] = kNull
elif isinstance(arg, TVMType):
tstr = c_str(str(arg))
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
arr.size = len(arg)
value[0].v_handle = <void*>(
<unsigned long long>ctypes.addressof(arr))
tcode[0] = kBytes
temp_args.append(arr)
elif isinstance(arg, string_types):
tstr = c_str(arg)
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
elif callable(arg):
arg = convert_to_tvm_func(arg)
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
cdef inline bytearray make_ret_bytes(void* chandle):
handle = ctypes_handle(chandle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
size = arr.size
res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed')
return res
cdef inline object make_ret(TVMValue value, int tcode):
"""convert result to return value."""
if tcode == kNodeHandle:
return make_ret_node(value.v_handle)
elif tcode == kNull:
return None
elif tcode == kInt:
return value.v_int64
elif tcode == kFloat:
return value.v_float64
elif tcode == kStr:
return py_str(value.v_str)
elif tcode == kBytes:
return make_ret_bytes(value.v_handle)
elif tcode == kHandle:
return ctypes_handle(value.v_handle)
elif tcode == kModuleHandle:
return _CLASS_MODULE(ctypes_handle(value.v_handle))
elif tcode == kFuncHandle:
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
else:
raise ValueError("Unhandled type code %d" % tcode)
cdef inline object FuncCall2(void* chandle, tuple args, int nargs):
cdef TVMValue[2] values
cdef int[2] tcodes
cdef TVMValue ret_val
cdef int ret_code
nargs = len(args)
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, &ret_val, &ret_code))
return make_ret(ret_val, ret_code)
cdef inline object FuncCall(void* chandle, tuple args):
cdef int nargs
nargs = len(args)
if nargs <= 2:
return FuncCall2(chandle, args, nargs)
cdef vector[TVMValue] values
cdef vector[int] tcodes
cdef TVMValue ret_val
cdef int ret_code
values.resize(max(nargs, 1))
tcodes.resize(max(nargs, 1))
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, &ret_val, &ret_code))
return make_ret(ret_val, ret_code)
cdef class FunctionBase:
cdef TVMFunctionHandle chandle
cdef int is_global
cdef _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property is_global:
def __get__(self):
return self.c_is_global != 0
def __set__(self, value):
self.c_is_global = value
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle, is_global):
self._set_handle(handle)
self.c_is_global = is_global
def __dealloc__(self):
if self.is_global == 0:
CALL(TVMFuncFree(self.chandle))
def __call__(self, *args):
return FuncCall(self.chandle, args)
_CLASS_FUNCTION = None
_CLASS_MODULE = None
def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
from ..base import string_types
from ..node_generic import _set_class_node_base
"""Maps node type to its constructor"""
NODE_TYPE = []
def _register_node(int index, object cls):
"""register node class"""
while len(NODE_TYPE) <= index:
NODE_TYPE.append(None)
NODE_TYPE[index] = cls
cdef inline object make_ret_node(void* chandle):
global NODE_TYPE
cdef int tindex
cdef list node_type
cdef object cls
node_type = NODE_TYPE
CALL(TVMNodeGetTypeIndex(chandle, &tindex))
if tindex < len(node_type):
cls = node_type[tindex]
if cls is not None:
obj = cls(None)
else:
obj = NodeBase(None)
(<NodeBase>obj).chandle = chandle
return obj
cdef class NodeBase:
cdef void* chandle
cdef _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
ptr = handle.value
self.chandle = <void*>(ptr)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes_handle(self.chandle)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
def __dealloc__(self):
CALL(TVMNodeFree(self.chandle))
def __getattr__(self, name):
cdef TVMValue ret_val
cdef int ret_type_code, ret_succ
CALL(TVMNodeGetAttr(self.chandle, c_str(name),
&ret_val, &ret_type_code, &ret_succ))
if ret_succ == 0:
raise AttributeError(
"'%s' object has no attribute '%s'" % (type(self), name))
return make_ret(ret_val, ret_type_code)
_set_class_node_base(NodeBase)
......@@ -4,11 +4,11 @@
from __future__ import absolute_import
import sys
import os
import ctypes
import numpy as np
from . import libinfo
__all__ = ['TVMError']
#----------------------------
# library loading
#----------------------------
......@@ -25,7 +25,7 @@ else:
class TVMError(Exception):
"""Error that will be throwed by all functions"""
"""Error thrown by TVM function"""
pass
def _load_lib():
......@@ -40,9 +40,11 @@ def _load_lib():
__version__ = libinfo.__version__
# library instance of nnvm
_LIB = _load_lib()
# The FFI mode of TVM
_FFI_MODE = os.environ.get("TVM_FFI", "auto")
#----------------------------
# helper function definition
# helper function in ctypes.
#----------------------------
def check_call(ret):
"""Check the return value of C API call
......
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement
"""Function configuration API."""
# pylint: disable=invalid-name, unused-import
"""Function namespace."""
from __future__ import absolute_import
import ctypes
import sys
import traceback
from numbers import Number, Integral
from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types
from .types import TVMValue, TypeCode, TVMType, TVMByteArray
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .node import NodeBase, NodeGeneric, convert_to_node
from .ndarray import NDArrayBase
import ctypes
from .base import _LIB, check_call, py_str, c_str, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import convert_to_tvm_func
else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func
FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle):
"""callback to free resources when it it not needed."""
pyobj = ctypes.cast(rhandle, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
def convert_to_tvm_func(pyfunc):
"""Convert a python function to TVM function
Parameters
----------
pyfunc : python function
The python function to be converted.
Returns
-------
tvmfunc: tvm.nd.Function
The converted tvm function.
"""
local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, ret, _):
""" ctypes function """
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
pyargs = [C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
# pylint: disable=broad-except
try:
rv = local_pyfunc(*pyargs)
except Exception:
msg = traceback.format_exc()
_LIB.TVMAPISetLastError(c_str(msg))
return -1
if rv is not None:
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one return value")
temp_args = []
values, tcodes, _ = _make_tvm_args((rv,), temp_args)
if not isinstance(ret, TVMRetValueHandle):
ret = TVMRetValueHandle(ret)
check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0])))
_ = temp_args
_ = rv
return 0
handle = FunctionHandle()
f = TVMPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
return Function(handle, False)
def _make_tvm_args(args, temp_args):
"""Pack arguments into c args tvm call accept"""
num_args = len(args)
values = (TVMValue * num_args)()
type_codes = (ctypes.c_int * num_args)()
for i, arg in enumerate(args):
if isinstance(arg, NodeBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.NODE_HANDLE
elif arg is None:
values[i].v_handle = None
type_codes[i] = TypeCode.NULL
elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = TypeCode.ARRAY_HANDLE
elif isinstance(arg, Integral):
values[i].v_int64 = arg
type_codes[i] = TypeCode.INT
elif isinstance(arg, Number):
values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, TVMType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr)
type_codes[i] = TypeCode.BYTES
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.NODE_HANDLE
temp_args.append(arg)
elif isinstance(arg, ModuleBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, Function):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
elif callable(arg):
arg = convert_to_tvm_func(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
class Function(object):
class Function(_FunctionBase):
"""The PackedFunc object used in TVM.
Function plays an key role to bridge front and backend in TVM.
......@@ -158,42 +51,7 @@ class Function(object):
tvm.register_func: How to register global function.
tvm.get_global_func: How to get global function.
"""
__slots__ = ["handle", "is_global"]
# pylint: disable=no-member
def __init__(self, handle, is_global):
"""Initialize the function with handle
Parameters
----------
handle : FunctionHandle
the handle to the underlying function.
is_global : bool
Whether this is a global function in python
"""
self.handle = handle
self.is_global = is_global
def __del__(self):
if not self.is_global:
check_call(_LIB.TVMFuncFree(self.handle))
def __call__(self, *args):
"""Call the function with positional arguments
args : list
The positional arguments to the function call.
"""
temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.TVMFuncCall(
self.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
pass
class ModuleBase(object):
......@@ -214,9 +72,8 @@ class ModuleBase(object):
"""
if self._entry:
return self._entry
else:
self._entry = self.get_function("__tvm_main__")
return self._entry
self._entry = self.get_function("__tvm_main__")
return self._entry
def get_function(self, name, query_imports=False):
"""Get function from the module.
......@@ -273,39 +130,11 @@ class ModuleBase(object):
raise ValueError("Can only take string as function name")
return self.get_function(name)
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
def __call__(self, *args):
if self._entry:
return self._entry(*args)
else:
f = self.entry_func
return f(*args)
_module_cls = None
def _return_module(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, ModuleHandle):
handle = ModuleHandle(handle)
return _module_cls(handle)
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, FunctionHandle):
handle = FunctionHandle(handle)
return Function(handle, False)
# setup return handle for function type
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE)
f = self.entry_func
return f(*args)
def register_func(func_name, f=None, override=False):
......@@ -456,8 +285,4 @@ def _init_api(namespace):
ff.__doc__ = ("TVM PackedFunc %s. " % fname)
setattr(target_module, ff.__name__, ff)
def _init_module_module(module_class):
"""Initialize the module."""
global _module_cls
_module_cls = module_class
_set_class_function(Function)
# coding: utf-8
"""Information about nnvm."""
"""Library information."""
from __future__ import absolute_import
import sys
import os
......@@ -16,17 +15,17 @@ def find_lib_path():
"""
use_runtime = os.environ.get("TVM_USE_RUNTIME_LIB", False)
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
api_path = os.path.join(curr_path, '../../lib/')
cmake_build_path = os.path.join(curr_path, '../../build/Release/')
api_path = os.path.join(curr_path, '../../../lib/')
cmake_build_path = os.path.join(curr_path, '../../../build/Release/')
dll_path = [curr_path, api_path, cmake_build_path]
if os.name == 'nt':
vs_configuration = 'Release'
if platform.architecture()[0] == '64bit':
dll_path.append(os.path.join(curr_path, '../../build', vs_configuration))
dll_path.append(os.path.join(curr_path, '../../windows/x64', vs_configuration))
dll_path.append(os.path.join(curr_path, '../../../build', vs_configuration))
dll_path.append(os.path.join(curr_path, '../../../windows/x64', vs_configuration))
else:
dll_path.append(os.path.join(curr_path, '../../build', vs_configuration))
dll_path.append(os.path.join(curr_path, '../../windows', vs_configuration))
dll_path.append(os.path.join(curr_path, '../../../build', vs_configuration))
dll_path.append(os.path.join(curr_path, '../../../windows', vs_configuration))
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
......@@ -40,7 +39,7 @@ def find_lib_path():
dll_path = runtime_dll_path if use_runtime else lib_dll_path
lib_found = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_found) == 0:
if not lib_found:
raise RuntimeError('Cannot find the files.\n' +
'List of candidates:\n' + str('\n'.join(dll_path)))
if use_runtime:
......
# pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
"""Symbolic configuration API."""
"""Runtime NDArray api"""
from __future__ import absolute_import
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array
tvm_shape_index_t = ctypes.c_int64
class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
class TVMType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str, lanes=1):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
elif type_str.startswith("handle"):
self.type_code = 4
bits = 64
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
from .._base import _LIB, check_call
from .._base import c_array
from .types import TVMType, tvm_shape_index_t
def __ne__(self, other):
return not self.__eq__(other)
class TVMContext(ctypes.Structure):
"""TVM context strucure."""
......@@ -104,10 +157,6 @@ def numpyasarray(np_data):
arr.ctx = cpu(0)
return arr, shape
_ndarray_cls = None
def empty(shape, dtype="float32", ctx=cpu(0)):
"""Create an empty array given shape and device
......@@ -133,7 +182,7 @@ def empty(shape, dtype="float32", ctx=cpu(0)):
dtype = TVMType(dtype)
check_call(_LIB.TVMArrayAlloc(
shape, ndim, dtype, ctx, ctypes.byref(handle)))
return _ndarray_cls(handle)
return _CLASS_NDARRAY(handle)
def sync(ctx):
......@@ -253,7 +302,8 @@ class NDArrayBase(object):
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
_CLASS_NDARRAY = None
def _init_ndarray_module(ndarray_class):
global _ndarray_cls
_ndarray_cls = ndarray_class
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
# coding: utf-8
# pylint: disable=invalid-name, protected-access
# pylint: disable=no-member, missing-docstring
"""Symbolic configuration API."""
"""Node namespace"""
# pylint: disable=unused-import
from __future__ import absolute_import
import ctypes
from numbers import Number, Integral
from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types
import sys
from .. import _api_internal
from .types import TVMValue, TypeCode
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
NodeHandle = ctypes.c_void_p
"""Maps node type to its constructor"""
NODE_TYPE = {
}
def _return_node(x):
"""Return node function"""
handle = x.v_handle
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
ret_val = TVMValue()
ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
handle, c_str("type_key"),
ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle)
RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
_return_node, TypeCode.NODE_HANDLE)
class NodeGeneric(object):
"""Base class for all classes that can be converted to node."""
def asnode(self):
"""Convert value to node"""
raise NotImplementedError()
from .node_generic import NodeGeneric, convert_to_node, const
from .base import _LIB, check_call, c_str, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _register_node, NodeBase as _NodeBase
else:
from ._cy2.core import _register_node, NodeBase as _NodeBase
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.node import _register_node, NodeBase as _NodeBase
class NodeBase(object):
class NodeBase(_NodeBase):
"""NodeBase is the base class of all TVM language AST object."""
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
self.handle = handle
def __repr__(self):
return _api_internal._format_str(self)
def __del__(self):
check_call(_LIB.TVMNodeFree(self.handle))
def __getattr__(self, name):
ret_val = TVMValue()
ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name),
ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
if not ret_success.value:
raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return RETURN_SWITCH[ret_type_code.value](ret_val)
def __dir__(self):
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMNodeListAttrNames(
self.handle, ctypes.byref(size), ctypes.byref(plist)))
names = []
for i in range(size.value):
names.append(py_str(plist[i]))
return names
def __hash__(self):
return _api_internal._raw_ptr(self)
......@@ -93,16 +47,6 @@ class NodeBase(object):
def __ne__(self, other):
return not self.__eq__(other)
def __dir__(self):
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMNodeListAttrNames(
self.handle, ctypes.byref(size), ctypes.byref(plist)))
names = []
for i in range(size.value):
names.append(py_str(plist[i]))
return names
def __reduce__(self):
return (type(self), (None,), self.__getstate__())
......@@ -110,8 +54,7 @@ class NodeBase(object):
handle = self.handle
if handle is not None:
return {'handle': _api_internal._save_json(self)}
else:
return {'handle': None}
return {'handle': None}
def __setstate__(self, state):
# pylint: disable=assigning-non-slot
......@@ -125,66 +68,6 @@ class NodeBase(object):
self.handle = None
def const(value, dtype=None):
"""Construct a constant value for a given type.
Parameters
----------
value : int or float
The input value
dtype : str
The data type.
Returns
-------
expr : Expr
Constant expression corresponds to the value.
"""
if dtype is None:
if isinstance(value, Integral):
dtype = 'int32'
else:
dtype = 'float32'
return _api_internal._const(value, dtype)
def convert_to_node(value):
"""Convert a python value to corresponding node type.
Parameters
----------
value : str
The value to be inspected.
Returns
-------
node : Node
The corresponding node value.
"""
if isinstance(value, NodeBase):
return value
elif isinstance(value, Number):
return const(value)
elif isinstance(value, string_types):
return _api_internal._str(value)
elif isinstance(value, (list, tuple)):
value = [convert_to_node(x) for x in value]
return _api_internal._Array(*value)
elif isinstance(value, dict):
vlist = []
for it in value.items():
if not isinstance(it[0], NodeBase):
raise ValueError("key of map must already been a container type")
vlist.append(it[0])
vlist.append(convert_to_node(it[1]))
return _api_internal._Map(*vlist)
elif isinstance(value, NodeGeneric):
return value.asnode()
else:
raise ValueError("don't know how to convert type %s to node" % type(value))
def register_node(type_key=None):
"""register node type
......@@ -197,10 +80,14 @@ def register_node(type_key=None):
def register(cls):
"""internal register function"""
NODE_TYPE[node_name] = cls
tindex = ctypes.c_int()
try:
check_call(_LIB.TVMNodeTypeKey2Index(c_str(node_name), ctypes.byref(tindex)))
_register_node(tindex.value, cls)
except AttributeError:
pass
return cls
if isinstance(type_key, str):
return register
else:
return register(type_key)
return register(type_key)
"""Common implementation of Node generic related logic"""
# pylint: disable=unused-import
from __future__ import absolute_import
from numbers import Number, Integral
from .. import _api_internal
from .base import string_types
# Node base class
_CLASS_NODE_BASE = None
def _set_class_node_base(cls):
global _CLASS_NODE_BASE
_CLASS_NODE_BASE = cls
class NodeGeneric(object):
"""Base class for all classes that can be converted to node."""
def asnode(self):
"""Convert value to node"""
raise NotImplementedError()
def convert_to_node(value):
"""Convert a python value to corresponding node type.
Parameters
----------
value : str
The value to be inspected.
Returns
-------
node : Node
The corresponding node value.
"""
if isinstance(value, _CLASS_NODE_BASE):
return value
elif isinstance(value, Number):
return const(value)
elif isinstance(value, string_types):
return _api_internal._str(value)
elif isinstance(value, (list, tuple)):
value = [convert_to_node(x) for x in value]
return _api_internal._Array(*value)
elif isinstance(value, dict):
vlist = []
for item in value.items():
if not isinstance(item[0], _CLASS_NODE_BASE):
raise ValueError("key of map must already been a container type")
vlist.append(item[0])
vlist.append(convert_to_node(item[1]))
return _api_internal._Map(*vlist)
elif isinstance(value, NodeGeneric):
return value.asnode()
else:
raise ValueError("don't know how to convert type %s to node" % type(value))
def const(value, dtype=None):
"""Construct a constant value for a given type.
Parameters
----------
value : int or float
The input value
dtype : str
The data type.
Returns
-------
expr : Expr
Constant expression corresponds to the value.
"""
if dtype is None:
if isinstance(value, Integral):
dtype = 'int32'
else:
dtype = 'float32'
return _api_internal._const(value, dtype)
......@@ -7,7 +7,7 @@ import os
import ctypes
from .. import _api_internal
from .._base import string_types
from .._ffi.base import string_types
from .._ffi.node import NodeBase, register_node
from .._ffi.function import register_func
from . import testing
......@@ -23,7 +23,10 @@ class VPISession(NodeBase):
def __del__(self):
self.proc.kill()
super(VPISession, self).__del__()
try:
super(VPISession, self).__del__()
except AttributeError:
pass
def arg(self, index):
"""Get handle passed to host session.
......@@ -143,10 +146,9 @@ def find_file(file_name):
ver_path = search_path()
flist = [os.path.join(p, file_name) for p in ver_path]
found = [p for p in flist if os.path.exists(p) and os.path.isfile(p)]
if len(found):
return found[0]
else:
if not found:
raise ValueError("Cannot find %s in %s" % (file_name, flist))
return found[0]
def compile_file(file_name, file_target, options=None):
......
......@@ -4,13 +4,13 @@ from __future__ import absolute_import as _abs
from numbers import Integral as _Integral
from ._ffi.base import string_types
from ._ffi.node import register_node, NodeBase
from ._ffi.node import convert_to_node as _convert_to_node
from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal
from . import _base
from . import make as _make
from . import expr as _expr
from . import tensor as _tensor
......@@ -57,8 +57,8 @@ def convert(value):
if callable(value):
return _convert_tvm_func(value)
else:
return _convert_to_node(value)
return _convert_to_node(value)
def load_json(json_str):
......@@ -396,9 +396,9 @@ def thread_axis(dom=None, tag='', name=''):
axis : IterVar
The thread itervar.
"""
if isinstance(dom, _base.string_types):
if isinstance(dom, string_types):
tag, dom = dom, None
if len(tag) == 0:
if not tag:
raise ValueError("tag must be given as Positional or keyword argument")
name = name if name else tag
return _IterVar(dom, name, 1, tag)
......
......@@ -157,8 +157,7 @@ def _rule_float_suffix(op):
return call_pure_extern(op.dtype, "%sf" % op.name, *op.args)
elif op.dtype == "float64":
return call_pure_extern(op.dtype, op.name, *op.args)
else:
return op
return op
def _rule_float_direct(op):
......@@ -183,8 +182,7 @@ def _rule_float_direct(op):
"""
if str(op.dtype).startswith("float"):
return call_pure_extern(op.dtype, op.name, *op.args)
else:
return None
return None
# opencl pattern for exp
register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
......
......@@ -7,7 +7,7 @@ from . import expr as _expr
from . import make as _make
from . import ir_pass as _pass
from . import collections as _collections
from ._base import string_types
from ._ffi.base import string_types
from ._ffi.node import NodeGeneric
class WithScope(object):
......@@ -89,7 +89,7 @@ class IRBuilder(object):
def _pop_seq(self):
"""Pop sequence from stack"""
seq = self._seq_stack.pop()
if len(seq) == 0 or callable(seq[-1]):
if not seq or callable(seq[-1]):
seq.append(_make.Evaluate(0))
stmt = seq[-1]
for s in reversed(seq[:-1]):
......@@ -232,7 +232,7 @@ class IRBuilder(object):
with ib.else_scope():
x[i] = x[i - 1] + 2
"""
if len(self._seq_stack[-1]) == 0:
if not self._seq_stack[-1]:
raise RuntimeError("else_scope can only follow an if_scope")
prev = self._seq_stack[-1][-1]
if not isinstance(prev, _stmt.IfThenElse) or prev.else_case:
......@@ -317,7 +317,7 @@ class IRBuilder(object):
The result statement.
"""
seq = self._pop_seq()
if len(self._seq_stack) != 0:
if self._seq_stack:
raise RuntimeError("cannot call get inside construction scope")
return seq
......
"""Container of compiled functions of TVM."""
from __future__ import absolute_import as _abs
from ._ffi.function import ModuleBase, _init_module_module
from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api
......@@ -97,4 +97,4 @@ def enabled(target):
_init_api("tvm.module")
_init_module_module(Module)
_set_class_module(Module)
......@@ -9,7 +9,7 @@ import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import cpu, gpu, opencl, vpi, empty, sync
from ._ffi.ndarray import _init_ndarray_module
from ._ffi.ndarray import _set_class_ndarray
cl = opencl
......@@ -49,5 +49,4 @@ def array(arr, ctx=cpu(0)):
ret[:] = arr
return ret
_init_ndarray_module(NDArray)
_set_class_ndarray(NDArray)
......@@ -93,8 +93,7 @@ class Tensor(NodeBase):
op = self.op
if op.num_outputs == 1:
return op.name
else:
return "%s.v%d" % (op.name, self.value_index)
return "%s.v%d" % (op.name, self.value_index)
class Operation(NodeBase):
......
......@@ -113,10 +113,18 @@ int TVMCbArgToReturn(TVMValue* value, int code) {
API_END();
}
int TVMNodeDupe(NodeHandle handle, NodeHandle* out_handle) {
int TVMNodeTypeKey2Index(const char* type_key,
int* out_index) {
API_BEGIN();
*out_index = static_cast<int>(Node::TypeKey2Index(type_key));
API_END();
}
*out_handle = new TVMAPINode(*static_cast<TVMAPINode*>(handle));
int TVMNodeGetTypeIndex(NodeHandle handle,
int* out_index) {
API_BEGIN();
*out_index = static_cast<int>(
(*static_cast<TVMAPINode*>(handle))->type_index());
API_END();
}
......
......@@ -9,7 +9,7 @@
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS
ignore=CVS, _cy2, _cy3
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
......@@ -67,7 +67,6 @@ enable=indexing-exception,old-raise-syntax
# --disable=W"
disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,protected-access
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
......
......@@ -7,7 +7,9 @@ import subprocess
runtime_py = """
import os
import sys
os.environ["TVM_USE_RUNTIME_LIB"] = "1"
os.environ["TVM_FFI"] = "ctypes"
import tvm
import numpy as np
path_dso = sys.argv[1]
......
......@@ -54,6 +54,10 @@ if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
python -m nose -v tests/python/unittest || exit -1
python3 -m nose -v tests/python/unittest || exit -1
make cython || exit -1
make cython3 || exit -1
TVM_FFI=cython python -m nose -v tests/python/unittest || exit -1
TVM_FFI=cython python3 -m nose -v tests/python/unittest || exit -1
else
nosetests -v tests/python/unittest || exit -1
nosetests3 -v tests/python/unittest || exit -1
......
#!/bin/bash
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
brew update
brew install python3
python -m pip install --user nose numpy
python3 -m pip install --user nose numpy
python -m pip install --user nose numpy cython
python3 -m pip install --user nose numpy cython
fi
fi
......
......@@ -73,11 +73,11 @@ def test_ram_read():
# yield until read is done
for i in range(a.shape[0] * 3):
sess.yield_until_next_cycle()
sess.shutdown()
# check if result matches
r = np.concatenate((a_np, a_np[2:]))
np.testing.assert_equal(np.array(reader.data), r)
def test_ram_write():
n = 10
# read from offset
......@@ -122,7 +122,7 @@ def test_ram_write():
# yield until write is done
for i in range(a.shape[0]+2):
sess.yield_until_next_cycle()
sess.shutdown()
# check if result matches
np.testing.assert_equal(a.asnumpy()[2:], r_data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment