Commit ff06917c by Tianqi Chen Committed by GitHub

[API/Refactor] Unified PackedFunc for API and Generated Functions (#26)

parent 4242b9cf
...@@ -4,7 +4,7 @@ language: cpp ...@@ -4,7 +4,7 @@ language: cpp
os: os:
- linux - linux
- osx # - osx
env: env:
# code analysis # code analysis
......
/*!
* Copyright (c) 2016 by Contributors
* \file api_registry.h
* \brief This file defines the TVM API registry.
*
* The API registry stores type-erased functions.
* Each registered function is automatically exposed
* to front-end language(e.g. python).
* Front-end can also pass callbacks as PackedFunc, or register
* then into the same global registry in C++.
* The goal is to mix the front-end language and the TVM back-end.
*
* \code
* // register the function as MyAPIFuncName
* TVM_REGISTER_API(MyAPIFuncName)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#ifndef TVM_API_REGISTRY_H_
#define TVM_API_REGISTRY_H_
#include <dmlc/base.h>
#include <string>
#include "./base.h"
#include "./runtime/packed_func.h"
#include "./packed_func_ext.h"
namespace tvm {
/*! \brief Utility to register API. */
class APIRegistry {
public:
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
APIRegistry& set_body(PackedFunc f); // NOLINT(*)
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
APIRegistry& set_body(PackedFunc::FType f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
/*!
* \brief Register a function with given name
* \param name The name of the function.
*/
static APIRegistry& __REGISTER__(const std::string& name); // NOLINT(*)
private:
/*! \brief name of the function */
std::string name_;
};
/*!
* \brief Get API function by name.
*
* \param name The name of the function.
* \return the corresponding API function.
* \note It is really PackedFunc::GetGlobal under the hood.
*/
inline PackedFunc GetAPIFunc(const std::string& name) {
return PackedFunc::GetGlobal(name);
}
#define _TVM_REGISTER_VAR_DEF_ \
static DMLC_ATTRIBUTE_UNUSED ::tvm::APIRegistry& __make_TVMRegistry_
/*!
* \brief Register API function globally.
* \code
* TVM_REGISTER_API(MyPrint)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#define TVM_REGISTER_API(OpName) \
DMLC_STR_CONCAT(_TVM_REGISTER_VAR_DEF_, __COUNTER__) = \
::tvm::APIRegistry::__REGISTER__(#OpName)
} // namespace tvm
#endif // TVM_API_REGISTRY_H_
...@@ -2,6 +2,13 @@ ...@@ -2,6 +2,13 @@
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file c_api.h * \file c_api.h
* \brief C API of TVM DSL * \brief C API of TVM DSL
*
* \note The API is designed in a minimum way.
* Most of the API functions are registered and can be pulled out.
*
* The common flow is:
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
*/ */
#ifndef TVM_C_API_H_ #ifndef TVM_C_API_H_
#define TVM_C_API_H_ #define TVM_C_API_H_
...@@ -9,77 +16,10 @@ ...@@ -9,77 +16,10 @@
#include "./runtime/c_runtime_api.h" #include "./runtime/c_runtime_api.h"
TVM_EXTERN_C { TVM_EXTERN_C {
/*! \brief handle to functions */
typedef void* APIFuncHandle;
/*! \brief handle to node */ /*! \brief handle to node */
typedef void* NodeHandle; typedef void* NodeHandle;
/*! /*!
* \brief List all the node function name
* \param out_size The number of functions
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMListAPIFuncNames(int *out_size,
const char*** out_array);
/*!
* \brief get function handle by name
* \param name The name of function
* \param handle The returning function handle
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetAPIFuncHandle(const char* name,
APIFuncHandle *handle);
/*!
* \brief Get the detailed information about function.
* \param handle The operator handle.
* \param real_name The returned name of the function.
* This name is not the alias name of the atomic symbol.
* \param description The returned description of the symbol.
* \param num_doc_args Number of arguments that contain documents.
* \param arg_names Name of the arguments of doc args
* \param arg_type_infos Type informations about the arguments.
* \param arg_descriptions Description information about the arguments.
* \param return_type Return type of the function, if any.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetAPIFuncInfo(APIFuncHandle handle,
const char **real_name,
const char **description,
int *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
/*!
* \brief Push an argument to the function calling stack.
* If push fails, the stack will be reset to empty
*
* \param arg The argument
* \param type_code The type_code of argument as in TVMTypeCode
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/
TVM_DLL int TVMAPIPushStack(TVMValue arg,
int type_code);
/*!
* \brief call a function by using arguments in the stack.
* The stack will be cleanup to empty after this call, whether the call is successful.
*
* \param handle The function handle
* \param ret_val The return value.
* \param ret_type_code the type code of return value.
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/
TVM_DLL int TVMAPIFuncCall(APIFuncHandle handle,
TVMValue* ret_val,
int* ret_type_code);
/*!
* \brief free the node handle * \brief free the node handle
* \param handle The node handle to be freed. * \param handle The node handle to be freed.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <string> #include <string>
#include <algorithm> #include <algorithm>
#include "./base.h" #include "./base.h"
#include "./runtime/packed_func.h"
namespace tvm { namespace tvm {
......
/*!
* Copyright (c) 2016 by Contributors
* \file packed_func_ext.h
* \brief Extension package to PackedFunc
* This enales pass NodeRef types into/from PackedFunc.
*/
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_
#include <sstream>
#include <string>
#include <memory>
#include <type_traits>
#include "./base.h"
#include "./expr.h"
namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
namespace runtime {
/*!
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
struct NodeTypeChecker {
static inline bool Check(Node* sptr) {
// This is the only place in the project where RTTI is used
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};
template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
if (!NodeTypeChecker<T>::Check(p.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "array<";
NodeTypeChecker<T>::PrintName(os);
os << ">";
}
};
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<MapNode>()) return false;
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<";
NodeTypeChecker<K>::PrintName(os);
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};
template<typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
NodeTypeChecker<T>::PrintName(os);
return os.str();
}
// extensions for tvm arg value
template<typename TNodeRef, typename>
inline TVMArgValue::operator TNodeRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TNodeRef>()
<< " but get " << sptr->type_key();
return TNodeRef(sptr);
}
inline TVMArgValue::operator Halide::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kInt) {
return Expr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kFloat) {
return Expr(static_cast<float>(value_.v_float64));
}
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
}
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<Expr>()
<< " but get " << sptr->type_key();
return Expr(sptr);
}
inline std::shared_ptr<Node>& TVMArgValue::node_sptr() {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return *ptr<std::shared_ptr<Node> >();
}
template<typename TNodeRef, typename>
inline bool TVMArgValue::IsNodeType() const {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr =
*ptr<std::shared_ptr<Node> >();
return NodeTypeChecker<TNodeRef>::Check(sptr.get());
}
// extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=(
const std::shared_ptr<Node>& other) {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
return *this;
}
inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
return *this;
}
template<typename TNodeRef, typename>
inline TVMRetValue::operator TNodeRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return TNodeRef(*ptr<std::shared_ptr<Node> >());
}
inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLINT(*)
values_[i].v_handle = &(other.node_);
type_codes_[i] = kNodeHandle;
}
// Type related stuffs
inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
}
inline TVMType Type2TVMType(Type t) {
TVMType ret;
ret.code = static_cast<uint8_t>(t.code());
ret.bits = static_cast<uint8_t>(t.bits());
ret.lanes = static_cast<uint16_t>(t.lanes());
return ret;
}
inline TVMRetValue& TVMRetValue::operator=(const Halide::Type& t) {
return this->operator=(Type2TVMType(t));
}
inline TVMRetValue::operator Halide::Type() const {
return TVMType2Type(operator TVMType());
}
inline TVMArgValue::operator Halide::Type() const {
return TVMType2Type(operator TVMType());
}
inline void TVMArgsSetter::operator()(
size_t i, const Halide::Type& t) const {
this->operator()(i, Type2TVMType(t));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_
...@@ -36,18 +36,6 @@ ...@@ -36,18 +36,6 @@
TVM_EXTERN_C { TVM_EXTERN_C {
/*! \brief type of array index. */ /*! \brief type of array index. */
typedef uint32_t tvm_index_t; typedef uint32_t tvm_index_t;
/*!
* \brief Union type of values
* being passed through API and function calls.
*/
typedef union {
int64_t v_int64;
double v_float64;
void* v_handle;
const char* v_str;
} TVMValue;
/*! /*!
* \brief The type code in TVMType * \brief The type code in TVMType
* \note TVMType is used in two places. * \note TVMType is used in two places.
...@@ -60,9 +48,11 @@ typedef enum { ...@@ -60,9 +48,11 @@ typedef enum {
// The next few fields are extension types // The next few fields are extension types
// that is used by TVM API calls. // that is used by TVM API calls.
kNull = 4U, kNull = 4U,
kNodeHandle = 5U, kArrayHandle = 5U,
kStr = 6U, kTVMType = 6U,
kFuncHandle = 7U kNodeHandle = 7U,
kStr = 8U,
kFuncHandle = 9U
} TVMTypeCode; } TVMTypeCode;
/*! /*!
...@@ -77,7 +67,7 @@ typedef enum { ...@@ -77,7 +67,7 @@ typedef enum {
*/ */
typedef struct { typedef struct {
/*! \brief type code, in TVMTypeCode */ /*! \brief type code, in TVMTypeCode */
uint8_t type_code; uint8_t code;
/*! \brief number of bits of the type */ /*! \brief number of bits of the type */
uint8_t bits; uint8_t bits;
/*! \brief number of lanes, */ /*! \brief number of lanes, */
...@@ -85,6 +75,18 @@ typedef struct { ...@@ -85,6 +75,18 @@ typedef struct {
} TVMType; } TVMType;
/*! /*!
* \brief Union type of values
* being passed through API and function calls.
*/
typedef union {
int64_t v_int64;
double v_float64;
void* v_handle;
const char* v_str;
TVMType v_type;
} TVMValue;
/*!
* \brief The device type * \brief The device type
*/ */
typedef enum { typedef enum {
...@@ -133,11 +135,10 @@ typedef struct { ...@@ -133,11 +135,10 @@ typedef struct {
* can be NULL, which indicates the default one. * can be NULL, which indicates the default one.
*/ */
typedef void* TVMStreamHandle; typedef void* TVMStreamHandle;
/*! /*! \brief Handle to packed function handle. */
* \brief Pointer to function handle that points to
* a generated TVM function.
*/
typedef void* TVMFunctionHandle; typedef void* TVMFunctionHandle;
/*! \brief Handle to hold return value. */
typedef void* TVMRetValueHandle;
/*! \brief the array handle */ /*! \brief the array handle */
typedef TVMArray* TVMArrayHandle; typedef TVMArray* TVMArrayHandle;
...@@ -228,20 +229,45 @@ TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream); ...@@ -228,20 +229,45 @@ TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
TVM_DLL int TVMFuncFree(TVMFunctionHandle func); TVM_DLL int TVMFuncFree(TVMFunctionHandle func);
/*! /*!
* \brief Call a function whose parameters are all packed. * \brief Call a Packed TVM Function.
* *
* \param func node handle of the function. * \param func node handle of the function.
* \param args The arguments * \param arg_values The arguments
* \param type_codes The type codes of the arguments * \param type_codes The type codes of the arguments
* \param num_args Number of arguments. * \param num_args Number of arguments.
* *
* \param ret_val The return value.
* \param ret_type_code the type code of return value.
*
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
* \note TVM calls always exchanges with type bits=64, lanes=1 * \note TVM calls always exchanges with type bits=64, lanes=1
*
* \note API calls always exchanges with type bits=64, lanes=1
* If API call returns container handles (e.g. FunctionHandle)
* these handles should be managed by the front-end.
* The front-end need to call free function (e.g. TVMFuncFree)
* to free these handles.
*/ */
TVM_DLL int TVMFuncCall(TVMFunctionHandle func, TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args, TVMValue* arg_values,
int* type_codes, int* type_codes,
int num_args); int num_args,
TVMValue* ret_val,
int* ret_type_code);
/*!
* \brief Set the return value of TVMPackedCFunc.
*
* This function is called by TVMPackedCFunc to set the return value.
* When this function is not called, the function returns null by default.
*
* \param ret The return value handle, pass by ret in TVMPackedCFunc
* \param value The value to be returned.
* \param type_code The type of the value to be returned.
*/
TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue value,
int type_code);
/*! /*!
* \brief C type of packed function. * \brief C type of packed function.
...@@ -249,10 +275,17 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func, ...@@ -249,10 +275,17 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
* \param args The arguments * \param args The arguments
* \param type_codes The type codes of the arguments * \param type_codes The type codes of the arguments
* \param num_args Number of arguments. * \param num_args Number of arguments.
* \param ret The return value handle.
* \param resource_handle The handle additional resouce handle from fron-end. * \param resource_handle The handle additional resouce handle from fron-end.
*
* \sa TVMCFuncSetReturn
*/ */
typedef void (*TVMPackedCFunc)( typedef void (*TVMPackedCFunc)(
TVMValue* args, int* type_codes, int num_args, void* resource_handle); TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* resource_handle);
/*! /*!
* \brief C callback to free the resource handle in C packed function. * \brief C callback to free the resource handle in C packed function.
...@@ -291,8 +324,20 @@ TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f); ...@@ -291,8 +324,20 @@ TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f);
* *
* \param name The name of the function. * \param name The name of the function.
* \param out the result function pointer. * \param out the result function pointer.
*
* \note The function handle of global function is managed by TVM runtime,
* So TVMFuncFree is should not be called when it get deleted.
*/ */
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
/*!
* \brief List all the globally registered function name
* \param out_size The number of functions
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncListGlobalNames(int *out_size,
const char*** out_array);
} // TVM_EXTERN_C } // TVM_EXTERN_C
#endif // TVM_RUNTIME_C_RUNTIME_API_H_ #endif // TVM_RUNTIME_C_RUNTIME_API_H_
...@@ -10,5 +10,6 @@ ...@@ -10,5 +10,6 @@
#include "./expr.h" #include "./expr.h"
#include "./tensor.h" #include "./tensor.h"
#include "./operation.h" #include "./operation.h"
#include "./packed_func_ext.h"
#endif // TVM_TVM_H_ #endif // TVM_TVM_H_
# pylint: disable=redefined-builtin, wildcard-import # pylint: disable=redefined-builtin, wildcard-import
"""C++ backend related python scripts""" """C++ backend related python scripts"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import register_node from ._ctypes._node import register_node
from . import tensor from . import tensor
from . import expr from . import expr
......
...@@ -91,45 +91,3 @@ def c_array(ctype, values): ...@@ -91,45 +91,3 @@ def c_array(ctype, values):
Created ctypes array Created ctypes array
""" """
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.
num_args : nn_uint
Number of arguments.
arg_names : ctypes.POINTER(ctypes.c_char_p)
Argument names.
arg_types : ctypes.POINTER(ctypes.c_char_p)
Argument type information.
arg_descs : ctypes.POINTER(ctypes.c_char_p)
Argument description information.
remove_dup : boolean, optional
Whether remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for i in range(num_args.value):
key = py_str(arg_names[i])
if key in param_keys and remove_dup:
continue
param_keys.add(key)
type_info = py_str(arg_types[i])
ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
# coding: utf-8
# pylint: disable=invalid-name, protected-access
"""Symbolic configuration API."""
from __future__ import absolute_import
import ctypes
import sys
from numbers import Number, Integral
from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types
from ._types import TVMValue, TypeCode, TVMType
from ._types import TVMPackedCFunc, TVMCFuncFinalizer
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH
from ._node import NodeBase, SliceBase, convert_to_node
from ._ndarray import NDArrayBase
FunctionHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle):
"""callback to free resources when it it not needed."""
pyobj = ctypes.cast(rhandle, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
def convert_to_tvm_func(pyfunc):
"""Convert a python function to TVM function
Parameters
----------
pyfunc : python function
The python function to be converted.
Returns
-------
tvmfunc: tvm.nd.Function
The converted tvm function.
"""
local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, ret, _):
""" ctypes function """
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
pyargs = [C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
rv = local_pyfunc(*pyargs)
if rv is not None:
if isinstance(rv, tuple):
raise ValueError("PackedFunction can only support one reurn value")
temp_args = []
values, tcodes, _ = _make_tvm_args((rv,), temp_args)
if not isinstance(ret, TVMRetValueHandle):
ret = TVMRetValueHandle(ret)
check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0])))
_ = temp_args
_ = rv
handle = FunctionHandle()
f = TVMPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
return Function(handle)
def _make_tvm_args(args, temp_args):
"""Pack arguments into c args tvm call accept"""
num_args = len(args)
values = (TVMValue * num_args)()
type_codes = (ctypes.c_int * num_args)()
for i, arg in enumerate(args):
if arg is None:
values[i].v_handle = None
type_codes[i] = TypeCode.NULL
elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = TypeCode.ARRAY_HANDLE
elif isinstance(arg, NodeBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.NODE_HANDLE
elif isinstance(arg, Integral):
values[i].v_int64 = arg
type_codes[i] = TypeCode.INT
elif isinstance(arg, Number):
values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, TVMType):
values[i].v_type = arg
type_codes[i] = TypeCode.TVM_TYPE
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, SliceBase)):
arg = convert_to_node(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.NODE_HANDLE
temp_args.append(arg)
elif isinstance(arg, Function):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
elif callable(arg):
arg = convert_to_tvm_func(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
class Function(object):
"""A function object at runtime."""
__slots__ = ["handle", "is_global"]
# pylint: disable=no-member
def __init__(self, handle, is_global=False):
"""Initialize the function with handle
Parameters
----------
handle : FunctionHandle
the handle to the underlying function.
is_global : bool, optional
Whether it is global function
"""
self.handle = handle
self.is_global = is_global
def __del__(self):
if not self.is_global:
check_call(_LIB.TVMFuncFree(self.handle))
def __call__(self, *args):
temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.TVMFuncCall(
self.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, FunctionHandle):
handle = FunctionHandle(handle)
return Function(handle, False)
# setup return handle for function type
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
def register_func(func_name, f=None):
"""Register global function
Parameters
----------
func_name : str or function
The function name
f : function
The function to be registered.
Returns
-------
fregister : function
Register function if f is not specified.
"""
if callable(func_name):
f = func_name
func_name = f.__name__
if not isinstance(func_name, str):
raise ValueError("expect string function name")
def register(myf):
"""internal register function"""
if not isinstance(myf, Function):
myf = convert_to_tvm_func(myf)
check_call(_LIB.TVMFuncRegisterGlobal(
c_str(func_name), myf.handle))
if f:
register(f)
else:
return register
def get_global_func(name):
"""Get a global function by name
Parameters
----------
name : str
The name of the global function
Returns
-------
func : tvm.nd.Function
The function to be returned.
"""
handle = FunctionHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
return Function(handle, True)
def list_global_func_names():
"""Get list of global functions registered.
Returns
-------
names : list
List of global functions names.
"""
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size),
ctypes.byref(plist)))
fnames = []
for i in range(size.value):
fnames.append(py_str(plist[i]))
return fnames
def _init_api_functions(root_namespace):
"""List and add all the functions to current module."""
module_obj = sys.modules["%s.api" % root_namespace]
module_internal = sys.modules["%s._api_internal" % root_namespace]
namespace_match = {
"_make_": sys.modules["%s.make" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
"_codegen_": sys.modules["%s.codegen" % root_namespace],
"_schedule_": sys.modules["%s.schedule" % root_namespace]
}
for name in list_global_func_names():
fname = name
target_module = module_internal if name.startswith('_') else module_obj
for k, v in namespace_match.items():
if name.startswith(k):
fname = name[len(k):]
target_module = v
f = get_global_func(name)
setattr(target_module, fname, f)
# pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement # pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring # pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
"""Symbolic configuration API.""" """Symbolic configuration API."""
from __future__ import absolute_import as _abs from __future__ import absolute_import
import ctypes import ctypes
from numbers import Number, Integral
import numpy as np import numpy as np
from .._base import _LIB from .._base import _LIB, check_call
from .._base import c_array, c_str, string_types from .._base import c_array, c_str
from .._base import check_call from ._types import TVMType, tvm_index_t
from ._types import TVMValue, TypeCode, TVMType
tvm_index_t = ctypes.c_uint32
class TVMContext(ctypes.Structure): class TVMContext(ctypes.Structure):
"""TVM context strucure.""" """TVM context strucure."""
...@@ -39,6 +36,19 @@ class TVMContext(ctypes.Structure): ...@@ -39,6 +36,19 @@ class TVMContext(ctypes.Structure):
return ret.value != 0 return ret.value != 0
class TVMArray(ctypes.Structure):
"""TVMValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("shape", ctypes.POINTER(tvm_index_t)),
("strides", ctypes.POINTER(tvm_index_t)),
("ndim", tvm_index_t),
("dtype", TVMType),
("ctx", TVMContext)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
def cpu(dev_id=0): def cpu(dev_id=0):
"""Construct a CPU device """Construct a CPU device
...@@ -72,18 +82,6 @@ def opencl(dev_id=0): ...@@ -72,18 +82,6 @@ def opencl(dev_id=0):
return TVMContext(4, dev_id) return TVMContext(4, dev_id)
class TVMArray(ctypes.Structure):
"""TVMValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("shape", ctypes.POINTER(tvm_index_t)),
("strides", ctypes.POINTER(tvm_index_t)),
("ndim", tvm_index_t),
("dtype", TVMType),
("ctx", TVMContext)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
def numpyasarray(np_data): def numpyasarray(np_data):
"""Return a TVMArray representation of a numpy array. """Return a TVMArray representation of a numpy array.
""" """
...@@ -102,7 +100,6 @@ def numpyasarray(np_data): ...@@ -102,7 +100,6 @@ def numpyasarray(np_data):
_ndarray_cls = None _ndarray_cls = None
_function_cls = None
def empty(shape, dtype="float32", ctx=cpu(0)): def empty(shape, dtype="float32", ctx=cpu(0)):
...@@ -275,51 +272,6 @@ class NDArrayBase(object): ...@@ -275,51 +272,6 @@ class NDArrayBase(object):
return target return target
class FunctionBase(object): def _init_ndarray_module(ndarray_class):
"""A function object at runtim."""
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : FunctionHandle
the handle to the underlying function.
"""
self.handle = handle
def __del__(self):
check_call(_LIB.TVMFuncFree(self.handle))
def __call__(self, *args):
num_args = len(args)
tvm_args = (TVMValue * num_args)()
tvm_type_code = (ctypes.c_int * num_args)()
for i, arg in enumerate(args):
if arg is None:
tvm_args[i].v_handle = None
tvm_type_code[i] = TypeCode.NULL
elif isinstance(arg, NDArrayBase):
tvm_args[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
tvm_type_code[i] = TypeCode.HANDLE
elif isinstance(arg, Integral):
tvm_args[i].v_int64 = arg
tvm_type_code[i] = TypeCode.INT
elif isinstance(arg, Number):
tvm_args[i].v_float64 = arg
tvm_type_code[i] = TypeCode.FLOAT
elif isinstance(arg, string_types):
tvm_args[i].v_str = c_str(arg)
tvm_type_code[i] = TypeCode.STR
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
check_call(_LIB.TVMFuncCall(
self.handle, tvm_args, tvm_type_code, ctypes.c_int(num_args)))
def _init_runtime_module(ndarray_class, function_class):
global _ndarray_cls global _ndarray_cls
global _function_cls
_ndarray_cls = ndarray_class _ndarray_cls = ndarray_class
_function_cls = function_class
# coding: utf-8
# pylint: disable=invalid-name, protected-access
# pylint: disable=no-member, missing-docstring
"""Symbolic configuration API."""
from __future__ import absolute_import
import ctypes
from numbers import Number, Integral
from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types
from .. import _api_internal
from ._types import TVMValue, TypeCode, RETURN_SWITCH
NodeHandle = ctypes.c_void_p
"""Maps node type to its constructor"""
NODE_TYPE = {
}
def _return_node(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
ret_val = TVMValue()
ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
handle, c_str("type_key"),
ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
return NODE_TYPE.get(py_str(ret_val.v_str), NodeBase)(handle)
RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
class SliceBase(object):
"""base class of slice object"""
pass
class NodeBase(object):
"""Symbol is symbolic graph."""
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
self.handle = handle
def __repr__(self):
return _api_internal._format_str(self)
def __del__(self):
check_call(_LIB.TVMNodeFree(self.handle))
def __getattr__(self, name):
ret_val = TVMValue()
ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name),
ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
if not ret_success.value:
raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return RETURN_SWITCH[ret_type_code.value](ret_val)
def __hash__(self):
return _api_internal._raw_ptr(self)
def __eq__(self, other):
if not isinstance(other, NodeBase):
return False
return self.__hash__() == other.__hash__()
def __ne__(self, other):
return not self.__eq__(other)
def __dir__(self):
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMNodeListAttrNames(
self.handle, ctypes.byref(size), ctypes.byref(plist)))
names = []
for i in range(size.value):
names.append(py_str(plist[i]))
return names
def __reduce__(self):
return (type(self), (None,), self.__getstate__())
def __getstate__(self):
handle = self.handle
if handle is not None:
return {'handle': _api_internal._save_json(self)}
else:
return {'handle': None}
def __setstate__(self, state):
# pylint: disable=assigning-non-slot
handle = state['handle']
if handle is not None:
json_str = handle
other = _api_internal._load_json(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None
def const(value, dtype=None):
"""construct a constant"""
if dtype is None:
if isinstance(value, Integral):
dtype = 'int32'
else:
dtype = 'float32'
return _api_internal._const(value, dtype)
def convert_to_node(value):
"""Convert a python value to corresponding node type.
Parameters
----------
value : str
The value to be inspected.
Returns
-------
node : Node
The corresponding node value.
"""
if isinstance(value, NodeBase):
return value
elif isinstance(value, Number):
return const(value)
elif isinstance(value, string_types):
return _api_internal._str(value)
elif isinstance(value, (list, tuple)):
value = [convert_to_node(x) for x in value]
return _api_internal._Array(*value)
elif isinstance(value, dict):
vlist = []
for it in value.items():
if not isinstance(it[0], NodeBase):
raise ValueError("key of map must already been a container type")
vlist.append(it[0])
vlist.append(convert_to_node(it[1]))
return _api_internal._Map(*vlist)
elif isinstance(value, SliceBase):
return value.tensor(*value.indices)
else:
raise ValueError("don't know how to convert type %s to node" % type(value))
def register_node(type_key=None):
"""register node type
Parameters
----------
type_key : str or cls
The type key of the node
"""
if isinstance(type_key, str):
def register(cls):
"""internal register function"""
NODE_TYPE[type_key] = cls
return cls
return register
else:
cls = type_key
NODE_TYPE[cls.__name__] = cls
return cls
...@@ -4,13 +4,9 @@ from __future__ import absolute_import as _abs ...@@ -4,13 +4,9 @@ from __future__ import absolute_import as _abs
import ctypes import ctypes
import numpy as np import numpy as np
from .._base import py_str
class TVMValue(ctypes.Union): tvm_index_t = ctypes.c_uint32
"""TVMValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p)]
class TypeCode(object): class TypeCode(object):
"""Type code used in API calls""" """Type code used in API calls"""
...@@ -19,9 +15,11 @@ class TypeCode(object): ...@@ -19,9 +15,11 @@ class TypeCode(object):
FLOAT = 2 FLOAT = 2
HANDLE = 3 HANDLE = 3
NULL = 4 NULL = 4
NODE_HANDLE = 5 ARRAY_HANDLE = 5
STR = 6 TVM_TYPE = 6
FUNC_HANDLE = 7 NODE_HANDLE = 7
STR = 8
FUNC_HANDLE = 9
def _api_type(code): def _api_type(code):
"""create a type accepted by API""" """create a type accepted by API"""
...@@ -40,13 +38,13 @@ class TVMType(ctypes.Structure): ...@@ -40,13 +38,13 @@ class TVMType(ctypes.Structure):
CODE2STR = { CODE2STR = {
0 : 'int', 0 : 'int',
1 : 'uint', 1 : 'uint',
2 : 'float' 2 : 'float',
4 : 'handle'
} }
def __init__(self, type_str, lanes=1): def __init__(self, type_str, lanes=1):
super(TVMType, self).__init__() super(TVMType, self).__init__()
if isinstance(type_str, np.dtype): if isinstance(type_str, np.dtype):
type_str = str(type_str) type_str = str(type_str)
if type_str.startswith("int"): if type_str.startswith("int"):
self.type_code = 0 self.type_code = 0
bits = int(type_str[3:]) bits = int(type_str[3:])
...@@ -56,6 +54,9 @@ class TVMType(ctypes.Structure): ...@@ -56,6 +54,9 @@ class TVMType(ctypes.Structure):
elif type_str.startswith("float"): elif type_str.startswith("float"):
self.type_code = 2 self.type_code = 2
bits = int(type_str[5:]) bits = int(type_str[5:])
elif type_str.startswith("handle"):
self.type_code = 4
bits = 64
else: else:
raise ValueError("Donot know how to handle type %s" % type_str) raise ValueError("Donot know how to handle type %s" % type_str)
...@@ -71,15 +72,61 @@ class TVMType(ctypes.Structure): ...@@ -71,15 +72,61 @@ class TVMType(ctypes.Structure):
x += "x%d" % self.lanes x += "x%d" % self.lanes
return x return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
def __ne__(self, other):
return not self.__eq__(other)
class TVMValue(ctypes.Union):
"""TVMValue in C API"""
_fields_ = [("v_int64", ctypes.c_int64),
("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p),
("v_type", TVMType)]
TVMPackedCFunc = ctypes.CFUNCTYPE( TVMPackedCFunc = ctypes.CFUNCTYPE(
None, None,
ctypes.POINTER(TVMValue), ctypes.POINTER(TVMValue),
ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int),
ctypes.c_int, ctypes.c_int,
ctypes.c_void_p,
ctypes.c_void_p) ctypes.c_void_p)
TVMCFuncFinalizer = ctypes.CFUNCTYPE( TVMCFuncFinalizer = ctypes.CFUNCTYPE(
None, None,
ctypes.c_void_p) ctypes.c_void_p)
def _return_handle(x):
"""return handle"""
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
return handle
RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.TVM_TYPE: lambda x: x.v_type,
TypeCode.STR: lambda x: py_str(x.v_str)
}
C_TO_PY_ARG_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.TVM_TYPE: lambda x: x.v_type,
TypeCode.STR: lambda x: py_str(x.v_str)
}
...@@ -2,16 +2,23 @@ ...@@ -2,16 +2,23 @@
# pylint: disable=redefined-builtin, undefined-variable, unused-import # pylint: disable=redefined-builtin, undefined-variable, unused-import
"""Functions defined in TVM.""" """Functions defined in TVM."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from numbers import Integral as _Integral from numbers import Integral as _Integral
from ._ctypes._api import _init_api_module, convert, register_func, get_global_func
from ._ctypes._types import TVMType
from ._ctypes._node import register_node, NodeBase
from ._ctypes._node import convert_to_node as _convert_to_node
from ._ctypes._function import Function
from ._ctypes._function import _init_api_functions, register_func, get_global_func
from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
from . import _api_internal from . import _api_internal
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
from . import collections as _collections from . import collections as _collections
int32 = "int32" int32 = TVMType("int32")
float32 = "float32" float32 = TVMType("float32")
handle = "handle" handle = TVMType("handle")
def const(value, dtype=None): def const(value, dtype=None):
"""construct a constant""" """construct a constant"""
...@@ -266,4 +273,25 @@ def Schedule(ops): ...@@ -266,4 +273,25 @@ def Schedule(ops):
return _api_internal._Schedule(ops) return _api_internal._Schedule(ops)
_init_api_module("tvm") def convert(value):
"""Convert value to TVM node or function.
Parameters
----------
value : python value
Returns
-------
tvm_val : Node or function
Converted value in TVM
"""
if isinstance(value, (Function, NodeBase)):
return value
if callable(value):
return _convert_tvm_func(value)
else:
return _convert_to_node(value)
_init_api_functions("tvm")
# pylint: disable=protected-access, no-member # pylint: disable=protected-access, no-member
"""Collection structure in the high level DSL.""" """Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._node import NodeBase, register_node
from . import _api_internal from . import _api_internal
from . import expr as _expr from . import expr as _expr
......
# pylint: disable=protected-access, no-member, missing-docstring # pylint: disable=protected-access, no-member, missing-docstring
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._node import NodeBase, register_node
from . import make as _make from . import make as _make
class ExprOp(object): class ExprOp(object):
......
...@@ -6,11 +6,11 @@ This is a simplified runtime API for quick testing and proptyping. ...@@ -6,11 +6,11 @@ This is a simplified runtime API for quick testing and proptyping.
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import numpy as _np import numpy as _np
from ._ctypes._runtime_api import TVMContext, TVMType, NDArrayBase, FunctionBase from ._ctypes._ndarray import TVMContext, TVMType, NDArrayBase
from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync from ._ctypes._ndarray import cpu, gpu, opencl, empty, sync
from ._ctypes._runtime_api import _init_runtime_module from ._ctypes._ndarray import _init_ndarray_module
from ._ctypes._runtime_api import init_opencl from ._ctypes._ndarray import init_opencl
from ._ctypes._function import Function
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime. """Lightweight NDArray class of TVM runtime.
...@@ -26,11 +26,6 @@ class NDArray(NDArrayBase): ...@@ -26,11 +26,6 @@ class NDArray(NDArrayBase):
pass pass
class Function(FunctionBase):
"""Function class that can executed a generated code."""
pass
def array(arr, ctx=cpu(0)): def array(arr, ctx=cpu(0)):
"""Create an array from source arr. """Create an array from source arr.
...@@ -54,4 +49,4 @@ def array(arr, ctx=cpu(0)): ...@@ -54,4 +49,4 @@ def array(arr, ctx=cpu(0)):
return ret return ret
_init_runtime_module(NDArray, Function) _init_ndarray_module(NDArray)
# pylint: disable=protected-access, no-member # pylint: disable=protected-access, no-member
"""Collection structure in the high level DSL.""" """Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._node import NodeBase, register_node
from . import _api_internal from . import _api_internal
from . import tensor as _tensor from . import tensor as _tensor
......
# pylint: disable=protected-access, no-member, missing-docstring # pylint: disable=protected-access, no-member, missing-docstring
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._node import NodeBase, register_node
class Stmt(NodeBase): class Stmt(NodeBase):
pass pass
......
# pylint: disable=protected-access, no-member, invalid-name # pylint: disable=protected-access, no-member, invalid-name
"""Tensor related abstractions""" """Tensor related abstractions"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, SliceBase, register_node, convert from ._ctypes._node import NodeBase, SliceBase, register_node, convert_to_node
from . import collections as _collections from . import collections as _collections
from . import _api_internal from . import _api_internal
from . import make as _make from . import make as _make
...@@ -26,7 +26,7 @@ class Tensor(NodeBase): ...@@ -26,7 +26,7 @@ class Tensor(NodeBase):
ndim = self.ndim ndim = self.ndim
if len(indices) != ndim: if len(indices) != ndim:
raise ValueError("Need to provide %d index in tensor slice" % ndim) raise ValueError("Need to provide %d index in tensor slice" % ndim)
indices = convert(indices) indices = convert_to_node(indices)
args = [] args = []
for x in indices: for x in indices:
if isinstance(x, _collections.IterVar): if isinstance(x, _collections.IterVar):
......
/*!
* Copyright (c) 2017 by Contributors
* Implementation of basic API functions
* \file api_base.cc
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/api_registry.h>
namespace tvm {
TVM_REGISTER_API(_format_str)
.set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kNodeHandle);
std::ostringstream os;
os << args[0].operator NodeRef();
*ret = os.str();
});
TVM_REGISTER_API(_raw_ptr)
.set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kNodeHandle);
*ret = reinterpret_cast<int64_t>(
args[0].node_sptr().get());
});
TVM_REGISTER_API(_save_json)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SaveJSON(args[0]);
});
TVM_REGISTER_API(_load_json)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = NodeRef(LoadJSON_(args[0]));
});
} // namespace tvm
...@@ -6,55 +6,51 @@ ...@@ -6,55 +6,51 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <tvm/api_registry.h>
#include "./c_api_registry.h"
#include "../codegen/codegen_c.h" #include "../codegen/codegen_c.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_codegen_CompileToC) TVM_REGISTER_API(_codegen_CompileToC)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CodeGenC().Compile(args.at(0), args.at(1)); *ret = CodeGenC().Compile(args[0], args[1]);
}); });
TVM_REGISTER_API(_codegen_MakeAPI) TVM_REGISTER_API(_codegen_MakeAPI)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = MakeAPI( *ret = MakeAPI(
args.at(0), args.at(1), args.at(2), args.at(3)); args[0], args[1], args[2], args[3]);
}); });
TVM_REGISTER_API(_codegen_SplitHostDevice) TVM_REGISTER_API(_codegen_SplitHostDevice)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SplitHostDevice(args.at(0)); *ret = SplitHostDevice(args[0]);
}); });
// generate a dummy packed function for testing // generate a dummy packed function for testing
void DummyHelloFunction(const TVMValue* args, const int* type_code, int num_args) { void DummyHelloFunction(TVMArgs args, TVMRetValue* rv) {
LOG(INFO) << num_args << " arguments"; LOG(INFO) << args.size() << " arguments";
for (int i = 0; i < num_args; ++i) { for (int i = 0; i < args.size(); ++i) {
switch (type_code[i]) { switch (args.type_codes[i]) {
case kNull: LOG(INFO) << i << ":nullptr"; break; case kNull: LOG(INFO) << i << ":nullptr"; break;
case kFloat: LOG(INFO) << i << ": double=" << args[i].v_float64; break; case kFloat: LOG(INFO) << i << ": double=" << args.values[i].v_float64; break;
case kInt: LOG(INFO) << i << ": long=" << args[i].v_int64; break; case kInt: LOG(INFO) << i << ": long=" << args.values[i].v_int64; break;
case kHandle: LOG(INFO) << i << ": handle=" << args[i].v_handle; break; case kHandle: LOG(INFO) << i << ": handle=" << args.values[i].v_handle; break;
default: LOG(FATAL) << "unhandled type " << TVMTypeCode2Str(type_code[i]); case kArrayHandle: LOG(INFO) << i << ": array_handle=" << args.values[i].v_handle; break;
default: LOG(FATAL) << "unhandled type " << runtime::TypeCode2Str(args.type_codes[i]);
} }
} }
} }
TVM_REGISTER_API(_codegen_DummyHelloFunction) TVM_REGISTER_API(_codegen_DummyHelloFunction)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = runtime::PackedFunc(DummyHelloFunction); *ret = runtime::PackedFunc(DummyHelloFunction);
}); });
TVM_REGISTER_API(_codegen_BuildStackVM) TVM_REGISTER_API(_codegen_BuildStackVM)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildStackVM(args.at(0)); *ret = BuildStackVM(args[0]);
}); });
} // namespace codegen } // namespace codegen
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* Implementation of API functions related to IR build * Implementation of API functions related to IR build
* \file c_api_ir.cc * \file api_ir.cc
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <ir/IROperator.h> #include <ir/IROperator.h>
#include "./c_api_registry.h" #include <tvm/api_registry.h>
namespace tvm { namespace tvm {
namespace ir { namespace ir {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_Var) TVM_REGISTER_API(_Var)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Variable::make(args.at(1), args.at(0)); *ret = Variable::make(args[1], args[0]);
}); });
TVM_REGISTER_API(_make_For) TVM_REGISTER_API(_make_For)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = For::make(args.at(0), *ret = For::make(args[0],
args.at(1), args[1],
args.at(2), args[2],
static_cast<ForType>(args.at(3).operator int()), static_cast<ForType>(args[3].operator int()),
static_cast<Halide::DeviceAPI>(args.at(4).operator int()), static_cast<Halide::DeviceAPI>(args[4].operator int()),
args.at(5)); args[5]);
}); });
TVM_REGISTER_API(_make_Realize) TVM_REGISTER_API(_make_Realize)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Realize::make(args.at(0), *ret = Realize::make(args[0],
args.at(1), args[1],
args.at(2), args[2],
args.at(3), args[3],
args.at(4), args[4],
args.at(5)); args[5]);
}); });
TVM_REGISTER_API(_make_Call) TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Call::make(args.at(0), *ret = Call::make(args[0],
args.at(1), args[1],
args.at(2), args[2],
static_cast<Call::CallType>(args.at(3).operator int()), static_cast<Call::CallType>(args[3].operator int()),
args.at(4), args[4],
args.at(5)); args[5]);
}); });
TVM_REGISTER_API(_make_Allocate) TVM_REGISTER_API(_make_Allocate)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Allocate::make(args.at(0), *ret = Allocate::make(args[0],
args.at(1), args[1],
args.at(2), args[2],
args.at(3), args[3],
args.at(4)); args[4]);
}); });
// make from two arguments // make from two arguments
#define REGISTER_MAKE1(Node) \ #define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \ TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args.at(0)); \ *ret = Node::make(args[0]); \
}) \ }) \
#define REGISTER_MAKE2(Node) \ #define REGISTER_MAKE2(Node) \
TVM_REGISTER_API(_make_## Node) \ TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args.at(0), args.at(1)); \ *ret = Node::make(args[0], args[1]); \
}) \ }) \
#define REGISTER_MAKE3(Node) \ #define REGISTER_MAKE3(Node) \
TVM_REGISTER_API(_make_## Node) \ TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args.at(0), args.at(1), args.at(2)); \ *ret = Node::make(args[0], args[1], args[2]); \
}) \ }) \
#define REGISTER_MAKE4(Node) \ #define REGISTER_MAKE4(Node) \
TVM_REGISTER_API(_make_## Node) \ TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args.at(0), args.at(1), args.at(2), args.at(3)); \ *ret = Node::make(args[0], args[1], args[2], args[3]); \
}) \ }) \
#define REGISTER_MAKE_BINARY_OP(Node) \ #define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API(_make_## Node) \ TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
Expr a = args.at(0), b = args.at(1); \ Expr a = args[0], b = args[1]; \
match_types(a, b); \ match_types(a, b); \
*ret = Node::make(a, b); \ *ret = Node::make(a, b); \
}) \ })
.add_argument("lhs", "Expr", "left operand") \
.add_argument("rhs", "Expr", "right operand")
REGISTER_MAKE3(Reduce); REGISTER_MAKE3(Reduce);
REGISTER_MAKE4(AttrStmt); REGISTER_MAKE4(AttrStmt);
......
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to Higher DSL build.
* \file api_lang.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/tensor.h>
#include <tvm/buffer.h>
#include <tvm/schedule.h>
#include <tvm/api_registry.h>
namespace tvm {
TVM_REGISTER_API(_const)
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[0].type_code() == kInt) {
*ret = make_const(args[1], args[0].operator int64_t());
} else if (args[0].type_code() == kFloat) {
*ret = make_const(args[1], args[0].operator double());
} else {
LOG(FATAL) << "only accept int or float";
}
});
TVM_REGISTER_API(_str)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ir::StringImm::make(args[0]);
});
TVM_REGISTER_API(_Array)
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<std::shared_ptr<Node> > data;
for (int i = 0; i < args.size(); ++i) {
data.push_back(args[i].node_sptr());
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
*ret = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](TVMArgs args, TVMRetValue* ret) {
int64_t i = args[1];
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(sptr.get());
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
*ret = n->data[i];
});
TVM_REGISTER_API(_ArraySize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
TVM_REGISTER_API(_Map)
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kNodeHandle)
<< "need content of array to be NodeBase";
CHECK(args[i + 1].type_code() == kNodeHandle)
<< "need content of array to be NodeBase";
data.emplace(std::make_pair(args[i].node_sptr(),
args[i + 1].node_sptr()));
}
auto node = std::make_shared<MapNode>();
node->data = std::move(data);
*ret = node;
});
TVM_REGISTER_API(_MapSize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
*ret = static_cast<int64_t>(n->data.size());
});
TVM_REGISTER_API(_MapGetItem)
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args[0].type_code() == kNodeHandle);
CHECK(args[1].type_code() == kNodeHandle);
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
auto it = n->data.find(args[1].node_sptr());
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
*ret = (*it).second;
});
TVM_REGISTER_API(_MapCount)
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args[0].type_code() == kNodeHandle);
CHECK(args[1].type_code() == kNodeHandle);
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
*ret = static_cast<int64_t>(
n->data.count(args[1].node_sptr()));
});
TVM_REGISTER_API(_MapItems)
.set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second);
}
*ret = rkvs;
});
TVM_REGISTER_API(Range)
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
*ret = Range(0, args[0]);
} else {
*ret = Range(args[0], args[1]);
}
});
TVM_REGISTER_API(_Buffer)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BufferNode::make(args[0],
args[1],
args[2],
args[3],
args[4]);
});
TVM_REGISTER_API(_Tensor)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorNode::make(args[0],
args[1],
args[2],
args[3]);
});
TVM_REGISTER_API(_TensorEqual)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Tensor() == args[1].operator Tensor();
});
TVM_REGISTER_API(_TensorHash)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<int64_t>(
std::hash<Tensor>()(args[0].operator Tensor()));
});
TVM_REGISTER_API(_Placeholder)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Placeholder(args[0],
args[1],
args[2]);
});
TVM_REGISTER_API(_ComputeOp)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ComputeOpNode::make(args[0],
args[1],
args[2]);
});
TVM_REGISTER_API(_OpGetOutput)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Operation().output(
args[1].operator int64_t());
});
TVM_REGISTER_API(_IterVar)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = IterVar(args[0], args[1], args[2]);
});
TVM_REGISTER_API(_Schedule)
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Schedule(args[0].operator Array<Operation>());
});
TVM_REGISTER_API(_StageSetScope)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.set_scope(args[1]);
});
TVM_REGISTER_API(_StageSplitByFactor)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
args[0].operator Stage()
.split(args[1], &outer, &inner, args[2]);
*ret = Array<IterVar>({outer, inner});
});
TVM_REGISTER_API(_StageSplitByOuter)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar inner;
args[0].operator Stage()
.split(args[1], args[2], &inner, args[3]);
*ret = inner;
});
TVM_REGISTER_API(_StageFuse)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar fused;
args[0].operator Stage()
.split(args[1], args[2], &fused);
*ret = fused;
});
TVM_REGISTER_API(_StageComputeAt)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.compute_at(args[1], args[2]);
});
TVM_REGISTER_API(_StageComputeInline)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.compute_inline();
});
TVM_REGISTER_API(_StageComputeRoot)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.compute_root();
});
TVM_REGISTER_API(_StageReorder)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.reorder(args[1]);
});
TVM_REGISTER_API(_StageTile)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar x_outer, y_outer, x_inner, y_inner;
args[0].operator Stage()
.tile(args[1], args[2], &x_outer, &y_outer,
&x_inner, &y_inner, args[3], args[4]);
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
} // namespace tvm
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2017 by Contributors
* Exposre of pass functions. * Exposre of pass functions.
* \file c_api_pass.cc * \file api_pass.cc
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "./c_api_registry.h" #include <tvm/api_registry.h>
namespace tvm { namespace tvm {
namespace ir { namespace ir {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_pass_Simplify) TVM_REGISTER_API(_pass_Simplify)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) { if (args[0].IsNodeType<Stmt>()) {
*ret = Simplify(args.at(0).operator Stmt()); *ret = Simplify(args[0].operator Stmt());
} else { } else {
*ret = Simplify(args.at(0).operator Expr()); *ret = Simplify(args[0].operator Expr());
} }
}); });
TVM_REGISTER_API(_pass_Equal) TVM_REGISTER_API(_pass_Equal)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) { if (args[0].IsNodeType<Stmt>()) {
*ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
} else { } else {
*ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr()); *ret = Equal(args[0].operator Expr(), args[1].operator Expr());
} }
}); });
// make from two arguments // make from two arguments
#define REGISTER_PASS1(PassName) \ #define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \ TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args.at(0)); \ *ret = PassName(args[0]); \
}) \ }) \
#define REGISTER_PASS2(PassName) \ #define REGISTER_PASS2(PassName) \
TVM_REGISTER_API(_pass_## PassName) \ TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args.at(0), args.at(1)); \ *ret = PassName(args[0], args[1]); \
}) \ }) \
#define REGISTER_PASS4(PassName) \ #define REGISTER_PASS4(PassName) \
TVM_REGISTER_API(_pass_## PassName) \ TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args.at(0), args.at(1), args.at(2), args.at(3)); \ *ret = PassName(args[0], args[1], args[2], args[3]); \
}) \ }) \
REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(ConvertSSA);
......
/*!
* Copyright (c) 2017 by Contributors
* \file api_registry.cc
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/api_registry.h>
#include <memory>
namespace tvm {
struct APIManager {
std::unordered_map<std::string, std::unique_ptr<APIRegistry> > fmap;
static APIManager* Global() {
static APIManager inst;
return &inst;
}
};
APIRegistry& APIRegistry::__REGISTER__(const std::string& name) { // NOLINT(*)
APIManager* m = APIManager::Global();
CHECK(!m->fmap.count(name))
<< "API function " << name << " has already been registered";
std::unique_ptr<APIRegistry> p(new APIRegistry());
p->name_ = name;
m->fmap[name] = std::move(p);
return *(m->fmap[name]);
}
APIRegistry& APIRegistry::set_body(PackedFunc f) { // NOLINT(*)
PackedFunc::RegisterGlobal(name_, f);
return *this;
}
} // namespace tvm
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2017 by Contributors
* Implementation of API functions related to schedule pass. * Implementation of API functions related to schedule pass.
* \file c_api_lang.cc * \file api_schedule.cc
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <tvm/schedule_pass.h> #include <tvm/schedule_pass.h>
#include "./c_api_registry.h" #include <tvm/api_registry.h>
#include "../schedule/graph.h" #include "../schedule/graph.h"
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
#define REGISTER_SCHEDULE_PASS1(PassName) \ #define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \ TVM_REGISTER_API(_schedule_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args.at(0)); \ *ret = PassName(args[0]); \
}) \ }) \
#define REGISTER_SCHEDULE_PASS2(PassName) \ #define REGISTER_SCHEDULE_PASS2(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \ TVM_REGISTER_API(_schedule_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args.at(0), args.at(1)); \ *ret = PassName(args[0], args[1]); \
}) \ }) \
......
...@@ -3,9 +3,16 @@ ...@@ -3,9 +3,16 @@
* Implementation of C API * Implementation of C API
* \file c_api.cc * \file c_api.cc
*/ */
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/c_api.h> #include <tvm/c_api.h>
#include "./c_api_common.h" #include <tvm/api_registry.h>
#include "./c_api_registry.h" #include <vector>
#include <string>
#include <exception>
#include "../runtime/runtime_base.h"
/*! \brief entry to to easily hold returning information */ /*! \brief entry to to easily hold returning information */
struct TVMAPIThreadLocalEntry { struct TVMAPIThreadLocalEntry {
...@@ -13,16 +20,8 @@ struct TVMAPIThreadLocalEntry { ...@@ -13,16 +20,8 @@ struct TVMAPIThreadLocalEntry {
std::vector<std::string> ret_vec_str; std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */ /*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp; std::vector<const char *> ret_vec_charp;
/*! \brief argument stack */ /*! \brief result holder for retruning string */
std::vector<tvm::APIVariantValue> arg_stack; std::string ret_str;
/*! \brief return value */
tvm::APIVariantValue ret_value;
// clear calling stack
inline void Clear() {
arg_stack.clear();
ret_value.sptr.reset();
}
inline void SetReturn(TVMValue* ret_val, int* ret_type_code);
}; };
using namespace tvm; using namespace tvm;
...@@ -34,7 +33,7 @@ using TVMAPINode = std::shared_ptr<Node>; ...@@ -34,7 +33,7 @@ using TVMAPINode = std::shared_ptr<Node>;
struct APIAttrGetter : public AttrVisitor { struct APIAttrGetter : public AttrVisitor {
std::string skey; std::string skey;
APIVariantValue* ret; TVMRetValue* ret;
bool found_node_ref{false}; bool found_node_ref{false};
void Visit(const char* key, double* value) final { void Visit(const char* key, double* value) final {
...@@ -97,95 +96,6 @@ struct APIAttrDir : public AttrVisitor { ...@@ -97,95 +96,6 @@ struct APIAttrDir : public AttrVisitor {
} }
}; };
int TVMListAPIFuncNames(int *out_size,
const char*** out_array) {
API_BEGIN();
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
ret->ret_vec_str = dmlc::Registry<APIFuncReg>::ListAllNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<int>(ret->ret_vec_str.size());
API_END();
}
int TVMGetAPIFuncHandle(const char* fname,
APIFuncHandle* out) {
API_BEGIN();
const APIFuncReg* reg = dmlc::Registry<APIFuncReg>::Find(fname);
CHECK(reg != nullptr) << "cannot find function " << fname;
*out = (APIFuncHandle)reg;
API_END();
}
int TVMGetAPIFuncInfo(APIFuncHandle handle,
const char **real_name,
const char **description,
int *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type) {
const auto *op = static_cast<const APIFuncReg *>(handle);
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
*real_name = op->name.c_str();
*description = op->description.c_str();
*num_doc_args = static_cast<int>(op->arguments.size());
if (return_type) *return_type = nullptr;
ret->ret_vec_charp.clear();
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].name.c_str());
}
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str());
}
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].description.c_str());
}
*arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
*arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size();
*arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2);
API_END();
}
int TVMAPIPushStack(TVMValue arg,
int type_code) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
ret->arg_stack.resize(ret->arg_stack.size() + 1);
APIVariantValue& v = ret->arg_stack.back();
v.type_code = type_code;
switch (type_code) {
case kInt: case kUInt: case kFloat: case kNull: {
v.v_union = arg; break;
}
case kStr: {
v.str = arg.v_str; break;
}
case kNodeHandle: {
v.sptr = *static_cast<TVMAPINode*>(arg.v_handle); break;
}
default: LOG(FATAL) << "TVM API cannot take type " << TVMTypeCode2Str(type_code);
}
API_END_HANDLE_ERROR(ret->Clear());
}
int TVMAPIFuncCall(APIFuncHandle handle,
TVMValue* ret_val,
int* ret_type_code) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
const auto *op = static_cast<const APIFuncReg *>(handle);
op->body(ret->arg_stack, &(ret->ret_value));
ret->SetReturn(ret_val, ret_type_code);
ret->arg_stack.clear();
API_END_HANDLE_ERROR(ret->Clear());
}
int TVMNodeFree(NodeHandle handle) { int TVMNodeFree(NodeHandle handle) {
API_BEGIN(); API_BEGIN();
...@@ -198,12 +108,11 @@ int TVMNodeGetAttr(NodeHandle handle, ...@@ -198,12 +108,11 @@ int TVMNodeGetAttr(NodeHandle handle,
TVMValue* ret_val, TVMValue* ret_val,
int* ret_type_code, int* ret_type_code,
int* ret_success) { int* ret_success) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
ret->ret_value.type_code = kNull; TVMRetValue rv;
APIAttrGetter getter; APIAttrGetter getter;
getter.skey = key; getter.skey = key;
getter.ret = &(ret->ret_value); getter.ret = &rv;
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle); TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
if (getter.skey == "type_key") { if (getter.skey == "type_key") {
ret_val->v_str = (*tnode)->type_key(); ret_val->v_str = (*tnode)->type_key();
...@@ -211,15 +120,17 @@ int TVMNodeGetAttr(NodeHandle handle, ...@@ -211,15 +120,17 @@ int TVMNodeGetAttr(NodeHandle handle,
*ret_success = 1; *ret_success = 1;
} else { } else {
(*tnode)->VisitAttrs(&getter); (*tnode)->VisitAttrs(&getter);
if (ret->ret_value.type_code != kNull) { *ret_success = getter.found_node_ref || rv.type_code() != kNull;
ret->SetReturn(ret_val, ret_type_code); if (rv.type_code() == kStr) {
*ret_success = 1; TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get();
e->ret_str = rv.operator std::string();
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
} else { } else {
*ret_success = getter.found_node_ref ? 1 : 0; rv.MoveToCHost(ret_val, ret_type_code);
*ret_type_code = kNull;
} }
} }
API_END_HANDLE_ERROR(ret->Clear()); API_END();
} }
int TVMNodeListAttrNames(NodeHandle handle, int TVMNodeListAttrNames(NodeHandle handle,
...@@ -240,21 +151,3 @@ int TVMNodeListAttrNames(NodeHandle handle, ...@@ -240,21 +151,3 @@ int TVMNodeListAttrNames(NodeHandle handle,
*out_size = static_cast<int>(ret->ret_vec_str.size()); *out_size = static_cast<int>(ret->ret_vec_str.size());
API_END(); API_END();
} }
inline void TVMAPIThreadLocalEntry::SetReturn(TVMValue* ret_val,
int* ret_type_code) {
APIVariantValue& rv = ret_value;
*ret_type_code = rv.type_code;
if (rv.type_code == kNodeHandle) {
if (rv.sptr.get() != nullptr) {
ret_val->v_handle = new TVMAPINode(std::move(rv.sptr));
} else {
ret_val->v_handle = nullptr;
}
} else if (rv.type_code == kFuncHandle) {
ret_val->v_handle = new runtime::PackedFunc::FType(std::move(rv.func));
} else {
*ret_val = rv.v_union;
}
}
...@@ -42,5 +42,78 @@ inline Type String2Type(std::string s) { ...@@ -42,5 +42,78 @@ inline Type String2Type(std::string s) {
return Type(code, bits, lanes); return Type(code, bits, lanes);
} }
inline const char* TVMTypeCode2Str(int type_code) {
switch (type_code) {
case kInt: return "int";
case kFloat: return "float";
case kStr: return "str";
case kHandle: return "Handle";
case kNull: return "NULL";
case kNodeHandle: return "NodeHandle";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}
template<typename T>
struct NodeTypeChecker {
static inline bool Check(Node* sptr) {
// This is the only place in the project where RTTI is used
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};
template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
if (!NodeTypeChecker<T>::Check(p.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "array<";
NodeTypeChecker<T>::PrintName(os);
os << ">";
}
};
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<MapNode>()) return false;
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<";
NodeTypeChecker<K>::PrintName(os);
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};
template<typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
NodeTypeChecker<T>::PrintName(os);
return os.str();
}
} // namespace tvm } // namespace tvm
#endif // TVM_BASE_COMMON_H_ #endif // TVM_BASE_COMMON_H_
/*!
* Copyright (c) 2016 by Contributors
* \file c_api_common.h
* \brief Common fields of all C APIs
*/
#ifndef TVM_C_API_C_API_COMMON_H_
#define TVM_C_API_C_API_COMMON_H_
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/c_api.h>
#include <vector>
#include <string>
#include <exception>
#include "./c_api_registry.h"
#include "../runtime/runtime_base.h"
#endif // TVM_C_API_C_API_COMMON_H_
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions
* \file c_api_impl.cc
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include "./c_api_registry.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::APIFuncReg);
} // namespace dmlc
namespace tvm {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_format_str)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_code == kNodeHandle);
std::ostringstream os;
os << args.at(0).operator NodeRef();
*ret = os.str();
})
.add_argument("expr", "Node", "expression to be printed");
TVM_REGISTER_API(_raw_ptr)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_code == kNodeHandle);
*ret = reinterpret_cast<int64_t>(args.at(0).sptr.get());
})
.add_argument("src", "NodeBase", "the node base");
TVM_REGISTER_API(_save_json)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = SaveJSON(args.at(0));
})
.add_argument("src", "json_str", "the node ");
TVM_REGISTER_API(_load_json)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = NodeRef(LoadJSON_(args.at(0)));
})
.add_argument("src", "NodeBase", "the node");
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to Higher DSL build.
* \file c_api_lang.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/tensor.h>
#include <tvm/buffer.h>
#include <tvm/schedule.h>
#include "./c_api_registry.h"
namespace tvm {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.at(0).type_code == kInt) {
*ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_code == kFloat) {
*ret = make_const(args.at(1), args.at(0).operator double());
} else {
LOG(FATAL) << "only accept int or float";
}
})
.add_argument("src", "Number", "source number")
.add_argument("dtype", "str", "data type");
TVM_REGISTER_API(_str)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = ir::StringImm::make(args.at(0));
});
TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_code == kNodeHandle)
<< "need content of array to be NodeBase";
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_code = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_code == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(sptr.get());
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_code = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
TVM_REGISTER_API(_Map)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK_EQ(args.size() % 2, 0U);
MapNode::ContainerType data;
for (size_t i = 0; i < args.size(); i += 2) {
CHECK(args.at(i).type_code == kNodeHandle)
<< "need content of array to be NodeBase";
CHECK(args.at(i + 1).type_code == kNodeHandle)
<< "need content of array to be NodeBase";
data.emplace(std::make_pair(args.at(i).sptr, args.at(i + 1).sptr));
}
auto node = std::make_shared<MapNode>();
node->data = std::move(data);
ret->type_code = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_MapSize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
*ret = static_cast<int64_t>(n->data.size());
});
TVM_REGISTER_API(_MapGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_code == kNodeHandle);
CHECK(args.at(1).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
auto it = n->data.find(args.at(1).sptr);
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
ret->sptr = (*it).second;
ret->type_code = kNodeHandle;
});
TVM_REGISTER_API(_MapCount)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_code == kNodeHandle);
CHECK(args.at(1).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
*ret = static_cast<int64_t>(n->data.count(args.at(1).sptr));
});
TVM_REGISTER_API(_MapItems)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_code == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second);
}
ret->sptr = rkvs;
ret->type_code = kNodeHandle;
});
TVM_REGISTER_API(Range)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.size() == 1) {
*ret = Range(0, args.at(0));
} else {
*ret = Range(args.at(0), args.at(1));
}
})
.describe("create a domain range")
.add_argument("begin", "Expr", "beginning of the range.")
.add_argument("end", "Expr", "extent of the range");
TVM_REGISTER_API(_Buffer)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = BufferNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
});
TVM_REGISTER_API(_Tensor)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3));
});
TVM_REGISTER_API(_TensorEqual)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = args.at(0).operator Tensor() == args.at(1).operator Tensor();
});
TVM_REGISTER_API(_TensorHash)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = static_cast<int64_t>(
std::hash<Tensor>()(args.at(0).operator Tensor()));
});
TVM_REGISTER_API(_Placeholder)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Placeholder(args.at(0),
args.at(1),
args.at(2));
});
TVM_REGISTER_API(_ComputeOp)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = ComputeOpNode::make(args.at(0),
args.at(1),
args.at(2));
});
TVM_REGISTER_API(_OpGetOutput)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = args.at(0).operator Operation().output(
args.at(1).operator int64_t());
});
TVM_REGISTER_API(_IterVar)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = IterVar(args.at(0), args.at(1), args.at(2));
});
TVM_REGISTER_API(_Schedule)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Schedule(args.at(0).operator Array<Operation>());
});
TVM_REGISTER_API(_StageSetScope)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Stage()
.set_scope(args.at(1));
});
TVM_REGISTER_API(_StageSplitByFactor)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar outer, inner;
args.at(0).operator Stage()
.split(args.at(1), &outer, &inner, args.at(2));
*ret = Array<IterVar>({outer, inner});
});
TVM_REGISTER_API(_StageSplitByOuter)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar inner;
args.at(0).operator Stage()
.split(args.at(1), args.at(2), &inner, args.at(3));
*ret = inner;
});
TVM_REGISTER_API(_StageFuse)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar fused;
args.at(0).operator Stage()
.split(args.at(1), args.at(2), &fused);
*ret = fused;
});
TVM_REGISTER_API(_StageComputeAt)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Stage()
.compute_at(args.at(1), args.at(2));
});
TVM_REGISTER_API(_StageComputeInline)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Stage()
.compute_inline();
});
TVM_REGISTER_API(_StageComputeRoot)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Stage()
.compute_root();
});
TVM_REGISTER_API(_StageReorder)
.set_body([](const ArgStack& args, RetValue *ret) {
args.at(0).operator Stage()
.reorder(args.at(1));
});
TVM_REGISTER_API(_StageTile)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar x_outer, y_outer, x_inner, y_inner;
args.at(0).operator Stage()
.tile(args.at(1), args.at(2), &x_outer, &y_outer,
&x_inner, &y_inner, args.at(3), args.at(4));
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file c_api_registry.h
* \brief Quick registry for C API.
*/
#ifndef TVM_C_API_C_API_REGISTRY_H_
#define TVM_C_API_C_API_REGISTRY_H_
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/c_api.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
#include <limits>
#include <string>
#include <vector>
#include "../base/common.h"
namespace tvm {
inline const char* TVMTypeCode2Str(int type_code) {
switch (type_code) {
case kInt: return "int";
case kFloat: return "float";
case kStr: return "str";
case kHandle: return "Handle";
case kNull: return "NULL";
case kNodeHandle: return "NodeHandle";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}
template<typename T>
struct NodeTypeChecker {
static inline bool Check(Node* sptr) {
// This is the only place in the project where RTTI is used
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};
template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
if (!NodeTypeChecker<T>::Check(p.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "array<";
NodeTypeChecker<T>::PrintName(os);
os << ">";
}
};
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<MapNode>()) return false;
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<";
NodeTypeChecker<K>::PrintName(os);
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};
template<typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
NodeTypeChecker<T>::PrintName(os);
return os.str();
}
/*! \brief Variant container for API calls */
class APIVariantValue {
public:
/*! \brief the type id */
int type_code{kNull};
/*! \brief shared pointer container */
std::shared_ptr<Node> sptr;
/*! \brief string container */
std::string str;
/*! \brief the variant holder */
TVMValue v_union;
/*! \brief std::function */
runtime::PackedFunc::FType func;
// constructor
APIVariantValue() {
}
// clear value
inline void Clear() {
}
// assign op
inline APIVariantValue& operator=(double value) {
type_code = kFloat;
v_union.v_float64 = value;
return *this;
}
inline APIVariantValue& operator=(std::nullptr_t value) {
type_code = kHandle;
v_union.v_handle = value;
return *this;
}
inline APIVariantValue& operator=(int64_t value) {
type_code = kInt;
v_union.v_int64 = value;
return *this;
}
inline APIVariantValue& operator=(bool value) {
type_code = kInt;
v_union.v_int64 = value;
return *this;
}
inline APIVariantValue& operator=(std::string value) {
type_code = kStr;
str = std::move(value);
v_union.v_str = str.c_str();
return *this;
}
inline APIVariantValue& operator=(const NodeRef& ref) {
if (ref.node_.get() == nullptr) {
type_code = kNull;
} else {
type_code = kNodeHandle;
this->sptr = ref.node_;
}
return *this;
}
inline APIVariantValue& operator=(const runtime::PackedFunc& f) {
type_code = kFuncHandle;
this->func = f.body();
return *this;
}
inline APIVariantValue& operator=(const Type& value) {
return operator=(Type2String(value));
}
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
inline operator T() const {
if (type_code == kNull) return T();
CHECK_EQ(type_code, kNodeHandle);
CHECK(NodeTypeChecker<T>::Check(sptr.get()))
<< "Did not get expected type " << NodeTypeName<T>();
return T(sptr);
}
inline operator Expr() const {
if (type_code == kNull) {
return Expr();
}
if (type_code == kInt) return Expr(operator int());
if (type_code == kFloat) {
return Expr(static_cast<float>(operator double()));
}
CHECK_EQ(type_code, kNodeHandle);
if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
} else {
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
<< "did not pass in Expr in a place need Expr";
return Expr(sptr);
}
}
inline operator double() const {
CHECK_EQ(type_code, kFloat);
return v_union.v_float64;
}
inline operator int64_t() const {
CHECK_EQ(type_code, kInt);
return v_union.v_int64;
}
inline operator uint64_t() const {
CHECK_EQ(type_code, kInt);
return v_union.v_int64;
}
inline operator int() const {
CHECK_EQ(type_code, kInt);
CHECK_LE(v_union.v_int64,
std::numeric_limits<int>::max());
return v_union.v_int64;
}
inline operator bool() const {
CHECK_EQ(type_code, kInt)
<< "expect boolean(int) but get "
<< TVMTypeCode2Str(type_code);
return v_union.v_int64 != 0;
}
inline operator std::string() const {
CHECK_EQ(type_code, kStr)
<< "expect Str but get "
<< TVMTypeCode2Str(type_code);
return str;
}
inline operator Type() const {
return String2Type(operator std::string());
}
inline operator runtime::PackedFunc() const {
CHECK_EQ(type_code, kFuncHandle);
return runtime::PackedFunc(func);
}
};
// common defintiion of API function.
using APIFunc = std::function<
void(const std::vector<APIVariantValue> &args, APIVariantValue* ret)>;
/*!
* \brief Registry entry for DataIterator factory functions.
*/
struct APIFuncReg
: public dmlc::FunctionRegEntryBase<APIFuncReg,
APIFunc> {
};
#define TVM_REGISTER_API(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::APIFuncReg, APIFuncReg, TypeName) \
} // namespace tvm
#endif // TVM_C_API_C_API_REGISTRY_H_
...@@ -12,19 +12,22 @@ using namespace ir; ...@@ -12,19 +12,22 @@ using namespace ir;
runtime::PackedFunc BuildStackVM(LoweredFunc func) { runtime::PackedFunc BuildStackVM(LoweredFunc func) {
StackVM vm = codegen::CodeGenStackVM().Compile(func); StackVM vm = codegen::CodeGenStackVM().Compile(func);
auto f = [vm](const TVMValue* args, const int* type_codes, int num_args) { using runtime::TVMArgs;
LOG(INFO) << "Run stack VM"; using runtime::TVMRetValue;
auto f = [vm](TVMArgs args, TVMRetValue* rv) {
StackVM::State* s = StackVM::ThreadLocalState(); StackVM::State* s = StackVM::ThreadLocalState();
s->sp = 0; s->sp = 0;
s->pc = 0; s->pc = 0;
if (s->heap.size() < vm.heap_size) { if (s->heap.size() < vm.heap_size) {
s->heap.resize(vm.heap_size); s->heap.resize(vm.heap_size);
} }
s->heap[0].v_handle = (void*)args; // NOLINT(*) s->heap[0].v_handle = (void*)args.values; // NOLINT(*)
s->heap[1].v_handle = (void*)type_codes; // NOLINT(*) s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*)
s->heap[2].v_int64 = num_args; s->heap[2].v_int64 = args.num_args;
vm.Run(s); vm.Run(s);
}; };
return runtime::PackedFunc(f); return runtime::PackedFunc(f);
} }
...@@ -118,6 +121,9 @@ int CodeGenStackVM::GetGlobalFuncID(std::string name) { ...@@ -118,6 +121,9 @@ int CodeGenStackVM::GetGlobalFuncID(std::string name) {
auto it = fun_idmap_.find(name); auto it = fun_idmap_.find(name);
if (it != fun_idmap_.end()) return it->second; if (it != fun_idmap_.end()) return it->second;
using runtime::PackedFunc; using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
PackedFunc f = PackedFunc::GetGlobal(name); PackedFunc f = PackedFunc::GetGlobal(name);
auto extern_f = [f](const TVMValue* args, int num_args) { auto extern_f = [f](const TVMValue* args, int num_args) {
CHECK_EQ(num_args % 2, 0); CHECK_EQ(num_args % 2, 0);
...@@ -128,7 +134,8 @@ int CodeGenStackVM::GetGlobalFuncID(std::string name) { ...@@ -128,7 +134,8 @@ int CodeGenStackVM::GetGlobalFuncID(std::string name) {
int code = (tcode >> (8 * 3)) & 255; int code = (tcode >> (8 * 3)) & 255;
type_codes[i] = code; type_codes[i] = code;
} }
f.CallPacked(args, &type_codes[0], num_args); TVMRetValue rv;
f.CallPacked(TVMArgs(args, &type_codes[0], num_args), &rv);
TVMValue r; r.v_int64 = 0; TVMValue r; r.v_int64 = 0;
return r; return r;
}; };
......
...@@ -136,7 +136,6 @@ class HostDeviceSplitter : public IRMutator { ...@@ -136,7 +136,6 @@ class HostDeviceSplitter : public IRMutator {
public: public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == "thread_extent") { if (op->type_key == "thread_extent") {
LOG(INFO) << "??";
IterVar iv(op->node.node_); IterVar iv(op->node.node_);
return SplitDeviceFunc(s); return SplitDeviceFunc(s);
} }
......
...@@ -302,7 +302,8 @@ void StackVM::Run(State* s) const { ...@@ -302,7 +302,8 @@ void StackVM::Run(State* s) const {
STACK_VM_TVM_LOAD_ARG(tc == kFloat, "float"); break; STACK_VM_TVM_LOAD_ARG(tc == kFloat, "float"); break;
} }
case TVM_LOAD_ARG_HANDLE: { case TVM_LOAD_ARG_HANDLE: {
STACK_VM_TVM_LOAD_ARG(tc == kHandle || tc == kNull, "handle"); break; STACK_VM_TVM_LOAD_ARG(
tc == kHandle || tc == kNull || tc == kArrayHandle, "handle"); break;
} }
case TVM_ARRAY_GET_DATA: { case TVM_ARRAY_GET_DATA: {
STACK_VM_TVM_ARRARY_GET(v_handle, void*, data); break; STACK_VM_TVM_ARRARY_GET(v_handle, void*, data); break;
...@@ -317,7 +318,7 @@ void StackVM::Run(State* s) const { ...@@ -317,7 +318,7 @@ void StackVM::Run(State* s) const {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break; STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break;
} }
case TVM_ARRAY_GET_TYPE_CODE: { case TVM_ARRAY_GET_TYPE_CODE: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.type_code); break; STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.code); break;
} }
case TVM_ARRAY_GET_TYPE_BITS: { case TVM_ARRAY_GET_TYPE_BITS: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.bits); break; STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.bits); break;
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
* \file c_runtime_api.cc * \file c_runtime_api.cc
* \brief Device specific implementations * \brief Device specific implementations
*/ */
#include <dmlc/thread_local.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <algorithm> #include <algorithm>
#include <string>
#include "./runtime_base.h" #include "./runtime_base.h"
#include "./device_api.h" #include "./device_api.h"
...@@ -37,7 +39,7 @@ inline void TVMArrayFree_(TVMArray* arr) { ...@@ -37,7 +39,7 @@ inline void TVMArrayFree_(TVMArray* arr) {
inline void VerifyType(TVMType dtype) { inline void VerifyType(TVMType dtype) {
CHECK_GE(dtype.lanes, 1U); CHECK_GE(dtype.lanes, 1U);
if (dtype.type_code == kFloat) { if (dtype.code == kFloat) {
CHECK_EQ(dtype.bits % 32U, 0U); CHECK_EQ(dtype.bits % 32U, 0U);
} else { } else {
CHECK_EQ(dtype.bits % 8U, 0U); CHECK_EQ(dtype.bits % 8U, 0U);
...@@ -65,6 +67,12 @@ inline size_t GetDataAlignment(TVMArray* arr) { ...@@ -65,6 +67,12 @@ inline size_t GetDataAlignment(TVMArray* arr) {
using namespace tvm::runtime; using namespace tvm::runtime;
struct TVMRuntimeEntry {
std::string ret_str;
};
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
int TVMDeviceInit(int dev_mask, int TVMDeviceInit(int dev_mask,
const char** option_keys, const char** option_keys,
const char** option_vals, const char** option_vals,
...@@ -177,10 +185,31 @@ int TVMFuncFree(TVMFunctionHandle func) { ...@@ -177,10 +185,31 @@ int TVMFuncFree(TVMFunctionHandle func) {
int TVMFuncCall(TVMFunctionHandle func, int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args, TVMValue* args,
int* arg_type_codes, int* arg_type_codes,
int num_args) { int num_args,
TVMValue* ret_val,
int* ret_type_code) {
API_BEGIN(); API_BEGIN();
TVMRetValue rv;
(*static_cast<const PackedFunc*>(func)).CallPacked( (*static_cast<const PackedFunc*>(func)).CallPacked(
args, arg_type_codes, num_args); TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kStr) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
e->ret_str = rv.operator std::string();
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
API_END();
}
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue value,
int type_code) {
API_BEGIN();
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
*rv = TVMArgValue(value, type_code);
API_END(); API_END();
} }
...@@ -191,22 +220,18 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, ...@@ -191,22 +220,18 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
API_BEGIN(); API_BEGIN();
if (fin == nullptr) { if (fin == nullptr) {
*out = new PackedFunc( *out = new PackedFunc(
[func, resource_handle](const TVMValue* args, [func, resource_handle](TVMArgs args, TVMRetValue* rv) {
const int* type_codes, func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
int num_args) { args.num_args, rv, resource_handle);
func((TVMValue*)args, (int*)type_codes, // NOLINT(*)
num_args, resource_handle);
}); });
} else { } else {
// wrap it in a shared_ptr, with fin as deleter. // wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope. // so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin); std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc( *out = new PackedFunc(
[func, rpack](const TVMValue* args, [func, rpack](TVMArgs args, TVMRetValue* rv) {
const int* type_codes, func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
int num_args) { args.num_args, rv, rpack.get());
func((TVMValue*)args, (int*)type_codes, // NOLINT(*)
num_args, rpack.get());
}); });
} }
API_END(); API_END();
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief The global registry of packed function. * \brief The global registry of packed function.
*/ */
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <unordered_map> #include <unordered_map>
#include <memory> #include <memory>
...@@ -58,6 +59,18 @@ std::vector<std::string> PackedFunc::ListGlobalNames() { ...@@ -58,6 +59,18 @@ std::vector<std::string> PackedFunc::ListGlobalNames() {
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
/*! \brief entry to to easily hold returning information */
struct TVMFuncThreadLocalEntry {
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
};
/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;
int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) { int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) {
using tvm::runtime::PackedFunc; using tvm::runtime::PackedFunc;
API_BEGIN(); API_BEGIN();
...@@ -68,6 +81,22 @@ int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) { ...@@ -68,6 +81,22 @@ int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) {
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
using tvm::runtime::PackedFunc; using tvm::runtime::PackedFunc;
API_BEGIN(); API_BEGIN();
*out = new PackedFunc(PackedFunc::GetGlobal(name)); const PackedFunc& f = PackedFunc::GetGlobal(name);
*out = (TVMFunctionHandle)(&f); // NOLINT(*)
API_END();
}
int TVMFuncListGlobalNames(int *out_size,
const char*** out_array) {
using tvm::runtime::PackedFunc;
API_BEGIN();
TVMFuncThreadLocalEntry *ret = TVMFuncThreadLocalStore::Get();
ret->ret_vec_str = PackedFunc::ListGlobalNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<int>(ret->ret_vec_str.size());
API_END(); API_END();
} }
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/tvm.h>
#include <tvm/ir.h>
TEST(PackedFunc, Basic) { TEST(PackedFunc, Basic) {
using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
int x = 0; int x = 0;
void* handle = &x; void* handle = &x;
TVMArray a; TVMArray a;
PackedFunc([&](const TVMValue* args, const int* type_codes, int num_args) { Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK(num_args == 3); CHECK(args.num_args == 3);
CHECK(args[0].v_float64 == 1.0); CHECK(args.values[0].v_float64 == 1.0);
CHECK(type_codes[0] == kFloat); CHECK(args.type_codes[0] == kFloat);
CHECK(args[1].v_handle == &a); CHECK(args.values[1].v_handle == &a);
CHECK(type_codes[1] == kHandle); CHECK(args.type_codes[1] == kArrayHandle);
CHECK(args[2].v_handle == &x); CHECK(args.values[2].v_handle == &x);
CHECK(type_codes[2] == kHandle); CHECK(args.type_codes[2] == kHandle);
*rv = Var("a");
})(1.0, &a, handle); })(1.0, &a, handle);
CHECK(v->name_hint == "a");
} }
TEST(PackedFunc, Node) {
using namespace tvm;
using namespace tvm::runtime;
Var x;
Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK(args.num_args == 1);
CHECK(args.type_codes[0] == kNodeHandle);
Var b = args[0];
CHECK(x.same_as(b));
*rv = b;
})(x);
CHECK(t.same_as(x));
}
TEST(PackedFunc, str) {
using namespace tvm;
using namespace tvm::runtime;
PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK(args.num_args == 1);
std::string x = args[0];
CHECK(x == "hello");
*rv = x;
})("hello");
}
TEST(PackedFunc, func) {
using namespace tvm;
using namespace tvm::runtime;
PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) {
*rv = args[0].operator int() + 1;
});
// function as arguments
int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
PackedFunc f = args[0];
// TVMArgValue -> Arguments as function
*rv = f(args[1]).operator int();
})(addone, 1);
CHECK_EQ(r0, 2);
int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
// TVMArgValue -> TVMRetValue
*rv = args[1];
})(2, 100);
CHECK_EQ(r1, 100);
int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
// re-assignment
*rv = args[0];
// TVMRetValue -> Function argument
*rv = addone(args[0].operator PackedFunc()(args[1], 1));
})(addone, 100);
CHECK_EQ(r2, 102);
}
TEST(PackedFunc, Expr) {
using namespace tvm;
using namespace tvm::runtime;
// automatic conversion of int to expr
PackedFunc addone([](TVMArgs args, TVMRetValue* rv) {
Expr x = args[0];
*rv = x.as<tvm::ir::IntImm>()->value + 1;
});
int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
PackedFunc f = args[0];
// TVMArgValue -> Arguments as function
*rv = f(args[1]).operator int();
})(addone, 1);
CHECK_EQ(r0, 2);
}
TEST(PackedFunc, Type) {
using namespace tvm;
using namespace tvm::runtime;
auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Type x = args[0];
*rv = x;
});
auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
*rv = args[0];
});
CHECK(get_type("int32").operator Type() == Int(32));
CHECK(get_type("float").operator Type() == Float(32));
CHECK(get_type2("float32x2").operator Type() == Float(32, 2));
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
...@@ -2,7 +2,8 @@ import tvm ...@@ -2,7 +2,8 @@ import tvm
def test_const(): def test_const():
x = tvm.const(1) x = tvm.const(1)
assert x.dtype == 'int32' print(x.dtype)
assert x.dtype == tvm.int32
assert isinstance(x, tvm.expr.IntImm) assert isinstance(x, tvm.expr.IntImm)
def test_const_saveload_json(): def test_const_saveload_json():
......
...@@ -17,10 +17,22 @@ def test_get_global(): ...@@ -17,10 +17,22 @@ def test_get_global():
@tvm.register_func @tvm.register_func
def my_packed_func(*args): def my_packed_func(*args):
assert(tuple(args) == targs) assert(tuple(args) == targs)
return 10
# get it out from global function table # get it out from global function table
f = tvm.get_global_func("my_packed_func") f = tvm.get_global_func("my_packed_func")
assert isinstance(f, tvm.nd.Function) assert isinstance(f, tvm.nd.Function)
f(*targs) y = f(*targs)
assert y == 10
def test_return_func():
def addy(y):
def add(x):
return tvm.convert(x + y)
return add
myf = tvm.convert(addy)
f = myf(10)
assert f(11).value == 21
def test_convert(): def test_convert():
...@@ -38,3 +50,4 @@ if __name__ == "__main__": ...@@ -38,3 +50,4 @@ if __name__ == "__main__":
test_function() test_function()
test_convert() test_convert()
test_get_global() test_get_global()
test_return_func()
...@@ -38,10 +38,10 @@ fi ...@@ -38,10 +38,10 @@ fi
if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
make all || exit -1 make all || exit -1
if [ ${TRAVIS_OS_NAME} == "osx" ]; then if [ ${TRAVIS_OS_NAME} == "osx" ]; then
python -m nose tests/python/ || exit -1 python -m nose -v tests/python/ || exit -1
python3 -m nose tests/python/ || exit -1 python3 -m nose -v tests/python/ || exit -1
else else
nosetests tests/python/ || exit -1 nosetests -v tests/python/ || exit -1
nosetests3 tests/python/ || exit -1 nosetests3 -v tests/python/ || exit -1
fi fi
fi fi
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment