Commit b6dc7826 by Haichen Shen Committed by Jared Roesch

[Relay][VM] Port VM, VM compiler, and Object into python (#3391)

* tmp

* Port vm and object to python

* clean up

* update vm build module

* update

* x

* tweak

* cleanup

* update

* fix rebase

* Rename to VMCompiler

* fix
parent afd4b3e4
...@@ -103,7 +103,7 @@ typedef enum { ...@@ -103,7 +103,7 @@ typedef enum {
kStr = 11U, kStr = 11U,
kBytes = 12U, kBytes = 12U,
kNDArrayContainer = 13U, kNDArrayContainer = 13U,
kObject = 14U, kObjectCell = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc. // Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and // To make sure each framework's id do not conflict, use first and
// last sections to mark ranges. // last sections to mark ranges.
...@@ -176,6 +176,8 @@ typedef void* TVMRetValueHandle; ...@@ -176,6 +176,8 @@ typedef void* TVMRetValueHandle;
* can be NULL, which indicates the default one. * can be NULL, which indicates the default one.
*/ */
typedef void* TVMStreamHandle; typedef void* TVMStreamHandle;
/*! \brief Handle to Object. */
typedef void* TVMObjectHandle;
/*! /*!
* \brief Used for implementing C API function. * \brief Used for implementing C API function.
...@@ -545,6 +547,15 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type, ...@@ -545,6 +547,15 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
TVMStreamHandle src, TVMStreamHandle src,
TVMStreamHandle dst); TVMStreamHandle dst);
/*!
* \brief Get the tag from an object.
*
* \param obj The object handle.
* \param tag The tag of object.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag);
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
#endif #endif
......
...@@ -324,6 +324,7 @@ class Object { ...@@ -324,6 +324,7 @@ class Object {
Object() : ptr_() {} Object() : ptr_() {}
Object(const Object& obj) : ptr_(obj.ptr_) {} Object(const Object& obj) : ptr_(obj.ptr_) {}
ObjectCell* operator->() { return this->ptr_.operator->(); } ObjectCell* operator->() { return this->ptr_.operator->(); }
const ObjectCell* operator->() const { return this->ptr_.operator->(); }
/*! \brief Construct a tensor object. */ /*! \brief Construct a tensor object. */
static Object Tensor(const NDArray& data); static Object Tensor(const NDArray& data);
......
...@@ -491,7 +491,7 @@ class TVMPODValue_ { ...@@ -491,7 +491,7 @@ class TVMPODValue_ {
} }
operator Object() const { operator Object() const {
if (type_code_ == kNull) return Object(); if (type_code_ == kNull) return Object();
TVM_CHECK_TYPE_CODE(type_code_, kObject); TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
return Object(static_cast<ObjectCell*>(value_.v_handle)); return Object(static_cast<ObjectCell*>(value_.v_handle));
} }
operator TVMContext() const { operator TVMContext() const {
...@@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ {
} }
TVMRetValue& operator=(Object other) { TVMRetValue& operator=(Object other) {
this->Clear(); this->Clear();
type_code_ = kObject; type_code_ = kObjectCell;
value_.v_handle = other.ptr_.data_; value_.v_handle = other.ptr_.data_;
other.ptr_.data_ = nullptr; other.ptr_.data_ = nullptr;
return *this; return *this;
...@@ -861,7 +861,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -861,7 +861,7 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >()); kNodeHandle, *other.template ptr<NodePtr<Node> >());
break; break;
} }
case kObject: { case kObjectCell: {
*this = other.operator Object(); *this = other.operator Object();
break; break;
} }
...@@ -912,7 +912,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -912,7 +912,7 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef(); static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break; break;
} }
case kObject: { case kObjectCell: {
static_cast<ObjectCell*>(value_.v_handle)->DecRef(); static_cast<ObjectCell*>(value_.v_handle)->DecRef();
break; break;
} }
...@@ -945,7 +945,7 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -945,7 +945,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle"; case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle"; case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer"; case kNDArrayContainer: return "NDArrayContainer";
case kObject: return "Object"; case kObjectCell: return "ObjectCell";
default: LOG(FATAL) << "unknown type_code=" default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return ""; << static_cast<int>(type_code); return "";
} }
......
...@@ -143,7 +143,7 @@ struct Instruction { ...@@ -143,7 +143,7 @@ struct Instruction {
/*! \brief The registers containing the arguments. */ /*! \brief The registers containing the arguments. */
RegName* invoke_args_registers; RegName* invoke_args_registers;
}; };
struct /* Const Operands */ { struct /* LoadConst Operands */ {
/* \brief The index into the constant pool. */ /* \brief The index into the constant pool. */
Index const_index; Index const_index;
}; };
...@@ -308,14 +308,14 @@ struct Instruction { ...@@ -308,14 +308,14 @@ struct Instruction {
struct VMFunction { struct VMFunction {
/*! \brief The function's name. */ /*! \brief The function's name. */
std::string name; std::string name;
/*! \brief The number of function parameters. */ /*! \brief The function parameter names. */
Index params; std::vector<std::string> params;
/*! \brief The instructions representing the function. */ /*! \brief The instructions representing the function. */
std::vector<Instruction> instructions; std::vector<Instruction> instructions;
/*! \brief The size of the frame for this function */ /*! \brief The size of the frame for this function */
Index register_file_size; Index register_file_size;
VMFunction(const std::string& name, Index params, VMFunction(const std::string& name, std::vector<std::string> params,
const std::vector<Instruction>& instructions, const std::vector<Instruction>& instructions,
Index register_file_size) Index register_file_size)
: name(name), : name(name),
...@@ -370,7 +370,15 @@ struct VMFrame { ...@@ -370,7 +370,15 @@ struct VMFrame {
* multiple threads, or serialized them to disk or over the * multiple threads, or serialized them to disk or over the
* wire. * wire.
*/ */
struct VirtualMachine { class VirtualMachine : public runtime::ModuleNode {
public:
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final {
return "VirtualMachine";
}
/*! \brief The virtual machine's packed function table. */ /*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs; std::vector<PackedFunc> packed_funcs;
/*! \brief The virtual machine's function table. */ /*! \brief The virtual machine's function table. */
...@@ -442,7 +450,7 @@ struct VirtualMachine { ...@@ -442,7 +450,7 @@ struct VirtualMachine {
/*! \brief A map from globals (as strings) to their index in the function map. /*! \brief A map from globals (as strings) to their index in the function map.
*/ */
std::unordered_map<std::string, Index> global_map_; std::unordered_map<std::string, Index> global_map;
private: private:
/*! \brief Invoke a global setting up the VM state to execute. /*! \brief Invoke a global setting up the VM state to execute.
......
...@@ -37,6 +37,7 @@ from . import node as _node ...@@ -37,6 +37,7 @@ from . import node as _node
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p
ObjectHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p TVMRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle): def _ctypes_free_resource(rhandle):
...@@ -162,9 +163,9 @@ def _make_tvm_args(args, temp_args): ...@@ -162,9 +163,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg) temp_args.append(arg)
elif isinstance(arg, ObjectBase): elif isinstance(arg, _CLASS_OBJECT):
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT type_codes[i] = TypeCode.OBJECT_CELL
else: else:
raise TypeError("Don't know how to handle type %s" % type(arg)) raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args return values, type_codes, num_args
...@@ -236,6 +237,7 @@ def _return_module(x): ...@@ -236,6 +237,7 @@ def _return_module(x):
handle = ModuleHandle(handle) handle = ModuleHandle(handle)
return _CLASS_MODULE(handle) return _CLASS_MODULE(handle)
def _handle_return_func(x): def _handle_return_func(x):
"""Return function""" """Return function"""
handle = x.v_handle handle = x.v_handle
...@@ -243,18 +245,11 @@ def _handle_return_func(x): ...@@ -243,18 +245,11 @@ def _handle_return_func(x):
handle = FunctionHandle(handle) handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False) return _CLASS_FUNCTION(handle, False)
class ObjectBase(object):
__slots__ = ["handle"]
def __init__(self, handle):
self.handle = handle
# setup return handle for function type # setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__ _node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True) RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
RETURN_SWITCH[TypeCode.OBJECT] = lambda x: _CLASS_OBJECT(x.v_handle)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE) _handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH
ObjectHandle = ctypes.c_void_p
"""Maps object type to its constructor"""
OBJECT_TYPE = {}
def _register_object(index, cls):
"""register object class"""
OBJECT_TYPE[index] = cls
def _return_object(x):
handle = x.v_handle
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
tag = ctypes.c_int()
check_call(_LIB.TVMGetObjectTag(handle, ctypes.byref(tag)))
cls = OBJECT_TYPE.get(tag.value, ObjectBase)
obj = cls(handle)
return obj
RETURN_SWITCH[TypeCode.OBJECT_CELL] = _return_object
class ObjectBase(object):
__slots__ = ["handle"]
def __init__(self, handle):
self.handle = handle
...@@ -37,7 +37,7 @@ cdef enum TVMTypeCode: ...@@ -37,7 +37,7 @@ cdef enum TVMTypeCode:
kStr = 11 kStr = 11
kBytes = 12 kBytes = 12
kNDArrayContainer = 13 kNDArrayContainer = 13
kObject = 14 kObjectCell = 14
kExtBegin = 15 kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "tvm/runtime/c_runtime_api.h":
...@@ -130,6 +130,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ...@@ -130,6 +130,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
int TVMArrayToDLPack(DLTensorHandle arr_from, int TVMArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out) DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
int TVMGetObjectTag(ObjectHandle obj, int* tag)
cdef extern from "tvm/c_dsl_api.h": cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle) int TVMNodeFree(NodeHandle handle)
......
...@@ -19,3 +19,4 @@ include "./base.pxi" ...@@ -19,3 +19,4 @@ include "./base.pxi"
include "./node.pxi" include "./node.pxi"
include "./function.pxi" include "./function.pxi"
include "./ndarray.pxi" include "./ndarray.pxi"
include "./vmobj.pxi"
...@@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args, ...@@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kNodeHandle or if (tcode == kNodeHandle or
tcode == kFuncHandle or tcode == kFuncHandle or
tcode == kModuleHandle or tcode == kModuleHandle or
tcode == kObject or tcode == kObjectCell or
tcode > kExtBegin): tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode)) CALL(TVMCbArgToReturn(&value, tcode))
...@@ -160,7 +160,7 @@ cdef inline int make_arg(object arg, ...@@ -160,7 +160,7 @@ cdef inline int make_arg(object arg,
tcode[0] = kModuleHandle tcode[0] = kModuleHandle
elif isinstance(arg, _CLASS_OBJECT): elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = c_handle(arg.handle) value[0].v_handle = c_handle(arg.handle)
tcode[0] = kObject tcode[0] = kObjectCell
elif isinstance(arg, FunctionBase): elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle tcode[0] = kFuncHandle
...@@ -212,8 +212,8 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -212,8 +212,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False) fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle (<FunctionBase>fobj).chandle = value.v_handle
return fobj return fobj
elif tcode == kObject: elif tcode == kObjectCell:
return _CLASS_OBJECT(ctypes_handle(value.v_handle)) return make_ret_object(value.v_handle)
elif tcode in _TVM_EXT_RET: elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
...@@ -310,27 +310,6 @@ cdef class FunctionBase: ...@@ -310,27 +310,6 @@ cdef class FunctionBase:
FuncCall(self.chandle, args, &ret_val, &ret_tcode) FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode) return make_ret(ret_val, ret_tcode)
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
_CLASS_FUNCTION = None _CLASS_FUNCTION = None
_CLASS_MODULE = None _CLASS_MODULE = None
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Maps object type to its constructor"""
OBJECT_TYPE = []
def _register_object(int index, object cls):
"""register node class"""
while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls
cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE
cdef int tag
cdef list object_type
cdef object cls
cdef object handle
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
CALL(TVMGetObjectTag(chandle, &tag))
if tag < len(object_type):
cls = object_type[tag]
if cls is not None:
obj = cls(handle)
else:
obj = ObjectBase(handle)
else:
obj = ObjectBase(handle)
return obj
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
...@@ -22,6 +22,7 @@ from __future__ import absolute_import ...@@ -22,6 +22,7 @@ from __future__ import absolute_import
import sys import sys
import ctypes import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from . import vmobj as _vmobj
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -30,28 +31,19 @@ try: ...@@ -30,28 +31,19 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module, _set_class_object from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import convert_to_tvm_func from ._cy3.core import convert_to_tvm_func
else: else:
from ._cy2.core import _set_class_function, _set_class_module, _set_class_object from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import convert_to_tvm_func from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module, _set_class_object from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import ObjectBase as _ObjectBase
from ._ctypes.function import FunctionBase as _FunctionBase from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func from ._ctypes.function import convert_to_tvm_func
class Object(_ObjectBase):
# TODO(@jroesch): Eventually add back introspection functionality.
pass
_set_class_object(Object)
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase): class Function(_FunctionBase):
......
...@@ -42,7 +42,7 @@ class TypeCode(object): ...@@ -42,7 +42,7 @@ class TypeCode(object):
STR = 11 STR = 11
BYTES = 12 BYTES = 12
NDARRAY_CONTAINER = 13 NDARRAY_CONTAINER = 13
OBJECT = 14 OBJECT_CELL = 14
EXT_BEGIN = 15 EXT_BEGIN = 15
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import
import sys
from .base import _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_object
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_object
from ._ctypes.vmobj import ObjectBase as _ObjectBase
from ._ctypes.vmobj import _register_object
class ObjectTag(object):
"""Type code used in API calls"""
TENSOR = 0
CLOSURE = 1
DATATYPE = 2
class Object(_ObjectBase):
"""The VM Object used in Relay virtual machine."""
def register_object(cls):
_register_object(cls.tag, cls)
return cls
_set_class_object(Object)
...@@ -33,6 +33,8 @@ from . import parser ...@@ -33,6 +33,8 @@ from . import parser
from . import debug from . import debug
from . import param_dict from . import param_dict
from . import feature from . import feature
from .backend import vm
from .backend import vmobj
# Root operators # Root operators
from .op import Op from .op import Op
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The VM Object FFI namespace."""
from tvm._ffi.function import _init_api
_init_api("_vmobj", __name__)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
""" """
The Relay Virtual Vachine. The Relay Virtual Vachine.
...@@ -23,55 +23,38 @@ Implements a Python interface to compiling and executing on the Relay VM. ...@@ -23,55 +23,38 @@ Implements a Python interface to compiling and executing on the Relay VM.
import numpy as np import numpy as np
import tvm import tvm
from tvm._ffi.function import Object
from .. import transform
from ..backend.interpreter import Executor
from ..expr import GlobalVar, Expr
from . import _vm from . import _vm
from . import vmobj as _obj
Object = Object from .interpreter import Executor
def optimize(mod):
"""Perform several optimizations on a module before executing it in the def _update_target(target):
Relay virtual machine. target = target if target else tvm.target.current_target()
if target is None:
Parameters raise ValueError("Target is not set in env or passed as argument.")
----------
mod : tvm.relay.Module tgts = {}
The module to optimize. if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
Returns tgts[dev_type] = tvm.target.create(target)
------- elif isinstance(target, dict):
ret : tvm.relay.Module for dev, tgt in target.items():
The optimized module. dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
""" tgts[dev_type] = tvm.target.create(tgt)
main_func = mod["main"] else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
opt_passes = [] "or dict of str to str/tvm.target.Target, but received " +
if not main_func.params and isinstance(main_func.body, GlobalVar): "{}".format(type(target)))
opt_passes.append(transform.EtaExpand()) return tgts
opt_passes = opt_passes + [
transform.SimplifyInference(),
transform.FuseOps(),
transform.InferType()
]
seq = transform.Sequential(opt_passes)
return seq(mod)
def _convert(arg, cargs): def _convert(arg, cargs):
if isinstance(arg, np.ndarray): if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
tensor = _vm._Tensor(tvm.nd.array(arg)) cargs.append(_obj.tensor_object(arg))
cargs.append(tensor) elif isinstance(arg, (tuple, list)):
elif isinstance(arg, tvm.nd.NDArray):
tensor = _vm._Tensor(arg)
cargs.append(tensor)
elif isinstance(arg, tuple):
field_args = [] field_args = []
for field in arg: for field in arg:
_convert(field, field_args) _convert(field, field_args)
cargs.append(_vm._Tuple(*field_args)) cargs.append(_obj.tuple_object(field_args))
else: else:
raise "unsupported type" raise "unsupported type"
...@@ -82,28 +65,97 @@ def convert(args): ...@@ -82,28 +65,97 @@ def convert(args):
return cargs return cargs
def _eval_vm(mod, ctx, *args):
"""
Evaluate a module on a given context with the provided arguments.
Parameters class VirtualMachine(object):
---------- """Relay VM runtime."""
mod: relay.Module def __init__(self, mod):
The module to optimize, will execute its entry_func. self.mod = mod
self._init = self.mod["init"]
ctx: tvm.Context self._invoke = self.mod["invoke"]
The TVM context to execute on.
def init(self, ctx):
args: List[tvm.NDArray, np.ndarray] """Initialize the context in the VM.
The arguments to evaluate.
""" Parameters
mod = optimize(mod) ----------
args = list(args) ctx : :py:class:`TVMContext`
assert isinstance(args, list) The runtime context to run the code on.
cargs = convert(args) """
args = [ctx.device_type, ctx.device_id]
self._init(*args)
def invoke(self, func_name, *args):
"""Invoke a function.
Parameters
----------
func_name : str
The name of the function.
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
Returns
-------
result : Object
The output.
"""
cargs = convert(args)
return self._invoke(func_name, *cargs)
def run(self, *args):
"""Run the main function.
Parameters
----------
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
Returns
-------
result : Object
The output.
"""
return self.invoke("main", *args)
class VMCompiler(object):
"""Build Relay module to run on VM runtime."""
def __init__(self):
self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
def compile(self, mod, target=None, target_host=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Returns
-------
vm : VirtualMachine
The VM runtime.
"""
target = _update_target(target)
self._compile(mod, target, target_host)
return VirtualMachine(self._get_vm())
result = _vm._evaluate_vm(mod, ctx.device_type, ctx.device_id, *cargs)
return result
class VMExecutor(Executor): class VMExecutor(Executor):
""" """
...@@ -126,19 +178,21 @@ class VMExecutor(Executor): ...@@ -126,19 +178,21 @@ class VMExecutor(Executor):
The target option to build the function. The target option to build the function.
""" """
def __init__(self, mod, ctx, target): def __init__(self, mod, ctx, target):
if mod is None:
raise RuntimeError("Must provide module to get VM executor.")
self.mod = mod self.mod = mod
self.ctx = ctx self.ctx = ctx
self.target = target self.target = target
compiler = VMCompiler()
self.vm = compiler.compile(mod, target)
self.vm.init(ctx)
def _make_executor(self, expr=None): def _make_executor(self, expr=None):
expr = expr if expr else self.mod assert expr is None
assert expr, "either expr or self.mod should be not null."
if isinstance(expr, Expr):
self.mod["main"] = expr
main = self.mod["main"] main = self.mod["main"]
def _vm_wrapper(*args, **kwargs): def _vm_wrapper(*args, **kwargs):
args = self._convert_args(main, args, kwargs) args = self._convert_args(main, args, kwargs)
return _eval_vm(self.mod, self.ctx, *args) return self.vm.run(*args)
return _vm_wrapper return _vm_wrapper
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Runtime Object API."""
from __future__ import absolute_import as _abs
import numpy as _np
from tvm._ffi.vmobj import Object, ObjectTag, register_object
from tvm import ndarray as _nd
from . import _vmobj
# TODO(@icemelon9): Add ClosureObject
@register_object
class TensorObject(Object):
"""Tensor object."""
tag = ObjectTag.TENSOR
def __init__(self, handle):
"""Constructs a Tensor object
Parameters
----------
handle : object
Object handle
Returns
-------
obj : TensorObject
A tensor object.
"""
super(TensorObject, self).__init__(handle)
self.data = _vmobj.GetTensorData(self)
def asnumpy(self):
"""Convert data to numpy array
Returns
-------
np_arr : numpy.ndarray
The corresponding numpy array.
"""
return self.data.asnumpy()
@register_object
class DatatypeObject(Object):
"""Datatype object."""
tag = ObjectTag.DATATYPE
def __init__(self, handle):
"""Constructs a Datatype object
Parameters
----------
handle : object
Object handle
Returns
-------
obj : DatatypeObject
A Datatype object.
"""
super(DatatypeObject, self).__init__(handle)
self.tag = _vmobj.GetDatatypeTag(self)
num_fields = _vmobj.GetDatatypeNumberOfFields(self)
self.fields = []
for i in range(num_fields):
self.fields.append(_vmobj.GetDatatypeFields(self, i))
def __getitem__(self, idx):
return self.fields[idx]
def __len__(self):
return self.num_fields
def __iter__(self):
return iter(self.fields)
# TODO(icemelon9): Add closure object
def tensor_object(arr, ctx=_nd.cpu(0)):
"""Create a tensor object from source arr.
Parameters
----------
arr : numpy.ndarray or tvm.nd.NDArray
The source array.
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : TensorObject
The created object.
"""
if isinstance(arr, _np.ndarray):
tensor = _vmobj.Tensor(_nd.array(arr, ctx))
elif isinstance(arr, _nd.NDArray):
tensor = _vmobj.Tensor(arr)
else:
raise RuntimeError("Unsupported type for tensor object.")
return tensor
def tuple_object(fields):
"""Create a datatype object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : DatatypeObject
The created object.
"""
for f in fields:
assert isinstance(f, Object)
return _vmobj.Tuple(*fields)
def datatype_object(tag, fields):
"""Create a datatype object from tag and source fields.
Parameters
----------
tag : int
The tag of datatype.
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : DatatypeObject
The created object.
"""
for f in fields:
assert isinstance(f, Object)
return _vmobj.Datatype(tag, *fields)
...@@ -529,6 +529,18 @@ def CanonicalizeCast(): ...@@ -529,6 +529,18 @@ def CanonicalizeCast():
return _transform.CanonicalizeCast() return _transform.CanonicalizeCast()
def LambdaLift():
"""
Lift the closure to global function.
Returns
-------
ret : tvm.relay.Pass
The registered pass that lifts the lambda function.
"""
return _transform.LambdaLift()
def PrintIR(): def PrintIR():
""" """
Print the IR for a module to help debugging. Print the IR for a module to help debugging.
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <memory> #include <memory>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/vm.cc
* \brief The Relay virtual machine.
*/
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/module.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/analysis.h>
namespace tvm {
namespace relay {
namespace vm {
runtime::vm::VirtualMachine CompileModule(const Module& mod);
using tvm::runtime::Object;
using tvm::runtime::ObjectTag;
using tvm::runtime::vm::VirtualMachine;
VirtualMachine FromModule(const Module& module, const std::vector<TVMContext>& ctxs) {
auto vm = CompileModule(module);
vm.Init(ctxs);
return vm;
}
Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
const std::vector<Object>& vm_args) {
VirtualMachine vm = FromModule(module, ctxs);
// TODO(zhiics): This measurement is for temporary usage. Remove it later. We
// need to introduce a better profiling method.
#if ENABLE_PROFILING
DLOG(INFO) << "Entry function is main." << std::endl;
auto start = std::chrono::high_resolution_clock::now();
#endif // ENABLE_PROFILING
Object res = vm.Invoke("main", vm_args);
#if ENABLE_PROFILING
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
LOG(INFO) << "Inference time: " << duration << "ms\n";
#endif // ENABLE_PROFILING
return res;
}
Value VMToValue(const relay::Module& module, Object obj) {
CHECK(module.defined());
switch (obj->tag) {
case ObjectTag::kTensor: {
return TensorValueNode::make(ToNDArray(obj));
}
case ObjectTag::kDatatype: {
const auto& data_type = obj.AsDatatype();
tvm::Array<Value> fields;
for (size_t i = 0; i < data_type->fields.size(); ++i) {
fields.push_back(VMToValue(module, data_type->fields[i]));
}
return ConstructorValueNode::make(data_type->tag, fields);
}
default:
LOG(FATAL) << "unsupported return value of type: " << obj->tag;
return Value();
}
}
TVM_REGISTER_API("relay._vm._Tensor").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Object::Tensor(args[0]);
});
TVM_REGISTER_API("relay._vm._Tuple").set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<Object> fields;
for (auto i = 0; i < args.size(); i++) {
fields.push_back(args[i]);
}
*ret = Object::Tuple(fields);
});
template <typename T>
std::string ToString(const T& t) {
std::stringstream s;
s << t;
return s.str();
}
TVM_REGISTER_API("relay._vm._ObjectTag").set_body([](TVMArgs args, TVMRetValue* ret) {
Object obj = args[0];
*ret = ToString(obj->tag);
});
TVM_REGISTER_API("relay._vm._Datatype")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*ret = Object::Datatype(tag, fields);
});
TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef to_compile = args[0];
TVMContext ctx;
int dev_type = args[1];
ctx.device_type = static_cast<DLDeviceType>(dev_type);
ctx.device_id = args[2];
Module module;
if (to_compile.as<FunctionNode>()) {
Function to_compile = args[0];
module = ModuleNode::FromExpr(to_compile);
} else if (to_compile.as<ModuleNode>()) {
module = args[0];
} else {
LOG(FATAL) << "expected function or module";
}
std::vector<Object> vm_args;
for (auto i = 3; i < args.size(); i++) {
Object obj = args[i];
vm_args.push_back(obj);
}
auto result = EvaluateModule(module, {ctx}, vm_args);
DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
*ret = VMToValue(module, result);
});
} // namespace vm
} // namespace relay
} // namespace tvm
...@@ -25,7 +25,10 @@ ...@@ -25,7 +25,10 @@
#include <tvm/logging.h> #include <tvm/logging.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include <iostream> #include <iostream>
#include "../runtime_base.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -87,5 +90,69 @@ NDArray ToNDArray(const Object& obj) { ...@@ -87,5 +90,69 @@ NDArray ToNDArray(const Object& obj) {
return tensor->data; return tensor->data;
} }
TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsTensor();
*rv = cell->data;
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsDatatype();
*rv = static_cast<int>(cell->tag);
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsDatatype();
*rv = static_cast<int>(cell->fields.size());
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
int idx = args[1];
auto cell = obj.AsDatatype();
CHECK_LT(idx, cell->fields.size());
*rv = cell->fields[idx];
});
TVM_REGISTER_GLOBAL("_vmobj.Tensor")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Object::Tensor(args[0]);
});
TVM_REGISTER_GLOBAL("_vmobj.Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<Object> fields;
for (auto i = 0; i < args.size(); ++i) {
fields.push_back(args[i]);
}
*rv = Object::Tuple(fields);
});
TVM_REGISTER_GLOBAL("_vmobj.Datatype")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*rv = Object::Datatype(tag, fields);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
using namespace tvm::runtime;
int TVMGetObjectTag(TVMObjectHandle handle, int* tag) {
API_BEGIN();
*tag = static_cast<int>(static_cast<ObjectCell*>(handle)->tag);
API_END();
}
...@@ -41,7 +41,7 @@ class PooledAllocator final : public Allocator { ...@@ -41,7 +41,7 @@ class PooledAllocator final : public Allocator {
static constexpr size_t kDefaultPageSize = 4096; static constexpr size_t kDefaultPageSize = 4096;
explicit PooledAllocator(TVMContext ctx, size_t page_size = kDefaultPageSize) explicit PooledAllocator(TVMContext ctx, size_t page_size = kDefaultPageSize)
: Allocator(), page_size_(page_size), used_memory_(0) {} : Allocator(), page_size_(page_size), used_memory_(0), ctx_(ctx) {}
~PooledAllocator() { ReleaseAll(); } ~PooledAllocator() { ReleaseAll(); }
......
...@@ -32,8 +32,8 @@ ...@@ -32,8 +32,8 @@
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include "../../runtime/vm/memory_manager.h" #include "memory_manager.h"
#include "../../runtime/vm/naive_allocator.h" #include "naive_allocator.h"
using namespace tvm::runtime; using namespace tvm::runtime;
...@@ -554,6 +554,37 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) { ...@@ -554,6 +554,37 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
return os; return os;
} }
PackedFunc VirtualMachine::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
std::vector<Object> func_args;
for (int i = 1; i < args.size(); ++i) {
Object obj = args[i];
func_args.push_back(obj);
}
*rv = this->Invoke(func_name, func_args);
});
} else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size() % 2, 0);
std::vector<TVMContext> contexts;
for (int i = 0; i < args.size() / 2; ++i) {
TVMContext ctx;
int device_type = args[i * 2];
ctx.device_type = DLDeviceType(device_type);
ctx.device_id = args[i * 2 + 1];
contexts.push_back(ctx);
}
this->Init(contexts);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}
void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) {
auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size); auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size);
frames.push_back(frame); frames.push_back(frame);
...@@ -573,11 +604,11 @@ Index VirtualMachine::PopFrame() { ...@@ -573,11 +604,11 @@ Index VirtualMachine::PopFrame() {
void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Object>& args) { void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Object>& args) {
DLOG(INFO) << "Invoking global " << func.name << " " << args.size(); DLOG(INFO) << "Invoking global " << func.name << " " << args.size();
PushFrame(func.params, this->pc + 1, func); PushFrame(func.params.size(), this->pc + 1, func);
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
WriteRegister(i, args[i]); WriteRegister(i, args[i]);
} }
DLOG(INFO) << "func.params= " << func.params; DLOG(INFO) << "func.params= " << func.params.size();
code = func.instructions.data(); code = func.instructions.data();
pc = 0; pc = 0;
...@@ -594,7 +625,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>& ...@@ -594,7 +625,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>&
} }
Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object>& args) { Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object>& args) {
auto func_index = this->global_map_[name]; auto func_index = this->global_map[name];
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index; DLOG(INFO) << "Invoke Global " << name << " at index " << func_index;
return Invoke(this->functions[func_index], args); return Invoke(this->functions[func_index], args);
} }
...@@ -719,12 +750,12 @@ void VirtualMachine::Run() { ...@@ -719,12 +750,12 @@ void VirtualMachine::Run() {
auto object = ReadRegister(instr.closure); auto object = ReadRegister(instr.closure);
const auto& closure = object.AsClosure(); const auto& closure = object.AsClosure();
std::vector<Object> args; std::vector<Object> args;
for (Index i = 0; i < instr.closure_args_num; ++i) {
args.push_back(ReadRegister(instr.closure_args[i]));
}
for (auto free_var : closure->free_vars) { for (auto free_var : closure->free_vars) {
args.push_back(free_var); args.push_back(free_var);
} }
for (Index i = 0; i < instr.closure_args_num; ++i) {
args.push_back(ReadRegister(instr.closure_args[i]));
}
InvokeGlobal(this->functions[closure->func_index], args); InvokeGlobal(this->functions[closure->func_index], args);
frames.back().caller_return_register = instr.dst; frames.back().caller_return_register = instr.dst;
goto main_loop; goto main_loop;
......
...@@ -92,9 +92,9 @@ def test_squeezenet(): ...@@ -92,9 +92,9 @@ def test_squeezenet():
def test_inception_v3(): def test_inception_v3():
image_shape = (1, 3, 299, 299) image_shape = (3, 299, 299)
mod, params = testing.inception_v3.get_workload(image_shape=image_shape) mod, params = testing.inception_v3.get_workload(image_shape=image_shape)
benchmark_execution(mod, params, data_shape=image_shape) benchmark_execution(mod, params, data_shape=(1, 3, 299, 299))
def test_dqn(): def test_dqn():
...@@ -107,7 +107,7 @@ def test_dqn(): ...@@ -107,7 +107,7 @@ def test_dqn():
def test_dcgan(): def test_dcgan():
image_shape = (1, 100) image_shape = (1, 100)
mod, params = testing.dcgan.get_workload(batch_size=1) mod, params = testing.dcgan.get_workload(batch_size=1)
benchmark_execution(mod, params, data_shape=image_shape) benchmark_execution(mod, params, data_shape=image_shape, out_shape=(1, 3, 64, 64))
def test_mobilenet(): def test_mobilenet():
...@@ -126,8 +126,7 @@ if __name__ == '__main__': ...@@ -126,8 +126,7 @@ if __name__ == '__main__':
test_squeezenet() test_squeezenet()
test_mobilenet() test_mobilenet()
test_densenet() test_densenet()
# The following networks fail test_inception_v3()
# test_inception_v3() test_mlp()
# test_mlp() test_dqn()
# test_dqn() test_dcgan()
# test_dcgan()
...@@ -23,40 +23,44 @@ from tvm import relay ...@@ -23,40 +23,44 @@ from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
def veval(f, *args, ctx=tvm.cpu()): def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr): if isinstance(f, relay.Expr):
ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx) mod = relay.Module()
if len(args) == 0: mod["main"] = f
return ex.evaluate(f) compiler = relay.vm.VMCompiler()
else: vm = compiler.compile(mod, target)
return ex.evaluate(f)(*args) vm.init(tvm.cpu())
return vm.run(*args)
else: else:
assert isinstance(f, relay.Module), "expected expression or module" assert isinstance(f, relay.Module), "expected expression or module"
mod = f mod = f
ex = relay.create_executor('vm', mod=mod, ctx=ctx) compiler = relay.vm.VMCompiler()
if len(args) == 0: vm = compiler.compile(mod, target)
return ex.evaluate() vm.init(tvm.cpu())
else: ret = vm.run(*args)
return ex.evaluate()(*args) return ret
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.interpreter.TensorValue): if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
return [o.data.asnumpy().tolist()] return [o.asnumpy().tolist()]
if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
result = [] result = []
for f in o.fields: for f in o:
result.extend(vmobj_to_list(f)) result.extend(vmobj_to_list(f))
return result return result
else:
raise RuntimeError("Unknown object type: %s" % type(o))
def test_split(): def test_split():
x = relay.var('x', shape=(12,)) x = relay.var('x', shape=(12,))
y = relay.split(x, 3, axis=0).astuple() y = relay.split(x, 3, axis=0).astuple()
z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0) f = relay.Function([x], y)
f = relay.Function([x], z)
x_data = np.random.rand(12,).astype('float32') x_data = np.random.rand(12,).astype('float32')
res = veval(f, x_data) res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0]) ref_res = np.split(x_data, 3, axis=0)
for i in range(3):
tvm.testing.assert_allclose(res[i].asnumpy(), ref_res[i])
def test_split_no_fuse(): def test_split_no_fuse():
x = relay.var('x', shape=(12,)) x = relay.var('x', shape=(12,))
...@@ -68,9 +72,8 @@ def test_split_no_fuse(): ...@@ -68,9 +72,8 @@ def test_split_no_fuse():
res = veval(f, x_data) res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0]) tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
def test_id(): def test_id():
x = relay.var('x', shape=(10, 10)) x = relay.var('x', shape=(10, 10), dtype='float64')
f = relay.Function([x], x) f = relay.Function([x], x)
x_data = np.random.rand(10, 10).astype('float64') x_data = np.random.rand(10, 10).astype('float64')
res = veval(f, x_data) res = veval(f, x_data)
...@@ -209,7 +212,7 @@ def test_list_constructor(): ...@@ -209,7 +212,7 @@ def test_list_constructor():
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
obj = vmobj_to_list(result) obj = vmobj_to_list(result)
tvm.testing.assert_allclose(obj, np.array([3,2,1])) tvm.testing.assert_allclose(obj, np.array([3,2,1]))
...@@ -275,7 +278,7 @@ def test_compose(): ...@@ -275,7 +278,7 @@ def test_compose():
mod["main"] = f mod["main"] = f
x_data = np.array(np.random.rand()).astype('float32') x_data = np.array(np.random.rand()).astype('float32')
result = veval(mod)(x_data) result = veval(mod, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0) tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
...@@ -296,7 +299,7 @@ def test_list_hd(): ...@@ -296,7 +299,7 @@ def test_list_hd():
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 3) tvm.testing.assert_allclose(result.asnumpy(), 3)
@raises(Exception) @raises(Exception)
...@@ -312,7 +315,7 @@ def test_list_tl_empty_list(): ...@@ -312,7 +315,7 @@ def test_list_tl_empty_list():
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
print(result) print(result)
def test_list_tl(): def test_list_tl():
...@@ -332,7 +335,7 @@ def test_list_tl(): ...@@ -332,7 +335,7 @@ def test_list_tl():
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1])) tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1]))
def test_list_nth(): def test_list_nth():
...@@ -351,7 +354,7 @@ def test_list_nth(): ...@@ -351,7 +354,7 @@ def test_list_nth():
f = relay.Function([], nth(l, relay.const(i))) f = relay.Function([], nth(l, relay.const(i)))
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), expected[i]) tvm.testing.assert_allclose(result.asnumpy(), expected[i])
def test_list_update(): def test_list_update():
...@@ -375,7 +378,7 @@ def test_list_update(): ...@@ -375,7 +378,7 @@ def test_list_update():
f = relay.Function([], l) f = relay.Function([], l)
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected)) tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected))
def test_list_length(): def test_list_length():
...@@ -397,7 +400,7 @@ def test_list_length(): ...@@ -397,7 +400,7 @@ def test_list_length():
f = relay.Function([], l) f = relay.Function([], l)
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 10) tvm.testing.assert_allclose(result.asnumpy(), 10)
def test_list_map(): def test_list_map():
...@@ -415,7 +418,7 @@ def test_list_map(): ...@@ -415,7 +418,7 @@ def test_list_map():
f = relay.Function([], map(add_one_func, l)) f = relay.Function([], map(add_one_func, l))
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2])) tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2]))
def test_list_foldl(): def test_list_foldl():
...@@ -433,7 +436,7 @@ def test_list_foldl(): ...@@ -433,7 +436,7 @@ def test_list_foldl():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil()))) l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldl(rev_dup_func, nil(), l)) f = relay.Function([], foldl(rev_dup_func, nil(), l))
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1])) tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1]))
def test_list_foldr(): def test_list_foldr():
...@@ -451,7 +454,7 @@ def test_list_foldr(): ...@@ -451,7 +454,7 @@ def test_list_foldr():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil()))) l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldr(identity_func, nil(), l)) f = relay.Function([], foldr(identity_func, nil(), l))
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3])) tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3]))
def test_list_sum(): def test_list_sum():
...@@ -465,7 +468,7 @@ def test_list_sum(): ...@@ -465,7 +468,7 @@ def test_list_sum():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil()))) l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], sum(l)) f = relay.Function([], sum(l))
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 6) tvm.testing.assert_allclose(result.asnumpy(), 6)
def test_list_filter(): def test_list_filter():
...@@ -485,7 +488,7 @@ def test_list_filter(): ...@@ -485,7 +488,7 @@ def test_list_filter():
cons(relay.const(1), nil()))))) cons(relay.const(1), nil())))))
f = relay.Function([], filter(greater_than_one, l)) f = relay.Function([], filter(greater_than_one, l))
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5])) tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5]))
def test_closure(): def test_closure():
...@@ -513,6 +516,10 @@ if __name__ == "__main__": ...@@ -513,6 +516,10 @@ if __name__ == "__main__":
test_split() test_split()
test_split_no_fuse() test_split_no_fuse()
test_list_constructor() test_list_constructor()
test_let_tensor()
test_let_scalar()
test_compose()
test_list_hd()
test_list_tl_empty_list() test_list_tl_empty_list()
test_list_tl() test_list_tl()
test_list_nth() test_list_nth()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment