Commit b6dc7826 by Haichen Shen Committed by Jared Roesch

[Relay][VM] Port VM, VM compiler, and Object into python (#3391)

* tmp

* Port vm and object to python

* clean up

* update vm build module

* update

* x

* tweak

* cleanup

* update

* fix rebase

* Rename to VMCompiler

* fix
parent afd4b3e4
......@@ -103,7 +103,7 @@ typedef enum {
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
kObject = 14U,
kObjectCell = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
......@@ -176,6 +176,8 @@ typedef void* TVMRetValueHandle;
* can be NULL, which indicates the default one.
*/
typedef void* TVMStreamHandle;
/*! \brief Handle to Object. */
typedef void* TVMObjectHandle;
/*!
* \brief Used for implementing C API function.
......@@ -545,6 +547,15 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
TVMStreamHandle src,
TVMStreamHandle dst);
/*!
* \brief Get the tag from an object.
*
* \param obj The object handle.
* \param tag The tag of object.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag);
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
......
......@@ -324,6 +324,7 @@ class Object {
Object() : ptr_() {}
Object(const Object& obj) : ptr_(obj.ptr_) {}
ObjectCell* operator->() { return this->ptr_.operator->(); }
const ObjectCell* operator->() const { return this->ptr_.operator->(); }
/*! \brief Construct a tensor object. */
static Object Tensor(const NDArray& data);
......
......@@ -491,7 +491,7 @@ class TVMPODValue_ {
}
operator Object() const {
if (type_code_ == kNull) return Object();
TVM_CHECK_TYPE_CODE(type_code_, kObject);
TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
return Object(static_cast<ObjectCell*>(value_.v_handle));
}
operator TVMContext() const {
......@@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ {
}
TVMRetValue& operator=(Object other) {
this->Clear();
type_code_ = kObject;
type_code_ = kObjectCell;
value_.v_handle = other.ptr_.data_;
other.ptr_.data_ = nullptr;
return *this;
......@@ -861,7 +861,7 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >());
break;
}
case kObject: {
case kObjectCell: {
*this = other.operator Object();
break;
}
......@@ -912,7 +912,7 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
case kObject: {
case kObjectCell: {
static_cast<ObjectCell*>(value_.v_handle)->DecRef();
break;
}
......@@ -945,7 +945,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
case kObject: return "Object";
case kObjectCell: return "ObjectCell";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
......
......@@ -143,7 +143,7 @@ struct Instruction {
/*! \brief The registers containing the arguments. */
RegName* invoke_args_registers;
};
struct /* Const Operands */ {
struct /* LoadConst Operands */ {
/* \brief The index into the constant pool. */
Index const_index;
};
......@@ -308,14 +308,14 @@ struct Instruction {
struct VMFunction {
/*! \brief The function's name. */
std::string name;
/*! \brief The number of function parameters. */
Index params;
/*! \brief The function parameter names. */
std::vector<std::string> params;
/*! \brief The instructions representing the function. */
std::vector<Instruction> instructions;
/*! \brief The size of the frame for this function */
Index register_file_size;
VMFunction(const std::string& name, Index params,
VMFunction(const std::string& name, std::vector<std::string> params,
const std::vector<Instruction>& instructions,
Index register_file_size)
: name(name),
......@@ -370,7 +370,15 @@ struct VMFrame {
* multiple threads, or serialized them to disk or over the
* wire.
*/
struct VirtualMachine {
class VirtualMachine : public runtime::ModuleNode {
public:
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final {
return "VirtualMachine";
}
/*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs;
/*! \brief The virtual machine's function table. */
......@@ -442,7 +450,7 @@ struct VirtualMachine {
/*! \brief A map from globals (as strings) to their index in the function map.
*/
std::unordered_map<std::string, Index> global_map_;
std::unordered_map<std::string, Index> global_map;
private:
/*! \brief Invoke a global setting up the VM state to execute.
......
......@@ -37,6 +37,7 @@ from . import node as _node
FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
ObjectHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle):
......@@ -162,9 +163,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg)
elif isinstance(arg, ObjectBase):
elif isinstance(arg, _CLASS_OBJECT):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT
type_codes[i] = TypeCode.OBJECT_CELL
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
......@@ -236,6 +237,7 @@ def _return_module(x):
handle = ModuleHandle(handle)
return _CLASS_MODULE(handle)
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
......@@ -243,18 +245,11 @@ def _handle_return_func(x):
handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False)
class ObjectBase(object):
__slots__ = ["handle"]
def __init__(self, handle):
self.handle = handle
# setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
RETURN_SWITCH[TypeCode.OBJECT] = lambda x: _CLASS_OBJECT(x.v_handle)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH
ObjectHandle = ctypes.c_void_p
"""Maps object type to its constructor"""
OBJECT_TYPE = {}
def _register_object(index, cls):
"""register object class"""
OBJECT_TYPE[index] = cls
def _return_object(x):
handle = x.v_handle
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
tag = ctypes.c_int()
check_call(_LIB.TVMGetObjectTag(handle, ctypes.byref(tag)))
cls = OBJECT_TYPE.get(tag.value, ObjectBase)
obj = cls(handle)
return obj
RETURN_SWITCH[TypeCode.OBJECT_CELL] = _return_object
class ObjectBase(object):
__slots__ = ["handle"]
def __init__(self, handle):
self.handle = handle
......@@ -37,7 +37,7 @@ cdef enum TVMTypeCode:
kStr = 11
kBytes = 12
kNDArrayContainer = 13
kObject = 14
kObjectCell = 14
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h":
......@@ -130,6 +130,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
int TVMArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
int TVMGetObjectTag(ObjectHandle obj, int* tag)
cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle)
......
......@@ -19,3 +19,4 @@ include "./base.pxi"
include "./node.pxi"
include "./function.pxi"
include "./ndarray.pxi"
include "./vmobj.pxi"
......@@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle or
tcode == kObject or
tcode == kObjectCell or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))
......@@ -160,7 +160,7 @@ cdef inline int make_arg(object arg,
tcode[0] = kModuleHandle
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kObject
tcode[0] = kObjectCell
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
......@@ -212,8 +212,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
elif tcode == kObject:
return _CLASS_OBJECT(ctypes_handle(value.v_handle))
elif tcode == kObjectCell:
return make_ret_object(value.v_handle)
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
......@@ -310,27 +310,6 @@ cdef class FunctionBase:
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
_CLASS_FUNCTION = None
_CLASS_MODULE = None
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Maps object type to its constructor"""
OBJECT_TYPE = []
def _register_object(int index, object cls):
"""register node class"""
while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls
cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE
cdef int tag
cdef list object_type
cdef object cls
cdef object handle
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
CALL(TVMGetObjectTag(chandle, &tag))
if tag < len(object_type):
cls = object_type[tag]
if cls is not None:
obj = cls(handle)
else:
obj = ObjectBase(handle)
else:
obj = ObjectBase(handle)
return obj
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
......@@ -22,6 +22,7 @@ from __future__ import absolute_import
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from . import vmobj as _vmobj
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -30,28 +31,19 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module, _set_class_object
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import convert_to_tvm_func
else:
from ._cy2.core import _set_class_function, _set_class_module, _set_class_object
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module, _set_class_object
from ._ctypes.function import ObjectBase as _ObjectBase
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func
class Object(_ObjectBase):
# TODO(@jroesch): Eventually add back introspection functionality.
pass
_set_class_object(Object)
FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase):
......
......@@ -42,7 +42,7 @@ class TypeCode(object):
STR = 11
BYTES = 12
NDARRAY_CONTAINER = 13
OBJECT = 14
OBJECT_CELL = 14
EXT_BEGIN = 15
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import
import sys
from .base import _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_object
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_object
from ._ctypes.vmobj import ObjectBase as _ObjectBase
from ._ctypes.vmobj import _register_object
class ObjectTag(object):
"""Type code used in API calls"""
TENSOR = 0
CLOSURE = 1
DATATYPE = 2
class Object(_ObjectBase):
"""The VM Object used in Relay virtual machine."""
def register_object(cls):
_register_object(cls.tag, cls)
return cls
_set_class_object(Object)
......@@ -33,6 +33,8 @@ from . import parser
from . import debug
from . import param_dict
from . import feature
from .backend import vm
from .backend import vmobj
# Root operators
from .op import Op
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The VM Object FFI namespace."""
from tvm._ffi.function import _init_api
_init_api("_vmobj", __name__)
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
"""
The Relay Virtual Vachine.
......@@ -23,55 +23,38 @@ Implements a Python interface to compiling and executing on the Relay VM.
import numpy as np
import tvm
from tvm._ffi.function import Object
from .. import transform
from ..backend.interpreter import Executor
from ..expr import GlobalVar, Expr
from . import _vm
Object = Object
def optimize(mod):
"""Perform several optimizations on a module before executing it in the
Relay virtual machine.
Parameters
----------
mod : tvm.relay.Module
The module to optimize.
Returns
-------
ret : tvm.relay.Module
The optimized module.
"""
main_func = mod["main"]
opt_passes = []
if not main_func.params and isinstance(main_func.body, GlobalVar):
opt_passes.append(transform.EtaExpand())
opt_passes = opt_passes + [
transform.SimplifyInference(),
transform.FuseOps(),
transform.InferType()
]
seq = transform.Sequential(opt_passes)
return seq(mod)
from . import vmobj as _obj
from .interpreter import Executor
def _update_target(target):
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
def _convert(arg, cargs):
if isinstance(arg, np.ndarray):
tensor = _vm._Tensor(tvm.nd.array(arg))
cargs.append(tensor)
elif isinstance(arg, tvm.nd.NDArray):
tensor = _vm._Tensor(arg)
cargs.append(tensor)
elif isinstance(arg, tuple):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.tensor_object(arg))
elif isinstance(arg, (tuple, list)):
field_args = []
for field in arg:
_convert(field, field_args)
cargs.append(_vm._Tuple(*field_args))
cargs.append(_obj.tuple_object(field_args))
else:
raise "unsupported type"
......@@ -82,28 +65,97 @@ def convert(args):
return cargs
def _eval_vm(mod, ctx, *args):
"""
Evaluate a module on a given context with the provided arguments.
Parameters
----------
mod: relay.Module
The module to optimize, will execute its entry_func.
ctx: tvm.Context
The TVM context to execute on.
args: List[tvm.NDArray, np.ndarray]
The arguments to evaluate.
"""
mod = optimize(mod)
args = list(args)
assert isinstance(args, list)
cargs = convert(args)
class VirtualMachine(object):
"""Relay VM runtime."""
def __init__(self, mod):
self.mod = mod
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
def init(self, ctx):
"""Initialize the context in the VM.
Parameters
----------
ctx : :py:class:`TVMContext`
The runtime context to run the code on.
"""
args = [ctx.device_type, ctx.device_id]
self._init(*args)
def invoke(self, func_name, *args):
"""Invoke a function.
Parameters
----------
func_name : str
The name of the function.
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
Returns
-------
result : Object
The output.
"""
cargs = convert(args)
return self._invoke(func_name, *cargs)
def run(self, *args):
"""Run the main function.
Parameters
----------
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
Returns
-------
result : Object
The output.
"""
return self.invoke("main", *args)
class VMCompiler(object):
"""Build Relay module to run on VM runtime."""
def __init__(self):
self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
def compile(self, mod, target=None, target_host=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Returns
-------
vm : VirtualMachine
The VM runtime.
"""
target = _update_target(target)
self._compile(mod, target, target_host)
return VirtualMachine(self._get_vm())
result = _vm._evaluate_vm(mod, ctx.device_type, ctx.device_id, *cargs)
return result
class VMExecutor(Executor):
"""
......@@ -126,19 +178,21 @@ class VMExecutor(Executor):
The target option to build the function.
"""
def __init__(self, mod, ctx, target):
if mod is None:
raise RuntimeError("Must provide module to get VM executor.")
self.mod = mod
self.ctx = ctx
self.target = target
compiler = VMCompiler()
self.vm = compiler.compile(mod, target)
self.vm.init(ctx)
def _make_executor(self, expr=None):
expr = expr if expr else self.mod
assert expr, "either expr or self.mod should be not null."
if isinstance(expr, Expr):
self.mod["main"] = expr
assert expr is None
main = self.mod["main"]
def _vm_wrapper(*args, **kwargs):
args = self._convert_args(main, args, kwargs)
return _eval_vm(self.mod, self.ctx, *args)
return self.vm.run(*args)
return _vm_wrapper
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Runtime Object API."""
from __future__ import absolute_import as _abs
import numpy as _np
from tvm._ffi.vmobj import Object, ObjectTag, register_object
from tvm import ndarray as _nd
from . import _vmobj
# TODO(@icemelon9): Add ClosureObject
@register_object
class TensorObject(Object):
"""Tensor object."""
tag = ObjectTag.TENSOR
def __init__(self, handle):
"""Constructs a Tensor object
Parameters
----------
handle : object
Object handle
Returns
-------
obj : TensorObject
A tensor object.
"""
super(TensorObject, self).__init__(handle)
self.data = _vmobj.GetTensorData(self)
def asnumpy(self):
"""Convert data to numpy array
Returns
-------
np_arr : numpy.ndarray
The corresponding numpy array.
"""
return self.data.asnumpy()
@register_object
class DatatypeObject(Object):
"""Datatype object."""
tag = ObjectTag.DATATYPE
def __init__(self, handle):
"""Constructs a Datatype object
Parameters
----------
handle : object
Object handle
Returns
-------
obj : DatatypeObject
A Datatype object.
"""
super(DatatypeObject, self).__init__(handle)
self.tag = _vmobj.GetDatatypeTag(self)
num_fields = _vmobj.GetDatatypeNumberOfFields(self)
self.fields = []
for i in range(num_fields):
self.fields.append(_vmobj.GetDatatypeFields(self, i))
def __getitem__(self, idx):
return self.fields[idx]
def __len__(self):
return self.num_fields
def __iter__(self):
return iter(self.fields)
# TODO(icemelon9): Add closure object
def tensor_object(arr, ctx=_nd.cpu(0)):
"""Create a tensor object from source arr.
Parameters
----------
arr : numpy.ndarray or tvm.nd.NDArray
The source array.
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : TensorObject
The created object.
"""
if isinstance(arr, _np.ndarray):
tensor = _vmobj.Tensor(_nd.array(arr, ctx))
elif isinstance(arr, _nd.NDArray):
tensor = _vmobj.Tensor(arr)
else:
raise RuntimeError("Unsupported type for tensor object.")
return tensor
def tuple_object(fields):
"""Create a datatype object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : DatatypeObject
The created object.
"""
for f in fields:
assert isinstance(f, Object)
return _vmobj.Tuple(*fields)
def datatype_object(tag, fields):
"""Create a datatype object from tag and source fields.
Parameters
----------
tag : int
The tag of datatype.
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : DatatypeObject
The created object.
"""
for f in fields:
assert isinstance(f, Object)
return _vmobj.Datatype(tag, *fields)
......@@ -529,6 +529,18 @@ def CanonicalizeCast():
return _transform.CanonicalizeCast()
def LambdaLift():
"""
Lift the closure to global function.
Returns
-------
ret : tvm.relay.Pass
The registered pass that lifts the lambda function.
"""
return _transform.LambdaLift()
def PrintIR():
"""
Print the IR for a module to help debugging.
......
......@@ -24,6 +24,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/build_module.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
#include <memory>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/vm.cc
* \brief The Relay virtual machine.
*/
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/module.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/analysis.h>
namespace tvm {
namespace relay {
namespace vm {
runtime::vm::VirtualMachine CompileModule(const Module& mod);
using tvm::runtime::Object;
using tvm::runtime::ObjectTag;
using tvm::runtime::vm::VirtualMachine;
VirtualMachine FromModule(const Module& module, const std::vector<TVMContext>& ctxs) {
auto vm = CompileModule(module);
vm.Init(ctxs);
return vm;
}
Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
const std::vector<Object>& vm_args) {
VirtualMachine vm = FromModule(module, ctxs);
// TODO(zhiics): This measurement is for temporary usage. Remove it later. We
// need to introduce a better profiling method.
#if ENABLE_PROFILING
DLOG(INFO) << "Entry function is main." << std::endl;
auto start = std::chrono::high_resolution_clock::now();
#endif // ENABLE_PROFILING
Object res = vm.Invoke("main", vm_args);
#if ENABLE_PROFILING
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
LOG(INFO) << "Inference time: " << duration << "ms\n";
#endif // ENABLE_PROFILING
return res;
}
Value VMToValue(const relay::Module& module, Object obj) {
CHECK(module.defined());
switch (obj->tag) {
case ObjectTag::kTensor: {
return TensorValueNode::make(ToNDArray(obj));
}
case ObjectTag::kDatatype: {
const auto& data_type = obj.AsDatatype();
tvm::Array<Value> fields;
for (size_t i = 0; i < data_type->fields.size(); ++i) {
fields.push_back(VMToValue(module, data_type->fields[i]));
}
return ConstructorValueNode::make(data_type->tag, fields);
}
default:
LOG(FATAL) << "unsupported return value of type: " << obj->tag;
return Value();
}
}
TVM_REGISTER_API("relay._vm._Tensor").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Object::Tensor(args[0]);
});
TVM_REGISTER_API("relay._vm._Tuple").set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<Object> fields;
for (auto i = 0; i < args.size(); i++) {
fields.push_back(args[i]);
}
*ret = Object::Tuple(fields);
});
template <typename T>
std::string ToString(const T& t) {
std::stringstream s;
s << t;
return s.str();
}
TVM_REGISTER_API("relay._vm._ObjectTag").set_body([](TVMArgs args, TVMRetValue* ret) {
Object obj = args[0];
*ret = ToString(obj->tag);
});
TVM_REGISTER_API("relay._vm._Datatype")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*ret = Object::Datatype(tag, fields);
});
TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef to_compile = args[0];
TVMContext ctx;
int dev_type = args[1];
ctx.device_type = static_cast<DLDeviceType>(dev_type);
ctx.device_id = args[2];
Module module;
if (to_compile.as<FunctionNode>()) {
Function to_compile = args[0];
module = ModuleNode::FromExpr(to_compile);
} else if (to_compile.as<ModuleNode>()) {
module = args[0];
} else {
LOG(FATAL) << "expected function or module";
}
std::vector<Object> vm_args;
for (auto i = 3; i < args.size(); i++) {
Object obj = args[i];
vm_args.push_back(obj);
}
auto result = EvaluateModule(module, {ctx}, vm_args);
DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
*ret = VMToValue(module, result);
});
} // namespace vm
} // namespace relay
} // namespace tvm
......@@ -25,7 +25,10 @@
#include <tvm/logging.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include <iostream>
#include "../runtime_base.h"
namespace tvm {
namespace runtime {
......@@ -87,5 +90,69 @@ NDArray ToNDArray(const Object& obj) {
return tensor->data;
}
TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsTensor();
*rv = cell->data;
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsDatatype();
*rv = static_cast<int>(cell->tag);
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsDatatype();
*rv = static_cast<int>(cell->fields.size());
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
int idx = args[1];
auto cell = obj.AsDatatype();
CHECK_LT(idx, cell->fields.size());
*rv = cell->fields[idx];
});
TVM_REGISTER_GLOBAL("_vmobj.Tensor")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Object::Tensor(args[0]);
});
TVM_REGISTER_GLOBAL("_vmobj.Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<Object> fields;
for (auto i = 0; i < args.size(); ++i) {
fields.push_back(args[i]);
}
*rv = Object::Tuple(fields);
});
TVM_REGISTER_GLOBAL("_vmobj.Datatype")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*rv = Object::Datatype(tag, fields);
});
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
int TVMGetObjectTag(TVMObjectHandle handle, int* tag) {
API_BEGIN();
*tag = static_cast<int>(static_cast<ObjectCell*>(handle)->tag);
API_END();
}
......@@ -41,7 +41,7 @@ class PooledAllocator final : public Allocator {
static constexpr size_t kDefaultPageSize = 4096;
explicit PooledAllocator(TVMContext ctx, size_t page_size = kDefaultPageSize)
: Allocator(), page_size_(page_size), used_memory_(0) {}
: Allocator(), page_size_(page_size), used_memory_(0), ctx_(ctx) {}
~PooledAllocator() { ReleaseAll(); }
......
......@@ -32,8 +32,8 @@
#include <stdexcept>
#include <vector>
#include "../../runtime/vm/memory_manager.h"
#include "../../runtime/vm/naive_allocator.h"
#include "memory_manager.h"
#include "naive_allocator.h"
using namespace tvm::runtime;
......@@ -554,6 +554,37 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
return os;
}
PackedFunc VirtualMachine::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
std::vector<Object> func_args;
for (int i = 1; i < args.size(); ++i) {
Object obj = args[i];
func_args.push_back(obj);
}
*rv = this->Invoke(func_name, func_args);
});
} else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size() % 2, 0);
std::vector<TVMContext> contexts;
for (int i = 0; i < args.size() / 2; ++i) {
TVMContext ctx;
int device_type = args[i * 2];
ctx.device_type = DLDeviceType(device_type);
ctx.device_id = args[i * 2 + 1];
contexts.push_back(ctx);
}
this->Init(contexts);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}
void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) {
auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size);
frames.push_back(frame);
......@@ -573,11 +604,11 @@ Index VirtualMachine::PopFrame() {
void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Object>& args) {
DLOG(INFO) << "Invoking global " << func.name << " " << args.size();
PushFrame(func.params, this->pc + 1, func);
PushFrame(func.params.size(), this->pc + 1, func);
for (size_t i = 0; i < args.size(); ++i) {
WriteRegister(i, args[i]);
}
DLOG(INFO) << "func.params= " << func.params;
DLOG(INFO) << "func.params= " << func.params.size();
code = func.instructions.data();
pc = 0;
......@@ -594,7 +625,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>&
}
Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object>& args) {
auto func_index = this->global_map_[name];
auto func_index = this->global_map[name];
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index;
return Invoke(this->functions[func_index], args);
}
......@@ -719,12 +750,12 @@ void VirtualMachine::Run() {
auto object = ReadRegister(instr.closure);
const auto& closure = object.AsClosure();
std::vector<Object> args;
for (Index i = 0; i < instr.closure_args_num; ++i) {
args.push_back(ReadRegister(instr.closure_args[i]));
}
for (auto free_var : closure->free_vars) {
args.push_back(free_var);
}
for (Index i = 0; i < instr.closure_args_num; ++i) {
args.push_back(ReadRegister(instr.closure_args[i]));
}
InvokeGlobal(this->functions[closure->func_index], args);
frames.back().caller_return_register = instr.dst;
goto main_loop;
......
......@@ -92,9 +92,9 @@ def test_squeezenet():
def test_inception_v3():
image_shape = (1, 3, 299, 299)
image_shape = (3, 299, 299)
mod, params = testing.inception_v3.get_workload(image_shape=image_shape)
benchmark_execution(mod, params, data_shape=image_shape)
benchmark_execution(mod, params, data_shape=(1, 3, 299, 299))
def test_dqn():
......@@ -107,7 +107,7 @@ def test_dqn():
def test_dcgan():
image_shape = (1, 100)
mod, params = testing.dcgan.get_workload(batch_size=1)
benchmark_execution(mod, params, data_shape=image_shape)
benchmark_execution(mod, params, data_shape=image_shape, out_shape=(1, 3, 64, 64))
def test_mobilenet():
......@@ -126,8 +126,7 @@ if __name__ == '__main__':
test_squeezenet()
test_mobilenet()
test_densenet()
# The following networks fail
# test_inception_v3()
# test_mlp()
# test_dqn()
# test_dcgan()
test_inception_v3()
test_mlp()
test_dqn()
test_dcgan()
......@@ -23,40 +23,44 @@ from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude
def veval(f, *args, ctx=tvm.cpu()):
def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr):
ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx)
if len(args) == 0:
return ex.evaluate(f)
else:
return ex.evaluate(f)(*args)
mod = relay.Module()
mod["main"] = f
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(tvm.cpu())
return vm.run(*args)
else:
assert isinstance(f, relay.Module), "expected expression or module"
mod = f
ex = relay.create_executor('vm', mod=mod, ctx=ctx)
if len(args) == 0:
return ex.evaluate()
else:
return ex.evaluate()(*args)
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(tvm.cpu())
ret = vm.run(*args)
return ret
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy().tolist()]
if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
result = []
for f in o.fields:
for f in o:
result.extend(vmobj_to_list(f))
return result
else:
raise RuntimeError("Unknown object type: %s" % type(o))
def test_split():
x = relay.var('x', shape=(12,))
y = relay.split(x, 3, axis=0).astuple()
z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0)
f = relay.Function([x], z)
f = relay.Function([x], y)
x_data = np.random.rand(12,).astype('float32')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
ref_res = np.split(x_data, 3, axis=0)
for i in range(3):
tvm.testing.assert_allclose(res[i].asnumpy(), ref_res[i])
def test_split_no_fuse():
x = relay.var('x', shape=(12,))
......@@ -68,9 +72,8 @@ def test_split_no_fuse():
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
def test_id():
x = relay.var('x', shape=(10, 10))
x = relay.var('x', shape=(10, 10), dtype='float64')
f = relay.Function([x], x)
x_data = np.random.rand(10, 10).astype('float64')
res = veval(f, x_data)
......@@ -209,7 +212,7 @@ def test_list_constructor():
mod["main"] = f
result = veval(mod)()
result = veval(mod)
obj = vmobj_to_list(result)
tvm.testing.assert_allclose(obj, np.array([3,2,1]))
......@@ -275,7 +278,7 @@ def test_compose():
mod["main"] = f
x_data = np.array(np.random.rand()).astype('float32')
result = veval(mod)(x_data)
result = veval(mod, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
......@@ -296,7 +299,7 @@ def test_list_hd():
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 3)
@raises(Exception)
......@@ -312,7 +315,7 @@ def test_list_tl_empty_list():
mod["main"] = f
result = veval(mod)()
result = veval(mod)
print(result)
def test_list_tl():
......@@ -332,7 +335,7 @@ def test_list_tl():
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1]))
def test_list_nth():
......@@ -351,7 +354,7 @@ def test_list_nth():
f = relay.Function([], nth(l, relay.const(i)))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), expected[i])
def test_list_update():
......@@ -375,7 +378,7 @@ def test_list_update():
f = relay.Function([], l)
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected))
def test_list_length():
......@@ -397,7 +400,7 @@ def test_list_length():
f = relay.Function([], l)
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 10)
def test_list_map():
......@@ -415,7 +418,7 @@ def test_list_map():
f = relay.Function([], map(add_one_func, l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2]))
def test_list_foldl():
......@@ -433,7 +436,7 @@ def test_list_foldl():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldl(rev_dup_func, nil(), l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1]))
def test_list_foldr():
......@@ -451,7 +454,7 @@ def test_list_foldr():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldr(identity_func, nil(), l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3]))
def test_list_sum():
......@@ -465,7 +468,7 @@ def test_list_sum():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], sum(l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 6)
def test_list_filter():
......@@ -485,7 +488,7 @@ def test_list_filter():
cons(relay.const(1), nil())))))
f = relay.Function([], filter(greater_than_one, l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5]))
def test_closure():
......@@ -513,6 +516,10 @@ if __name__ == "__main__":
test_split()
test_split_no_fuse()
test_list_constructor()
test_let_tensor()
test_let_scalar()
test_compose()
test_list_hd()
test_list_tl_empty_list()
test_list_tl()
test_list_nth()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment