Commit 4242b9cf by Tianqi Chen Committed by GitHub

[API/JIT] Enable registerable global function, introduce StackVM intepreter (#25)

parent 01a7ce0c
......@@ -10,7 +10,7 @@
#include "./base.h"
#include "./expr.h"
#include "./module.h"
#include "./runtime/runtime.h"
#include "./runtime/packed_func.h"
namespace tvm {
......
......@@ -81,11 +81,13 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*!
* \brief See pesudo code
*
* bool tvm_print(VType value) {
* LOG(INFO) << value;
* int tvm_call_global(name, TVMValue* args) {
* PackedFunc f = PackedFunc::GetGlobal(name);
* f (args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_print = "tvm_print";
constexpr const char* tvm_call_global = "tvm_call_global";
/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
......
......@@ -20,9 +20,6 @@ namespace tvm {
// Internal node container of lowered function.
class LoweredFuncNode;
// Internal node container of module.
class ModuleNode;
/*!
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
......
......@@ -161,7 +161,7 @@ TVM_DLL const char *TVMGetLastError(void);
* \param option_vals Additional option values to pass
* \param num_options Number of options to be passed into it.
* \param out_code 1: success, 0: already initialized
* \return Whether the function is successful.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMDeviceInit(int dev_mask,
const char** option_keys,
......@@ -188,7 +188,7 @@ TVM_DLL int TVMContextEnabled(TVMContext ctx,
* \param dtype The array data type.
* \param ctx The ctx this array sits on.
* \param out The output handle.
* \return Whether the function is successful.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim,
......@@ -198,6 +198,7 @@ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
/*!
* \brief Free the TVM Array.
* \param handle The array handle to be freed.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
......@@ -206,6 +207,7 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
* \param from The array to be copied from.
* \param to The target space.
* \param stream The stream where the copy happens, can be NULL.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMArrayHandle to,
......@@ -214,13 +216,14 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
* \brief Wait until all computations on stream completes.
* \param ctx The ctx to be synchronized.
* \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
/*!
* \brief Free the function when it is no longer needed.
* \param func The function handle
* \return whether
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncFree(TVMFunctionHandle func);
......@@ -239,6 +242,57 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* type_codes,
int num_args);
/*!
* \brief C type of packed function.
*
* \param args The arguments
* \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
* \param resource_handle The handle additional resouce handle from fron-end.
*/
typedef void (*TVMPackedCFunc)(
TVMValue* args, int* type_codes, int num_args, void* resource_handle);
/*!
* \brief C callback to free the resource handle in C packed function.
* \param resource_handle The handle additional resouce handle from fron-end.
*/
typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle);
/*!
* \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
*
* The resource_handle will be managed by TVM API, until the function is no longer used.
*
* \param func The packed C function.
* \param resource_handle The resource handle from front-end, can be NULL.
* \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
* \param out the result function handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out);
/*!
* \brief Register the function to runtime's global table.
*
* The registered function then can be pulled by the backend by the name.
*
* \param name The name of the function.
* \param f The function to be registered.
*/
TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f);
/*!
* \brief Get a global function.
*
* \param name The name of the function.
* \param out the result function pointer.
*/
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
} // TVM_EXTERN_C
#endif // TVM_RUNTIME_C_RUNTIME_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file runtime.h
* \file packed_func.h
* \brief Runtime related c++ class.
*/
#ifndef TVM_RUNTIME_RUNTIME_H_
#define TVM_RUNTIME_RUNTIME_H_
#ifndef TVM_RUNTIME_PACKED_FUNC_H_
#define TVM_RUNTIME_PACKED_FUNC_H_
#include <functional>
#include <tuple>
#include <vector>
#include <string>
#include "./c_runtime_api.h"
namespace tvm {
namespace runtime {
/*!
* \brief Packed function is a runtime function
* whose argument type_codes are erased by packed format.
* \brief Packed function is a type-erased function.
* The arguments are passed by packed format.
*
* This is an useful unified interface to call generated functions.
* This is an useful unified interface to call generated functions,
* It is the unified function function type of TVM.
* It corresponds to TVMFunctionHandle in C runtime API.
*/
class PackedFunc {
public:
/*! \brief The internal std::function */
using FType = std::function<void(const TVMValue* args, const int* type_codes, int num_args)>;
/*! \brief default constructor */
PackedFunc() {}
/*!
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
*/
explicit PackedFunc(FType body) : body_(body) {}
/*!
* \brief invoke the packed function by directly passing in arguments.
* \brief Call packed function by directly passing in unpacked format.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
* \return The first return value.
*/
template<typename... Args>
inline void operator()(Args&& ...args) const;
......@@ -41,9 +49,25 @@ class PackedFunc {
*/
inline void CallPacked(const TVMValue* args, const int* type_codes, int num_args) const;
/*! \return the internal body function */
inline FType body() const {
return body_;
}
inline FType body() const;
/*!
* \brief Register f as into global function table
* \param name The name of the function.
* \param f The function to be registered.
* \return Reference to the registered function.
* \note The returned reference is valid until the end of the program
*/
static const PackedFunc& RegisterGlobal(const std::string& name, PackedFunc f);
/*!
* \brief Get the global function by name.
* \param name The name of the function.
* \return reference to the registered function.
*/
static const PackedFunc& GetGlobal(const std::string& name);
/*!
* \brief Get the names of currently registered global function.
*/
static std::vector<std::string> ListGlobalNames();
private:
/*! \brief internal container of packed function */
......@@ -56,6 +80,10 @@ inline void PackedFunc::CallPacked(
body_(args, type_codes, num_args);
}
inline PackedFunc::FType PackedFunc::body() const {
return body_;
}
template<bool stop, std::size_t I, typename F, typename ...Args>
struct for_each_dispatcher_ {
static inline void run(const std::tuple<Args...>& args, F f) {
......@@ -124,4 +152,4 @@ inline void PackedFunc::operator()(Args&& ...args) const {
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_RUNTIME_H_
#endif // TVM_RUNTIME_PACKED_FUNC_H_
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring, too-many-return-statements
"""Symbolic configuration API."""
from __future__ import absolute_import as _abs
......@@ -13,7 +13,7 @@ from .._base import c_str, py_str, string_types
from .._base import check_call, ctypes2docstring
from .. import _api_internal
from . import _runtime_api
from ._types import TVMValue, TypeCode
from ._types import TVMValue, TypeCode, TVMPackedCFunc, TVMCFuncFinalizer
# type definitions
APIFuncHandle = ctypes.c_void_p
......@@ -57,6 +57,13 @@ def _return_func(x):
return _runtime_api._function_cls(handle)
def _return_handle(x):
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
return handle
RET_SWITCH = {
TypeCode.NULL: lambda x: None,
TypeCode.INT: lambda x: x.v_int64,
......@@ -66,6 +73,15 @@ RET_SWITCH = {
TypeCode.FUNC_HANDLE: _return_func
}
PACK_ARG_SWITCH = {
TypeCode.NULL: lambda x: None,
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.HANDLE: lambda x: _return_handle,
}
class SliceBase(object):
"""base class of slice object"""
pass
......@@ -159,10 +175,53 @@ def const(value, dtype=None):
return _api_internal._const(value, dtype)
def _ctypes_free_resource(rhandle):
"""callback to free resources when it it not needed."""
pyobj = ctypes.cast(rhandle, ctypes.py_object)
ctypes.pythonapi.Py_DecRef(pyobj)
# Global callback that is always alive
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))
def convert_to_tvm_func(pyfunc):
"""Convert a python function to TVM function
Parameters
----------
pyfunc : python function
The python function to be converted.
Returns
-------
tvmfunc: tvm.nd.Function
The converted tvm function.
"""
local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, _):
""" ctypes function """
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
pyargs = [PACK_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)]
local_pyfunc(*pyargs)
handle = FunctionHandle()
f = TVMPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
return _runtime_api._function_cls(handle)
def convert(value):
"""Convert a value to expression."""
if isinstance(value, Number):
if isinstance(value, (NodeBase, _runtime_api.FunctionBase)):
return value
elif isinstance(value, Number):
return const(value)
elif isinstance(value, string_types):
return _api_internal._str(value)
elif isinstance(value, (list, tuple)):
value = [convert(x) for x in value]
return _api_internal._Array(*value)
......@@ -176,10 +235,11 @@ def convert(value):
return _api_internal._Map(*vlist)
elif isinstance(value, SliceBase):
return value.tensor(*value.indices)
elif callable(value):
return convert_to_tvm_func(value)
else:
if not isinstance(value, NodeBase):
raise ValueError("don't know how to handle type %s" % type(value))
return value
raise ValueError("don't know how to handle type %s" % type(value))
return value
def _push_arg(arg):
......@@ -270,6 +330,59 @@ def register_node(type_key=None):
NODE_TYPE[cls.__name__] = cls
return cls
def register_func(func_name, f=None):
"""Register global function
Parameters
----------
func_name : str or function
The function name
f : function
The function to be registered.
Returns
-------
fregister : function
Register function if f is not specified.
"""
if callable(func_name):
f = func_name
func_name = f.__name__
if not isinstance(func_name, str):
raise ValueError("expect string function name")
def register(myf):
"""internal register function"""
if not isinstance(myf, _runtime_api.FunctionBase):
myf = convert_to_tvm_func(myf)
check_call(_LIB.TVMFuncRegisterGlobal(
c_str(func_name), myf.handle))
if f:
register(f)
else:
return register
def get_global_func(name):
"""Get a global function by name
Parameters
----------
name : str
The name of the global function
Returns
-------
func : tvm.nd.Function
The function to be returned.
"""
handle = FunctionHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
return _runtime_api._function_cls(handle)
def _init_api_module(root_namespace):
"""List and add all the functions to current module."""
plist = ctypes.POINTER(ctypes.c_char_p)()
......
......@@ -70,3 +70,16 @@ class TVMType(ctypes.Structure):
if self.lanes != 1:
x += "x%d" % self.lanes
return x
TVMPackedCFunc = ctypes.CFUNCTYPE(
None,
ctypes.POINTER(TVMValue),
ctypes.POINTER(ctypes.c_int),
ctypes.c_int,
ctypes.c_void_p)
TVMCFuncFinalizer = ctypes.CFUNCTYPE(
None,
ctypes.c_void_p)
# pylint: disable=protected-access, no-member, invalid-name
# pylint: disable=redefined-builtin, undefined-variable
# pylint: disable=redefined-builtin, undefined-variable, unused-import
"""Functions defined in TVM."""
from __future__ import absolute_import as _abs
from numbers import Integral as _Integral
from ._ctypes._api import _init_api_module, convert
from ._ctypes._api import _init_api_module, convert, register_func, get_global_func
from . import _api_internal
from . import make as _make
from . import expr as _expr
......
......@@ -52,5 +52,10 @@ TVM_REGISTER_API(_codegen_DummyHelloFunction)
*ret = runtime::PackedFunc(DummyHelloFunction);
});
TVM_REGISTER_API(_codegen_BuildStackVM)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = BuildStackVM(args.at(0));
});
} // namespace codegen
} // namespace tvm
......@@ -46,7 +46,8 @@ TVM_REGISTER_API(_make_Call)
args.at(1),
args.at(2),
static_cast<Call::CallType>(args.at(3).operator int()),
args.at(4));
args.at(4),
args.at(5));
});
TVM_REGISTER_API(_make_Allocate)
......
......@@ -4,6 +4,7 @@
* \file c_api_lang.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/tensor.h>
#include <tvm/buffer.h>
#include <tvm/schedule.h>
......@@ -27,6 +28,13 @@ TVM_REGISTER_API(_const)
.add_argument("src", "Number", "source number")
.add_argument("dtype", "str", "data type");
TVM_REGISTER_API(_str)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = ir::StringImm::make(args.at(0));
});
TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
......
......@@ -9,7 +9,7 @@
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/c_api.h>
#include <tvm/runtime/runtime.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
#include <limits>
#include <string>
......
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_stack_vm.h
* \brief Codegen into Simple Stack VM.
*/
#ifndef TVM_CODEGEN_CODEGEN_STACK_VM_H_
#define TVM_CODEGEN_CODEGEN_STACK_VM_H_
#include <tvm/ir.h>
#include <tvm/module.h>
#include <tvm/codegen.h>
#include <string>
#include <unordered_map>
#include "../jit/stack_vm.h"
namespace tvm {
namespace codegen {
using jit::StackVM;
/*!
* \brief A base class to generate a stack VM.
* This module is used to generate host wrapper
* into device function when only device JIT is available.
*/
class CodeGenStackVM {
public:
/*!
* \brief Generate a stack VM representing
* \param f The function to be compiled
* \note Only call compile once,
* create a new codegen object each time.
*/
StackVM Compile(LoweredFunc f);
/*! \brief Push stmt to generate new code */
void Push(const Stmt& n);
/*! \brief Push expr to generate new code */
void Push(const Expr& n);
/*!
* \brief Push the opcode to the code.
* \param opcode The code to be pushed.
*/
void PushOp(StackVM::OpCode opcode);
/*!
* \brief Push the opcode and operand to the code.
* \param opcode The opcode.
* \param operand The operand to be pushed.
* \return operand_index, indicating location of operand
*/
int64_t PushOp(StackVM::OpCode opcode, int operand);
/*!
* \brief Set the relative jump offset to be offset.
* \param operand_index The indexed returned by PushOp.
* \param operand The operand to be set.
*/
void SetOperand(int64_t operand_index, int64_t operand);
/*! \return The current program pointer */
int64_t GetPC() const {
return static_cast<int64_t>(vm_.code.size());
}
/*!
* \brief Get string id in vm
* \param key The string to get id.
* \return the id of the string.
*/
int GetStrID(const std::string& key);
/*!
* \brief Push the function to the VM and get a id.
* \param f The function to be pushed.
*/
int GetGlobalFuncID(std::string name);
/*!
* \brief Allocate a variable name for a newly defined var.
* \param v The variable.
* \return the heap index of the var.
*/
int AllocVarID(const Variable* v);
/*!
* \brief Get a variable name.
* \param v The variable.
* \return the heap index of the var.
*/
int GetVarID(const Variable* v) const;
// overloadable functions
virtual void Push_(const ir::Load* op);
virtual void Push_(const ir::Store* op);
virtual void Push_(const ir::Allocate* op);
virtual void Push_(const ir::Call* op);
virtual void HandleUnknownCall(const ir::Call* op);
/*! \brief function to to print normal code */
using FType = IRFunctor<void(const NodeRef&, CodeGenStackVM *)>;
// vtable to print code
static FType& vtable(); // NOLINT(*)
private:
bool debug_{false};
/*! \brief The vm to be generated */
StackVM vm_;
/*! \brief id of each variable */
std::unordered_map<const Variable*, int> var_idmap_;
/*! \brief id of each string */
std::unordered_map<std::string, int> str_idmap_;
/*! \brief id of each function */
std::unordered_map<std::string, int> fun_idmap_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_STACK_VM_H_
/*!
* Copyright (c) 2016 by Contributors
* \file stack_vm.h
* \brief A simple stack-based virtual machine.
*
* This can be used to interepret host side code
* to setup calls into device functions
* when only JIT for device is available(via NVRTC or OpenCL).
*/
#ifndef TVM_JIT_STACK_VM_H_
#define TVM_JIT_STACK_VM_H_
#include <tvm/base.h>
#include <tvm/runtime/c_runtime_api.h>
#include <string>
#include <vector>
namespace tvm {
namespace jit {
/*!
* \brief A simple stack-based virtual machine.
*/
class StackVM {
public:
/*!
* \brief The opcode of stack vm
* \note Notation
* - sp Stack pointer
* - pc Program pointer
*/
enum OpCode {
// integer ops
ADD_I64,
SUB_I64,
MUL_I64,
DIV_I64,
MOD_I64,
EQ_I64,
LT_I64,
LE_I64,
// floating ops
ADD_F64,
SUB_F64,
MUL_F64,
DIV_F64,
EQ_F64,
LT_F64,
LE_F64,
// load operation
ADDR_LOAD_UINT32,
ADDR_LOAD_INT32,
ADDR_LOAD_INT64,
ADDR_LOAD_FP64,
ADDR_LOAD_HANDLE,
// store operations
// *(stack[sp - 1].v_andle) = stack[sp].v_int64
// sp = sp - 2;
ADDR_STORE_INT64,
/*!
* \brief Quick routine to load uint32 from constant offset.
* \code
* stack[sp].v_int64 = ((uint32_t*)stack[sp].v_handle)[code[pc + 1].v_int];
* pc = pc + 2;
* \endcode
*/
ARRAY_LOAD_UINT32,
// logical ops
NOT,
/*!
* \brief Add address by an offset.
* \code
* stack[sp - 1].v_handle = ((char*)stack[sp - 1].v_handle + stack[sp].v_int64);
* sp = sp - 1;
* \endcode
*/
ADDR_ADD,
/*!
* \brief push integer fetched from next pc position into stack
* \code
* stack[sp + 1].v_int64 = code[pc + 1].v_int;
* pc = pc + 2;
* sp = sp + 1;
* \endcode
*/
PUSH_I64,
/*!
* \brief push a value given relative index on the stack
* \code
* stack[sp + 1] = stack[sp + code[pc + 1].v_int];
* pc = pc + 2;
* sp = sp + 1;
* \endcode
*/
PUSH_VALUE,
/*!
* \brief Load data from heap to top of stack
* \code
* stack[sp + 1] = heap[code[pc + 1].v_int];
* pc = pc + 2;
* sp = sp + 1;
* \endcode
*/
LOAD_HEAP,
/*!
* \brief Store data to heap
* \code
* heap[code[pc + 1].v_int] = stack[sp];
* sp = sp - 1;
* \endcode
*/
STORE_HEAP,
/*! \brief pop value from top of the stack */
POP,
/*!
* \brief select based on operands.
* \code
* stack[sp - 2] = stack[sp].v_int64 ? stack[sp - 2] : stack[sp - 1]
* sp = sp - 2;
* \endcode
*/
SELECT,
/*!
* \brief call an extern function
* \code
* num_args = stack[sp].v_int64;
* call_fid = code[pc + 1].v_int;
* f = extern_func[call_fid];
* stack[sp - num_args] = f(&stack[sp - num_args], num_args);
* sp = sp - num_args;
* \endcode
*/
CALL_EXTERN,
/*!
* \brief Assert condition is true.
* \code
* CHECK(stack[sp]) << str_data[code[pc + 1].v_int];
* sp = sp - 1;
* \endcode
*/
ASSERT,
/*!
* \brief Relative Jump if the condition is true,
* Does not change the stack status.
* \code
* if (stack[sp]) {
* pc += code[pc + 1].v_int
* } else {
* pc = pc + 2;
* }
* \endcode
*/
RJUMP_IF_TRUE,
/*!
* \brief Relative Jump if the condition is true,
* Does not change the stack status.
* \code
* if (stack[sp]) {
* pc += code[pc + 1].v_int
* } else {
* pc = pc + 2;
* }
* \endcode
*/
RJUMP_IF_FALSE,
/*!
* \brief Relative jump to a location.
* \code
* pc += code[pc + 1].v_int;
* \endcode
*/
RJUMP,
/*!
* \brief debug instruction.
* \code
* CHECK_EQ(sp, code[pc + 1]).v_int;
* pc += 2;
* \code
*/
ASSERT_SP,
// Intrinsics for API function,
TVM_LOAD_ARG_INT64,
TVM_LOAD_ARG_FP64,
TVM_LOAD_ARG_HANDLE,
TVM_ARRAY_GET_DATA,
TVM_ARRAY_GET_SHAPE,
TVM_ARRAY_GET_STRIDES,
TVM_ARRAY_GET_NDIM,
TVM_ARRAY_GET_TYPE_CODE,
TVM_ARRAY_GET_TYPE_BITS,
TVM_ARRAY_GET_TYPE_LANES
};
/*! \brief The code structure */
union Code {
OpCode op_code;
int v_int;
};
/*! \brief The state object of StackVM */
struct State {
/*! \brief The execution stack */
std::vector<TVMValue> stack;
/*! \brief The global heap space */
std::vector<TVMValue> heap;
/*! \brief stack pointer */
int64_t sp{0};
/*! \brief program counter */
int64_t pc{0};
};
/*! \brief execute the stack vm with given state */
void Run(State* state) const;
/*!
* \brief Print instruction at location pc
* \param os The ostream
* \param pc The pc
* \return the pc to next instruction.
*/
int64_t PrintCode(std::ostream&os, int64_t pc) const; // NOLINT(*)
/*! \brief Get thread local state of the stack VM */
static State* ThreadLocalState();
/*! \brief extern function that will mutate the state */
using ExternFunc = std::function<TVMValue (const TVMValue* args, int num_args)>;
/*! \brief The instructions */
std::vector<Code> code;
/*! \brief constant error messages */
std::vector<std::string> str_data;
/*! \brief Extern functions */
std::vector<ExternFunc> extern_func;
/*! \brief name of each heap id*/
std::vector<std::string> heap_id_name;
/*! \brief The memory size needed */
size_t heap_size{0};
/*! \brief The stack size required */
size_t stack_size{1024};
/*!
* \brief Convert I64 opcode to F64 Ones
* \param code The op code.
* \return the F64 op code.
*/
static OpCode CodeI64ToF64(OpCode code) {
switch (code) {
case ADD_I64: return ADD_F64;
case SUB_I64: return SUB_F64;
case MUL_I64: return MUL_F64;
case DIV_I64: return DIV_F64;
case EQ_I64: return EQ_F64;
case LT_I64: return LT_F64;
case LE_I64: return LE_F64;
case MOD_I64: LOG(FATAL) << "cannot handle mod for float";
default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64;
}
}
/*!
* \brief Get load opcode for type t
* \param t the type code.
* \return The load opcode
*/
static OpCode GetLoad(Type t) {
CHECK_EQ(t.lanes(), 1);
if (t.is_handle()) return ADDR_LOAD_HANDLE;
if (t.is_int()) {
switch (t.bits()) {
case 32 : return ADDR_LOAD_INT32;
case 64 : return ADDR_LOAD_INT64;
}
} else if (t.is_uint()) {
switch (t.bits()) {
case 32 : return ADDR_LOAD_UINT32;
}
} else if (t.is_float()) {
switch (t.bits()) {
case 64 : return ADDR_LOAD_FP64;
}
}
LOG(FATAL) << "Cannot load type " << t;
return ADDR_LOAD_FP64;
}
/*!
* \brief Get store opcode for type t
* \param t the type code.
* \return The load opcode
*/
static OpCode GetStore(Type t) {
CHECK_EQ(t.lanes(), 1);
if (t.is_int()) {
switch (t.bits()) {
case 64 : return ADDR_STORE_INT64;
}
}
LOG(FATAL) << "Cannot store type " << t;
return ADDR_LOAD_FP64;
}
friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*)
};
} // namespace jit
} // namespace tvm
#endif // TVM_JIT_STACK_VM_H_
......@@ -4,7 +4,7 @@
* \brief Device specific implementations
*/
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/runtime.h>
#include <tvm/runtime/packed_func.h>
#include <algorithm>
#include "./runtime_base.h"
#include "./device_api.h"
......@@ -170,7 +170,7 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc::FType*>(func);
delete static_cast<PackedFunc*>(func);
API_END();
}
......@@ -179,7 +179,35 @@ int TVMFuncCall(TVMFunctionHandle func,
int* arg_type_codes,
int num_args) {
API_BEGIN();
(*static_cast<const PackedFunc::FType*>(func))(
(*static_cast<const PackedFunc*>(func)).CallPacked(
args, arg_type_codes, num_args);
API_END();
}
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out) {
API_BEGIN();
if (fin == nullptr) {
*out = new PackedFunc(
[func, resource_handle](const TVMValue* args,
const int* type_codes,
int num_args) {
func((TVMValue*)args, (int*)type_codes, // NOLINT(*)
num_args, resource_handle);
});
} else {
// wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc(
[func, rpack](const TVMValue* args,
const int* type_codes,
int num_args) {
func((TVMValue*)args, (int*)type_codes, // NOLINT(*)
num_args, rpack.get());
});
}
API_END();
}
/*!
* Copyright (c) 2017 by Contributors
* \file packed_func_registry.cc
* \brief The global registry of packed function.
*/
#include <dmlc/logging.h>
#include <tvm/runtime/packed_func.h>
#include <unordered_map>
#include <memory>
#include "./runtime_base.h"
namespace tvm {
namespace runtime {
struct PackedFuncRegistry {
// map storing the functions.
// We delibrately used raw pointer
// This is because PackedFunc can contain callbacks into the host languge(python)
// and the resource can become invalid because of indeterminstic order of destruction.
// The resources will only be recycled during program exit.
std::unordered_map<std::string, PackedFunc*> fmap;
static PackedFuncRegistry* Global() {
static PackedFuncRegistry inst;
return &inst;
}
};
const PackedFunc& PackedFunc::RegisterGlobal(
const std::string& name, PackedFunc f) {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
auto it = r->fmap.find(name);
CHECK(it == r->fmap.end())
<< "Global PackedFunc " << name << " is already registered";
PackedFunc* fp = new PackedFunc(f);
r->fmap[name] = fp;
return *fp;
}
const PackedFunc& PackedFunc::GetGlobal(const std::string& name) {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
auto it = r->fmap.find(name);
CHECK(it != r->fmap.end())
<< "Global PackedFunc " << name << " is not registered";
return *(it->second);
}
std::vector<std::string> PackedFunc::ListGlobalNames() {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
std::vector<std::string> keys;
keys.reserve(r->fmap.size());
for (const auto &kv : r->fmap) {
keys.push_back(kv.first);
}
return keys;
}
} // namespace runtime
} // namespace tvm
int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) {
using tvm::runtime::PackedFunc;
API_BEGIN();
PackedFunc::RegisterGlobal(name, *static_cast<PackedFunc*>(f));
API_END();
}
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
using tvm::runtime::PackedFunc;
API_BEGIN();
*out = new PackedFunc(PackedFunc::GetGlobal(name));
API_END();
}
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/runtime/runtime.h>
#include <tvm/runtime/packed_func.h>
TEST(PackedFunc, Basic) {
using namespace tvm::runtime;
......
import tvm
import numpy as np
def tvm_call_global(*args):
args = tvm.convert(args)
return tvm.make.Call("int32", "tvm_call_global", args, 4, None, 0)
def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func
def tvm_call_back_get_shape(shape0):
assert shape0 == a.shape[0]
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm_call_global("tvm_call_back_get_shape", Ab.shape[0]))
print(stmt)
fapi = tvm.codegen.MakeAPI(stmt, "print_shape", [Ab], 1)
print(fapi.body)
f = tvm.codegen.BuildStackVM(fapi)
f(a)
@tvm.register_func
def tvm_stack_vm_print(*x):
print(x)
def test_stack_vm_loop():
dtype = 'int64'
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), dtype)
i = tvm.Var('i')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, 0, n - 1, 0, 0,
tvm.make.Block(
tvm.make.Store(Ab.ptr,
tvm.make.Load(dtype, Ab.ptr, i) + 1,
i + 1),
tvm.make.Evaluate(tvm_call_global("tvm_stack_vm_print", i))))
print(stmt)
fapi = tvm.codegen.MakeAPI(stmt, "ramp", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
def test_stack_vm_cond():
dtype = 'int64'
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), dtype)
i = tvm.Var('i')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, 0, n - 1, 0, 0,
tvm.make.IfThenElse(
tvm.make.EQ(i, 4),
tvm.make.Store(Ab.ptr,
tvm.make.Load(dtype, Ab.ptr, i) + 1, i + 1),
tvm.make.Store(Ab.ptr,
tvm.make.Load(dtype, Ab.ptr, i) + 2, i + 1)))
print(stmt)
fapi = tvm.codegen.MakeAPI(stmt, "test", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
y = np.arange(a.shape[0]) * 2
y[5:] -= 1
np.testing.assert_equal(a.asnumpy(), y)
if __name__ == "__main__":
test_stack_vm_cond()
import tvm
import numpy as np
def test_function():
ctx = tvm.cpu(0)
x = np.random.randint(0, 10, size=(3, 4))
......@@ -13,5 +11,30 @@ def test_function():
f(y, 10)
def test_get_global():
targs = (10, 10.0, "hello")
# register into global function table
@tvm.register_func
def my_packed_func(*args):
assert(tuple(args) == targs)
# get it out from global function table
f = tvm.get_global_func("my_packed_func")
assert isinstance(f, tvm.nd.Function)
f(*targs)
def test_convert():
# convert a function to tvm function
targs = (10, 10.0, "hello", 10)
def myfunc(*args):
assert(tuple(args) == targs)
f = tvm.convert(myfunc)
assert isinstance(f, tvm.nd.Function)
f(*targs)
if __name__ == "__main__":
test_function()
test_convert()
test_get_global()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment