Commit b6dc7826 by Haichen Shen Committed by Jared Roesch

[Relay][VM] Port VM, VM compiler, and Object into python (#3391)

* tmp

* Port vm and object to python

* clean up

* update vm build module

* update

* x

* tweak

* cleanup

* update

* fix rebase

* Rename to VMCompiler

* fix
parent afd4b3e4
...@@ -103,7 +103,7 @@ typedef enum { ...@@ -103,7 +103,7 @@ typedef enum {
kStr = 11U, kStr = 11U,
kBytes = 12U, kBytes = 12U,
kNDArrayContainer = 13U, kNDArrayContainer = 13U,
kObject = 14U, kObjectCell = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc. // Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and // To make sure each framework's id do not conflict, use first and
// last sections to mark ranges. // last sections to mark ranges.
...@@ -176,6 +176,8 @@ typedef void* TVMRetValueHandle; ...@@ -176,6 +176,8 @@ typedef void* TVMRetValueHandle;
* can be NULL, which indicates the default one. * can be NULL, which indicates the default one.
*/ */
typedef void* TVMStreamHandle; typedef void* TVMStreamHandle;
/*! \brief Handle to Object. */
typedef void* TVMObjectHandle;
/*! /*!
* \brief Used for implementing C API function. * \brief Used for implementing C API function.
...@@ -545,6 +547,15 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type, ...@@ -545,6 +547,15 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
TVMStreamHandle src, TVMStreamHandle src,
TVMStreamHandle dst); TVMStreamHandle dst);
/*!
* \brief Get the tag from an object.
*
* \param obj The object handle.
* \param tag The tag of object.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag);
#ifdef __cplusplus #ifdef __cplusplus
} // TVM_EXTERN_C } // TVM_EXTERN_C
#endif #endif
......
...@@ -324,6 +324,7 @@ class Object { ...@@ -324,6 +324,7 @@ class Object {
Object() : ptr_() {} Object() : ptr_() {}
Object(const Object& obj) : ptr_(obj.ptr_) {} Object(const Object& obj) : ptr_(obj.ptr_) {}
ObjectCell* operator->() { return this->ptr_.operator->(); } ObjectCell* operator->() { return this->ptr_.operator->(); }
const ObjectCell* operator->() const { return this->ptr_.operator->(); }
/*! \brief Construct a tensor object. */ /*! \brief Construct a tensor object. */
static Object Tensor(const NDArray& data); static Object Tensor(const NDArray& data);
......
...@@ -491,7 +491,7 @@ class TVMPODValue_ { ...@@ -491,7 +491,7 @@ class TVMPODValue_ {
} }
operator Object() const { operator Object() const {
if (type_code_ == kNull) return Object(); if (type_code_ == kNull) return Object();
TVM_CHECK_TYPE_CODE(type_code_, kObject); TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
return Object(static_cast<ObjectCell*>(value_.v_handle)); return Object(static_cast<ObjectCell*>(value_.v_handle));
} }
operator TVMContext() const { operator TVMContext() const {
...@@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ {
} }
TVMRetValue& operator=(Object other) { TVMRetValue& operator=(Object other) {
this->Clear(); this->Clear();
type_code_ = kObject; type_code_ = kObjectCell;
value_.v_handle = other.ptr_.data_; value_.v_handle = other.ptr_.data_;
other.ptr_.data_ = nullptr; other.ptr_.data_ = nullptr;
return *this; return *this;
...@@ -861,7 +861,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -861,7 +861,7 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >()); kNodeHandle, *other.template ptr<NodePtr<Node> >());
break; break;
} }
case kObject: { case kObjectCell: {
*this = other.operator Object(); *this = other.operator Object();
break; break;
} }
...@@ -912,7 +912,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -912,7 +912,7 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef(); static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break; break;
} }
case kObject: { case kObjectCell: {
static_cast<ObjectCell*>(value_.v_handle)->DecRef(); static_cast<ObjectCell*>(value_.v_handle)->DecRef();
break; break;
} }
...@@ -945,7 +945,7 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -945,7 +945,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle"; case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle"; case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer"; case kNDArrayContainer: return "NDArrayContainer";
case kObject: return "Object"; case kObjectCell: return "ObjectCell";
default: LOG(FATAL) << "unknown type_code=" default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return ""; << static_cast<int>(type_code); return "";
} }
......
...@@ -143,7 +143,7 @@ struct Instruction { ...@@ -143,7 +143,7 @@ struct Instruction {
/*! \brief The registers containing the arguments. */ /*! \brief The registers containing the arguments. */
RegName* invoke_args_registers; RegName* invoke_args_registers;
}; };
struct /* Const Operands */ { struct /* LoadConst Operands */ {
/* \brief The index into the constant pool. */ /* \brief The index into the constant pool. */
Index const_index; Index const_index;
}; };
...@@ -308,14 +308,14 @@ struct Instruction { ...@@ -308,14 +308,14 @@ struct Instruction {
struct VMFunction { struct VMFunction {
/*! \brief The function's name. */ /*! \brief The function's name. */
std::string name; std::string name;
/*! \brief The number of function parameters. */ /*! \brief The function parameter names. */
Index params; std::vector<std::string> params;
/*! \brief The instructions representing the function. */ /*! \brief The instructions representing the function. */
std::vector<Instruction> instructions; std::vector<Instruction> instructions;
/*! \brief The size of the frame for this function */ /*! \brief The size of the frame for this function */
Index register_file_size; Index register_file_size;
VMFunction(const std::string& name, Index params, VMFunction(const std::string& name, std::vector<std::string> params,
const std::vector<Instruction>& instructions, const std::vector<Instruction>& instructions,
Index register_file_size) Index register_file_size)
: name(name), : name(name),
...@@ -370,7 +370,15 @@ struct VMFrame { ...@@ -370,7 +370,15 @@ struct VMFrame {
* multiple threads, or serialized them to disk or over the * multiple threads, or serialized them to disk or over the
* wire. * wire.
*/ */
struct VirtualMachine { class VirtualMachine : public runtime::ModuleNode {
public:
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final {
return "VirtualMachine";
}
/*! \brief The virtual machine's packed function table. */ /*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs; std::vector<PackedFunc> packed_funcs;
/*! \brief The virtual machine's function table. */ /*! \brief The virtual machine's function table. */
...@@ -442,7 +450,7 @@ struct VirtualMachine { ...@@ -442,7 +450,7 @@ struct VirtualMachine {
/*! \brief A map from globals (as strings) to their index in the function map. /*! \brief A map from globals (as strings) to their index in the function map.
*/ */
std::unordered_map<std::string, Index> global_map_; std::unordered_map<std::string, Index> global_map;
private: private:
/*! \brief Invoke a global setting up the VM state to execute. /*! \brief Invoke a global setting up the VM state to execute.
......
...@@ -37,6 +37,7 @@ from . import node as _node ...@@ -37,6 +37,7 @@ from . import node as _node
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p
ObjectHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p TVMRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle): def _ctypes_free_resource(rhandle):
...@@ -162,9 +163,9 @@ def _make_tvm_args(args, temp_args): ...@@ -162,9 +163,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg) temp_args.append(arg)
elif isinstance(arg, ObjectBase): elif isinstance(arg, _CLASS_OBJECT):
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT type_codes[i] = TypeCode.OBJECT_CELL
else: else:
raise TypeError("Don't know how to handle type %s" % type(arg)) raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args return values, type_codes, num_args
...@@ -236,6 +237,7 @@ def _return_module(x): ...@@ -236,6 +237,7 @@ def _return_module(x):
handle = ModuleHandle(handle) handle = ModuleHandle(handle)
return _CLASS_MODULE(handle) return _CLASS_MODULE(handle)
def _handle_return_func(x): def _handle_return_func(x):
"""Return function""" """Return function"""
handle = x.v_handle handle = x.v_handle
...@@ -243,18 +245,11 @@ def _handle_return_func(x): ...@@ -243,18 +245,11 @@ def _handle_return_func(x):
handle = FunctionHandle(handle) handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False) return _CLASS_FUNCTION(handle, False)
class ObjectBase(object):
__slots__ = ["handle"]
def __init__(self, handle):
self.handle = handle
# setup return handle for function type # setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__ _node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True) RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
RETURN_SWITCH[TypeCode.OBJECT] = lambda x: _CLASS_OBJECT(x.v_handle)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE) _handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH
ObjectHandle = ctypes.c_void_p
"""Maps object type to its constructor"""
OBJECT_TYPE = {}
def _register_object(index, cls):
"""register object class"""
OBJECT_TYPE[index] = cls
def _return_object(x):
handle = x.v_handle
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
tag = ctypes.c_int()
check_call(_LIB.TVMGetObjectTag(handle, ctypes.byref(tag)))
cls = OBJECT_TYPE.get(tag.value, ObjectBase)
obj = cls(handle)
return obj
RETURN_SWITCH[TypeCode.OBJECT_CELL] = _return_object
class ObjectBase(object):
__slots__ = ["handle"]
def __init__(self, handle):
self.handle = handle
...@@ -37,7 +37,7 @@ cdef enum TVMTypeCode: ...@@ -37,7 +37,7 @@ cdef enum TVMTypeCode:
kStr = 11 kStr = 11
kBytes = 12 kBytes = 12
kNDArrayContainer = 13 kNDArrayContainer = 13
kObject = 14 kObjectCell = 14
kExtBegin = 15 kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "tvm/runtime/c_runtime_api.h":
...@@ -130,6 +130,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ...@@ -130,6 +130,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
int TVMArrayToDLPack(DLTensorHandle arr_from, int TVMArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out) DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
int TVMGetObjectTag(ObjectHandle obj, int* tag)
cdef extern from "tvm/c_dsl_api.h": cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle) int TVMNodeFree(NodeHandle handle)
......
...@@ -19,3 +19,4 @@ include "./base.pxi" ...@@ -19,3 +19,4 @@ include "./base.pxi"
include "./node.pxi" include "./node.pxi"
include "./function.pxi" include "./function.pxi"
include "./ndarray.pxi" include "./ndarray.pxi"
include "./vmobj.pxi"
...@@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args, ...@@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kNodeHandle or if (tcode == kNodeHandle or
tcode == kFuncHandle or tcode == kFuncHandle or
tcode == kModuleHandle or tcode == kModuleHandle or
tcode == kObject or tcode == kObjectCell or
tcode > kExtBegin): tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode)) CALL(TVMCbArgToReturn(&value, tcode))
...@@ -160,7 +160,7 @@ cdef inline int make_arg(object arg, ...@@ -160,7 +160,7 @@ cdef inline int make_arg(object arg,
tcode[0] = kModuleHandle tcode[0] = kModuleHandle
elif isinstance(arg, _CLASS_OBJECT): elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = c_handle(arg.handle) value[0].v_handle = c_handle(arg.handle)
tcode[0] = kObject tcode[0] = kObjectCell
elif isinstance(arg, FunctionBase): elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle tcode[0] = kFuncHandle
...@@ -212,8 +212,8 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -212,8 +212,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False) fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle (<FunctionBase>fobj).chandle = value.v_handle
return fobj return fobj
elif tcode == kObject: elif tcode == kObjectCell:
return _CLASS_OBJECT(ctypes_handle(value.v_handle)) return make_ret_object(value.v_handle)
elif tcode in _TVM_EXT_RET: elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
...@@ -310,27 +310,6 @@ cdef class FunctionBase: ...@@ -310,27 +310,6 @@ cdef class FunctionBase:
FuncCall(self.chandle, args, &ret_val, &ret_tcode) FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode) return make_ret(ret_val, ret_tcode)
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
_CLASS_FUNCTION = None _CLASS_FUNCTION = None
_CLASS_MODULE = None _CLASS_MODULE = None
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Maps object type to its constructor"""
OBJECT_TYPE = []
def _register_object(int index, object cls):
"""register node class"""
while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls
cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE
cdef int tag
cdef list object_type
cdef object cls
cdef object handle
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
CALL(TVMGetObjectTag(chandle, &tag))
if tag < len(object_type):
cls = object_type[tag]
if cls is not None:
obj = cls(handle)
else:
obj = ObjectBase(handle)
else:
obj = ObjectBase(handle)
return obj
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
...@@ -22,6 +22,7 @@ from __future__ import absolute_import ...@@ -22,6 +22,7 @@ from __future__ import absolute_import
import sys import sys
import ctypes import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from . import vmobj as _vmobj
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -30,28 +31,19 @@ try: ...@@ -30,28 +31,19 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module, _set_class_object from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import convert_to_tvm_func from ._cy3.core import convert_to_tvm_func
else: else:
from ._cy2.core import _set_class_function, _set_class_module, _set_class_object from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import convert_to_tvm_func from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module, _set_class_object from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import ObjectBase as _ObjectBase
from ._ctypes.function import FunctionBase as _FunctionBase from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func from ._ctypes.function import convert_to_tvm_func
class Object(_ObjectBase):
# TODO(@jroesch): Eventually add back introspection functionality.
pass
_set_class_object(Object)
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase): class Function(_FunctionBase):
......
...@@ -42,7 +42,7 @@ class TypeCode(object): ...@@ -42,7 +42,7 @@ class TypeCode(object):
STR = 11 STR = 11
BYTES = 12 BYTES = 12
NDARRAY_CONTAINER = 13 NDARRAY_CONTAINER = 13
OBJECT = 14 OBJECT_CELL = 14
EXT_BEGIN = 15 EXT_BEGIN = 15
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import
import sys
from .base import _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_object
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_object
from ._ctypes.vmobj import ObjectBase as _ObjectBase
from ._ctypes.vmobj import _register_object
class ObjectTag(object):
"""Type code used in API calls"""
TENSOR = 0
CLOSURE = 1
DATATYPE = 2
class Object(_ObjectBase):
"""The VM Object used in Relay virtual machine."""
def register_object(cls):
_register_object(cls.tag, cls)
return cls
_set_class_object(Object)
...@@ -33,6 +33,8 @@ from . import parser ...@@ -33,6 +33,8 @@ from . import parser
from . import debug from . import debug
from . import param_dict from . import param_dict
from . import feature from . import feature
from .backend import vm
from .backend import vmobj
# Root operators # Root operators
from .op import Op from .op import Op
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The VM Object FFI namespace."""
from tvm._ffi.function import _init_api
_init_api("_vmobj", __name__)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
""" """
The Relay Virtual Vachine. The Relay Virtual Vachine.
...@@ -23,55 +23,38 @@ Implements a Python interface to compiling and executing on the Relay VM. ...@@ -23,55 +23,38 @@ Implements a Python interface to compiling and executing on the Relay VM.
import numpy as np import numpy as np
import tvm import tvm
from tvm._ffi.function import Object
from .. import transform
from ..backend.interpreter import Executor
from ..expr import GlobalVar, Expr
from . import _vm from . import _vm
from . import vmobj as _obj
Object = Object from .interpreter import Executor
def optimize(mod):
"""Perform several optimizations on a module before executing it in the def _update_target(target):
Relay virtual machine. target = target if target else tvm.target.current_target()
if target is None:
Parameters raise ValueError("Target is not set in env or passed as argument.")
----------
mod : tvm.relay.Module tgts = {}
The module to optimize. if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
Returns tgts[dev_type] = tvm.target.create(target)
------- elif isinstance(target, dict):
ret : tvm.relay.Module for dev, tgt in target.items():
The optimized module. dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
""" tgts[dev_type] = tvm.target.create(tgt)
main_func = mod["main"] else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
opt_passes = [] "or dict of str to str/tvm.target.Target, but received " +
if not main_func.params and isinstance(main_func.body, GlobalVar): "{}".format(type(target)))
opt_passes.append(transform.EtaExpand()) return tgts
opt_passes = opt_passes + [
transform.SimplifyInference(),
transform.FuseOps(),
transform.InferType()
]
seq = transform.Sequential(opt_passes)
return seq(mod)
def _convert(arg, cargs): def _convert(arg, cargs):
if isinstance(arg, np.ndarray): if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
tensor = _vm._Tensor(tvm.nd.array(arg)) cargs.append(_obj.tensor_object(arg))
cargs.append(tensor) elif isinstance(arg, (tuple, list)):
elif isinstance(arg, tvm.nd.NDArray):
tensor = _vm._Tensor(arg)
cargs.append(tensor)
elif isinstance(arg, tuple):
field_args = [] field_args = []
for field in arg: for field in arg:
_convert(field, field_args) _convert(field, field_args)
cargs.append(_vm._Tuple(*field_args)) cargs.append(_obj.tuple_object(field_args))
else: else:
raise "unsupported type" raise "unsupported type"
...@@ -82,28 +65,97 @@ def convert(args): ...@@ -82,28 +65,97 @@ def convert(args):
return cargs return cargs
def _eval_vm(mod, ctx, *args):
"""
Evaluate a module on a given context with the provided arguments.
Parameters class VirtualMachine(object):
---------- """Relay VM runtime."""
mod: relay.Module def __init__(self, mod):
The module to optimize, will execute its entry_func. self.mod = mod
self._init = self.mod["init"]
ctx: tvm.Context self._invoke = self.mod["invoke"]
The TVM context to execute on.
def init(self, ctx):
args: List[tvm.NDArray, np.ndarray] """Initialize the context in the VM.
The arguments to evaluate.
""" Parameters
mod = optimize(mod) ----------
args = list(args) ctx : :py:class:`TVMContext`
assert isinstance(args, list) The runtime context to run the code on.
cargs = convert(args) """
args = [ctx.device_type, ctx.device_id]
self._init(*args)
def invoke(self, func_name, *args):
"""Invoke a function.
Parameters
----------
func_name : str
The name of the function.
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
Returns
-------
result : Object
The output.
"""
cargs = convert(args)
return self._invoke(func_name, *cargs)
def run(self, *args):
"""Run the main function.
Parameters
----------
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
Returns
-------
result : Object
The output.
"""
return self.invoke("main", *args)
class VMCompiler(object):
"""Build Relay module to run on VM runtime."""
def __init__(self):
self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
def compile(self, mod, target=None, target_host=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Returns
-------
vm : VirtualMachine
The VM runtime.
"""
target = _update_target(target)
self._compile(mod, target, target_host)
return VirtualMachine(self._get_vm())
result = _vm._evaluate_vm(mod, ctx.device_type, ctx.device_id, *cargs)
return result
class VMExecutor(Executor): class VMExecutor(Executor):
""" """
...@@ -126,19 +178,21 @@ class VMExecutor(Executor): ...@@ -126,19 +178,21 @@ class VMExecutor(Executor):
The target option to build the function. The target option to build the function.
""" """
def __init__(self, mod, ctx, target): def __init__(self, mod, ctx, target):
if mod is None:
raise RuntimeError("Must provide module to get VM executor.")
self.mod = mod self.mod = mod
self.ctx = ctx self.ctx = ctx
self.target = target self.target = target
compiler = VMCompiler()
self.vm = compiler.compile(mod, target)
self.vm.init(ctx)
def _make_executor(self, expr=None): def _make_executor(self, expr=None):
expr = expr if expr else self.mod assert expr is None
assert expr, "either expr or self.mod should be not null."
if isinstance(expr, Expr):
self.mod["main"] = expr
main = self.mod["main"] main = self.mod["main"]
def _vm_wrapper(*args, **kwargs): def _vm_wrapper(*args, **kwargs):
args = self._convert_args(main, args, kwargs) args = self._convert_args(main, args, kwargs)
return _eval_vm(self.mod, self.ctx, *args) return self.vm.run(*args)
return _vm_wrapper return _vm_wrapper
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Runtime Object API."""
from __future__ import absolute_import as _abs
import numpy as _np
from tvm._ffi.vmobj import Object, ObjectTag, register_object
from tvm import ndarray as _nd
from . import _vmobj
# TODO(@icemelon9): Add ClosureObject
@register_object
class TensorObject(Object):
"""Tensor object."""
tag = ObjectTag.TENSOR
def __init__(self, handle):
"""Constructs a Tensor object
Parameters
----------
handle : object
Object handle
Returns
-------
obj : TensorObject
A tensor object.
"""
super(TensorObject, self).__init__(handle)
self.data = _vmobj.GetTensorData(self)
def asnumpy(self):
"""Convert data to numpy array
Returns
-------
np_arr : numpy.ndarray
The corresponding numpy array.
"""
return self.data.asnumpy()
@register_object
class DatatypeObject(Object):
"""Datatype object."""
tag = ObjectTag.DATATYPE
def __init__(self, handle):
"""Constructs a Datatype object
Parameters
----------
handle : object
Object handle
Returns
-------
obj : DatatypeObject
A Datatype object.
"""
super(DatatypeObject, self).__init__(handle)
self.tag = _vmobj.GetDatatypeTag(self)
num_fields = _vmobj.GetDatatypeNumberOfFields(self)
self.fields = []
for i in range(num_fields):
self.fields.append(_vmobj.GetDatatypeFields(self, i))
def __getitem__(self, idx):
return self.fields[idx]
def __len__(self):
return self.num_fields
def __iter__(self):
return iter(self.fields)
# TODO(icemelon9): Add closure object
def tensor_object(arr, ctx=_nd.cpu(0)):
"""Create a tensor object from source arr.
Parameters
----------
arr : numpy.ndarray or tvm.nd.NDArray
The source array.
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : TensorObject
The created object.
"""
if isinstance(arr, _np.ndarray):
tensor = _vmobj.Tensor(_nd.array(arr, ctx))
elif isinstance(arr, _nd.NDArray):
tensor = _vmobj.Tensor(arr)
else:
raise RuntimeError("Unsupported type for tensor object.")
return tensor
def tuple_object(fields):
"""Create a datatype object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : DatatypeObject
The created object.
"""
for f in fields:
assert isinstance(f, Object)
return _vmobj.Tuple(*fields)
def datatype_object(tag, fields):
"""Create a datatype object from tag and source fields.
Parameters
----------
tag : int
The tag of datatype.
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : DatatypeObject
The created object.
"""
for f in fields:
assert isinstance(f, Object)
return _vmobj.Datatype(tag, *fields)
...@@ -529,6 +529,18 @@ def CanonicalizeCast(): ...@@ -529,6 +529,18 @@ def CanonicalizeCast():
return _transform.CanonicalizeCast() return _transform.CanonicalizeCast()
def LambdaLift():
"""
Lift the closure to global function.
Returns
-------
ret : tvm.relay.Pass
The registered pass that lifts the lambda function.
"""
return _transform.LambdaLift()
def PrintIR(): def PrintIR():
""" """
Print the IR for a module to help debugging. Print the IR for a module to help debugging.
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <memory> #include <memory>
......
...@@ -63,6 +63,7 @@ using TagNameMap = std::unordered_map<size_t, tvm::relay::Constructor>; ...@@ -63,6 +63,7 @@ using TagNameMap = std::unordered_map<size_t, tvm::relay::Constructor>;
using GlobalMap = NodeMap<GlobalVar, Index>; using GlobalMap = NodeMap<GlobalVar, Index>;
using ConstMap = NodeMap<Constant, Index>; using ConstMap = NodeMap<Constant, Index>;
using ConstTensorShapeMap = NodeMap<TensorType, std::pair<Index, NDArray>>; using ConstTensorShapeMap = NodeMap<TensorType, std::pair<Index, NDArray>>;
using TargetsMap = Map<tvm::Integer, tvm::Target>;
struct VMCompilerContext { struct VMCompilerContext {
// The module context for the compilation // The module context for the compilation
...@@ -153,47 +154,134 @@ struct AccessField : MatchValue { ...@@ -153,47 +154,134 @@ struct AccessField : MatchValue {
~AccessField() {} ~AccessField() {}
}; };
struct VMCompiler; /*!
* \brief Condition in a decision tree
*/
struct ConditionNode {
virtual ~ConditionNode() {}
};
using ConditionNodePtr = std::shared_ptr<ConditionNode>;
/*! /*!
* \brief Compile a pattern match expression * \brief A var binding condition
* It first converts the pattern match expression into a desicision tree, the condition
* could be object comparison or variable binding. If any of the condition fails in a clause,
* the decision tree switches to check the conditions of next clause and so on. If no clause
* matches the value, a fatal node is inserted.
*
* After the decision tree is built, we convert it into bytecodes using If/Goto.
*/ */
void CompileMatch(Match match, VMCompiler* compiler); struct VarBinding : ConditionNode {
Var var;
MatchValuePtr val;
struct VMCompiler : ExprFunctor<void(const Expr& expr)> { VarBinding(Var var, MatchValuePtr val)
/*! \brief Store the expression a variable points to. */ : var(var), val(val) {}
std::unordered_map<Var, Expr, NodeHash, NodeEqual> expr_map;
std::vector<Instruction> instructions; ~VarBinding() {}
};
// var -> register num /*!
std::unordered_map<Var, RegName, NodeHash, NodeEqual> var_register_map; * \brief Compare the tag of the object
*/
struct TagCompare : ConditionNode {
/*! \brief The object to be examined */
MatchValuePtr obj;
size_t last_register; /*! \brief The expected tag */
int target_tag;
// Total number of virtual registers allocated TagCompare(MatchValuePtr obj, size_t target)
size_t registers_num; : obj(obj), target_tag(target) {
CompileEngine engine; }
/*! \brief Global shared meta data */ ~TagCompare() {}
VMCompilerContext* context; };
using TreeNodePtr = typename relay::TreeNode<ConditionNodePtr>::pointer;
using TreeLeafNode = relay::TreeLeafNode<ConditionNodePtr>;
using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionNodePtr>;
using TreeBranchNode = relay::TreeBranchNode<ConditionNodePtr>;
VMCompiler(VMCompilerContext* context) TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data,
: instructions(), Pattern pattern,
var_register_map(), TreeNodePtr then_branch,
last_register(0), TreeNodePtr else_branch) {
registers_num(0), if (pattern.as<PatternWildcardNode>()) {
engine(CompileEngine::Global()), // We ignore wildcard binding since it's not producing new vars
context(context) return then_branch;
{} } else if (pattern.as<PatternVarNode>()) {
auto pat = pattern.as<PatternVarNode>();
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pat = pattern.as<PatternConstructorNode>();
auto pattern = GetRef<PatternConstructor>(pat);
auto tag = pattern->constructor->tag;
size_t NewRegister() { return registers_num++; } size_t field_index = 0;
for (auto& p : pattern->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
}
}
TreeNodePtr BuildDecisionTreeFromClause(MatchValuePtr data,
Clause clause,
TreeNodePtr else_branch) {
return BuildDecisionTreeFromPattern(data, clause->lhs,
TreeLeafNode::Make(clause->rhs), else_branch);
}
TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> clauses) {
// When nothing matches, the VM throws fatal error
TreeNodePtr else_branch = TreeLeafFatalNode::Make();
// Start from the last clause
for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) {
else_branch = BuildDecisionTreeFromClause(data, *it, else_branch);
}
return else_branch;
}
class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
public:
VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets)
: last_register_(0),
registers_num_(0),
engine_(CompileEngine::Global()),
context_(context),
targets_(targets) {}
VMFunction Compile(const GlobalVar& var, const Function& func) {
size_t i = 0;
// We then assign register num to the free variables
for (auto param : func->params) {
auto arg_register = NewRegister();
CHECK_EQ(i, arg_register);
var_register_map_.insert({param, arg_register});
params_.push_back(param->name_hint());
++i;
}
if (IsClosure(func)) {
Function inner_func = Downcast<Function>(func->body);
for (auto param : inner_func->params) {
auto arg_register = NewRegister();
CHECK_EQ(i, arg_register);
var_register_map_.insert({param, arg_register});
params_.push_back(param->name_hint());
++i;
}
this->VisitExpr(inner_func->body);
} else {
this->VisitExpr(func->body);
}
instructions_.push_back(Instruction::Ret(last_register_));
return VMFunction(var->name_hint, params_, instructions_, registers_num_);
}
protected:
size_t NewRegister() { return registers_num_++; }
inline void Emit(const Instruction& instr) { inline void Emit(const Instruction& instr) {
DLOG(INFO) << "VMCompiler::Emit: instr=" << instr; DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
...@@ -210,10 +298,10 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -210,10 +298,10 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::AllocClosure: case Opcode::AllocClosure:
case Opcode::Move: case Opcode::Move:
case Opcode::InvokeClosure: case Opcode::InvokeClosure:
last_register = instr.dst; last_register_ = instr.dst;
break; break;
case Opcode::InvokePacked: case Opcode::InvokePacked:
last_register = instr.packed_args[instr.arity - 1]; last_register_ = instr.packed_args[instr.arity - 1];
break; break;
case Opcode::If: case Opcode::If:
case Opcode::Ret: case Opcode::Ret:
...@@ -221,21 +309,21 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -221,21 +309,21 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::Fatal: case Opcode::Fatal:
break; break;
} }
instructions.push_back(instr); instructions_.push_back(instr);
} }
void VisitExpr_(const ConstantNode* const_node) { void VisitExpr_(const ConstantNode* const_node) {
auto rconst = GetRef<Constant>(const_node); auto rconst = GetRef<Constant>(const_node);
auto it = this->context->const_map.find(rconst); auto it = this->context_->const_map.find(rconst);
CHECK(it != this->context->const_map.end()); CHECK(it != this->context_->const_map.end());
Emit(Instruction::LoadConst(it->second, NewRegister())); Emit(Instruction::LoadConst(it->second, NewRegister()));
} }
void VisitExpr_(const VarNode* var_node) { void VisitExpr_(const VarNode* var_node) {
auto var = GetRef<Var>(var_node); auto var = GetRef<Var>(var_node);
auto reg_it = this->var_register_map.find(var); auto reg_it = this->var_register_map_.find(var);
CHECK(reg_it != this->var_register_map.end()); CHECK(reg_it != this->var_register_map_.end());
last_register = reg_it->second; last_register_ = reg_it->second;
} }
void VisitExpr_(const TupleNode* tuple_node) { void VisitExpr_(const TupleNode* tuple_node) {
...@@ -244,7 +332,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -244,7 +332,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
for (auto& field : tuple->fields) { for (auto& field : tuple->fields) {
this->VisitExpr(field); this->VisitExpr(field);
fields_registers.push_back(last_register); fields_registers.push_back(last_register_);
} }
// TODO(@jroesch): use correct tag // TODO(@jroesch): use correct tag
...@@ -259,29 +347,28 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -259,29 +347,28 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
auto match = GetRef<Match>(match_node); auto match = GetRef<Match>(match_node);
this->VisitExpr(match->data); this->VisitExpr(match->data);
CompileMatch(match, this); CompileMatch(match);
} }
void VisitExpr_(const LetNode* let_node) { void VisitExpr_(const LetNode* let_node) {
DLOG(INFO) << let_node->value; DLOG(INFO) << let_node->value;
this->VisitExpr(let_node->value); this->VisitExpr(let_node->value);
DLOG(INFO) << this->last_register; var_register_map_.insert({let_node->var, this->last_register_});
var_register_map.insert({let_node->var, this->last_register});
this->VisitExpr(let_node->body); this->VisitExpr(let_node->body);
} }
void VisitExpr_(const TupleGetItemNode* get_node) { void VisitExpr_(const TupleGetItemNode* get_node) {
auto get = GetRef<TupleGetItem>(get_node); auto get = GetRef<TupleGetItem>(get_node);
this->VisitExpr(get->tuple); this->VisitExpr(get->tuple);
auto tuple_register = last_register; auto tuple_register = last_register_;
Emit(Instruction::GetField(tuple_register, get->index, NewRegister())); Emit(Instruction::GetField(tuple_register, get->index, NewRegister()));
} }
void VisitExpr_(const GlobalVarNode* gvar) { void VisitExpr_(const GlobalVarNode* gvar) {
auto var = GetRef<GlobalVar>(gvar); auto var = GetRef<GlobalVar>(gvar);
auto func = this->context->module->Lookup(var); auto func = context_->module->Lookup(var);
auto it = this->context->global_map.find(var); auto it = context_->global_map.find(var);
CHECK(it != this->context->global_map.end()); CHECK(it != context_->global_map.end());
// Allocate closure with zero free vars // Allocate closure with zero free vars
Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister())); Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister()));
} }
...@@ -289,30 +376,30 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -289,30 +376,30 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
void VisitExpr_(const IfNode* if_node) { void VisitExpr_(const IfNode* if_node) {
this->VisitExpr(if_node->cond); this->VisitExpr(if_node->cond);
size_t test_register = last_register; size_t test_register = last_register_;
this->Emit(Instruction::LoadConsti(1, NewRegister())); this->Emit(Instruction::LoadConsti(1, NewRegister()));
auto after_cond = this->instructions.size(); auto after_cond = instructions_.size();
auto target_register = this->last_register; auto target_register = last_register_;
this->Emit(Instruction::If(test_register, target_register, 0, 0)); this->Emit(Instruction::If(test_register, target_register, 0, 0));
this->VisitExpr(if_node->true_branch); this->VisitExpr(if_node->true_branch);
size_t true_register = last_register; size_t true_register = last_register_;
Emit(Instruction::Goto(0)); Emit(Instruction::Goto(0));
// Finally store how many instructions there are in the // Finally store how many instructions there are in the
// true branch. // true branch.
auto after_true = this->instructions.size(); auto after_true = this->instructions_.size();
this->VisitExpr(if_node->false_branch); this->VisitExpr(if_node->false_branch);
size_t false_register = last_register; size_t false_register = last_register_;
// In else-branch, override the then-branch register // In else-branch, override the then-branch register
Emit(Instruction::Move(false_register, true_register)); Emit(Instruction::Move(false_register, true_register));
// Compute the total number of instructions // Compute the total number of instructions
// after generating false. // after generating false.
auto after_false = this->instructions.size(); auto after_false = this->instructions_.size();
// Now we will compute the jump targets in order // Now we will compute the jump targets in order
// to properly patch the instruction with the // to properly patch the instruction with the
...@@ -322,13 +409,13 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -322,13 +409,13 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
// we patch up the if instruction, and goto. // we patch up the if instruction, and goto.
auto true_offset = 1; auto true_offset = 1;
auto false_offset = after_true - after_cond; auto false_offset = after_true - after_cond;
this->instructions[after_cond].if_op.true_offset = true_offset; instructions_[after_cond].if_op.true_offset = true_offset;
this->instructions[after_cond].if_op.false_offset = false_offset; instructions_[after_cond].if_op.false_offset = false_offset;
// Patch the Goto. // Patch the Goto.
this->instructions[after_true - 1].pc_offset = (after_false - after_true) + 1; this->instructions_[after_true - 1].pc_offset = (after_false - after_true) + 1;
this->last_register = true_register; this->last_register_ = true_register;
} }
Instruction AllocTensorFromType(const TensorTypeNode* ttype) { Instruction AllocTensorFromType(const TensorTypeNode* ttype) {
...@@ -399,18 +486,27 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -399,18 +486,27 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
// Next generate the invoke instruction. // Next generate the invoke instruction.
CHECK(func->IsPrimitive()); CHECK(func->IsPrimitive());
auto target = Target::Create("llvm"); Target target;
if (targets_.size() == 1) {
// homogeneous execution.
for (auto kv : targets_) {
target = kv.second;
}
} else {
// heterogeneous execution.
LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
}
auto key = CCacheKeyNode::make(func, target); auto key = CCacheKeyNode::make(func, target);
auto cfunc = engine->Lower(key); auto cfunc = engine_->Lower(key);
// TODO(jroesch): support lowered funcs for multiple targets // TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1); CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1; auto op_index = -1;
if (this->context->seen_funcs.find(cfunc->funcs[0]) == this->context->seen_funcs.end()) { if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
op_index = this->context->lowered_funcs.size(); op_index = context_->lowered_funcs.size();
this->context->lowered_funcs.push_back(cfunc->funcs[0]); context_->lowered_funcs.push_back(cfunc->funcs[0]);
this->context->seen_funcs[cfunc->funcs[0]] = op_index; context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else { } else {
op_index = this->context->seen_funcs[cfunc->funcs[0]]; op_index = context_->seen_funcs[cfunc->funcs[0]];
} }
Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs)); Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
...@@ -430,7 +526,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -430,7 +526,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
for (auto arg : call_node->args) { for (auto arg : call_node->args) {
this->VisitExpr(arg); this->VisitExpr(arg);
args_registers.push_back(last_register); args_registers.push_back(last_register_);
} }
Expr op = call_node->op; Expr op = call_node->op;
...@@ -440,12 +536,12 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -440,12 +536,12 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
EmitInvokePrimitive(GetRef<Function>(func_node), args_registers, call_node->checked_type()); EmitInvokePrimitive(GetRef<Function>(func_node), args_registers, call_node->checked_type());
} else if (auto global_node = op.as<GlobalVarNode>()) { } else if (auto global_node = op.as<GlobalVarNode>()) {
auto global = GetRef<GlobalVar>(global_node); auto global = GetRef<GlobalVar>(global_node);
auto it = this->context->global_map.find(global); auto it = context_->global_map.find(global);
CHECK(it != this->context->global_map.end()); CHECK(it != context_->global_map.end());
DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
<< " with func_index=" << it->second; << " with func_index=" << it->second;
auto func = this->context->module->Lookup(global); auto func = context_->module->Lookup(global);
if (IsClosure(func)) { if (IsClosure(func)) {
auto arity = func->params.size(); auto arity = func->params.size();
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister())); Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
...@@ -458,7 +554,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -458,7 +554,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
NewRegister())); NewRegister()));
} else if (auto var_node = op.as<VarNode>()) { } else if (auto var_node = op.as<VarNode>()) {
VisitExpr(GetRef<Var>(var_node)); VisitExpr(GetRef<Var>(var_node));
Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister())); Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
} else { } else {
LOG(FATAL) << "unsupported case in vm compiler: " << op; LOG(FATAL) << "unsupported case in vm compiler: " << op;
} }
...@@ -472,317 +568,257 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -472,317 +568,257 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
} }
} }
void CompileClosure(const Function& func) { /*!
// We first layout the function arguments. * \brief Compile a match value
auto inner_func = Downcast<Function>(func->body); * Generate byte code that compute the value specificed in val
*
size_t i = 0; * \return The register number assigned for the final value
for (auto param : inner_func->params) { */
auto arg_register = NewRegister(); RegName CompileMatchValue(MatchValuePtr val) {
CHECK_EQ(i, arg_register); if (std::dynamic_pointer_cast<RegisterValue>(val)) {
var_register_map.insert({param, arg_register}); auto r = std::dynamic_pointer_cast<RegisterValue>(val);
i++; return r->rergister_num;
} } else {
auto path = std::dynamic_pointer_cast<AccessField>(val);
// We then assign register num to the free variables auto p = CompileMatchValue(path->parent);
for (auto param : func->params) { Emit(Instruction::GetField(p, path->index, NewRegister()));
auto arg_register = NewRegister(); path->reg = last_register_;
CHECK_EQ(i, arg_register); return path->reg;
var_register_map.insert({param, arg_register});
i++;
} }
// We will now process the body like normal.
this->VisitExpr(inner_func->body);
} }
void Compile(const Function& func) { void CompileTreeNode(TreeNodePtr tree) {
// We need to generate code specially for lifted closures. if (std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
if (IsClosure(func)) { auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree);
CompileClosure(func); VisitExpr(node->body);
return; } else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
} Emit(Instruction::Fatal());
} else if (std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
for (size_t i = 0; i < func->params.size(); ++i) { auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree);
auto arg_register = NewRegister(); if (std::dynamic_pointer_cast<TagCompare>(node->cond)) {
CHECK_EQ(arg_register, i); // For Tag compariton, generate branches
var_register_map.insert({func->params[i], arg_register}); auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond);
auto r = CompileMatchValue(cond->obj);
Emit(Instruction::GetTag(r, NewRegister()));
auto operand1 = last_register_;
Emit(Instruction::LoadConsti(cond->target_tag, NewRegister()));
auto operand2 = last_register_;
Emit(Instruction::If(operand1, operand2, 1, 0));
auto cond_offset = instructions_.size() - 1;
CompileTreeNode(node->then_branch);
auto if_reg = last_register_;
Emit(Instruction::Goto(1));
auto goto_offset = instructions_.size() - 1;
CompileTreeNode(node->else_branch);
auto else_reg = last_register_;
Emit(Instruction::Move(else_reg, if_reg));
last_register_ = if_reg;
auto else_offset = instructions_.size() - 1;
// Fixing offsets
instructions_[cond_offset].if_op.false_offset = goto_offset - cond_offset + 1;
instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1;
} else {
// For other non-branch conditions, move to then_branch directly
auto cond = std::dynamic_pointer_cast<VarBinding>(node->cond);
var_register_map_[cond->var] = CompileMatchValue(cond->val);
CompileTreeNode(node->then_branch);
}
} }
this->VisitExpr(func->body);
} }
};
/*! /*!
* \brief Compile a match value * \brief Compile a pattern match expression
* Generate byte code that compute the value specificed in val * It first converts the pattern match expression into a desicision tree, the condition
* * could be object comparison or variable binding. If any of the condition fails in a clause,
* \return The register number assigned for the final value * the decision tree switches to check the conditions of next clause and so on. If no clause
*/ * matches the value, a fatal node is inserted.
RegName CompileMatchValue(MatchValuePtr val, VMCompiler* compiler) { *
if (std::dynamic_pointer_cast<RegisterValue>(val)) { * After the decision tree is built, we convert it into bytecodes using If/Goto.
auto r = std::dynamic_pointer_cast<RegisterValue>(val); */
return r->rergister_num; void CompileMatch(Match match) {
} else { auto data = std::make_shared<RegisterValue>(last_register_);
auto path = std::dynamic_pointer_cast<AccessField>(val); auto decision_tree = BuildDecisionTreeFromClauses(data, match->clauses);
auto p = CompileMatchValue(path->parent, compiler); CompileTreeNode(decision_tree);
compiler->Emit(Instruction::GetField(p, path->index, compiler->NewRegister()));
path->reg = compiler->last_register;
return path->reg;
} }
}
/*! protected:
* \brief Condition in a decision tree /*! \brief Store the expression a variable points to. */
*/ std::unordered_map<Var, Expr, NodeHash, NodeEqual> expr_map_;
struct ConditionNode { /*! \brief Instructions in the VMFunction. */
virtual ~ConditionNode() {} std::vector<Instruction> instructions_;
/*! \brief Parameter names of the function. */
std::vector<std::string> params_;
/*! \brief Map from var to register number. */
std::unordered_map<Var, RegName, NodeHash, NodeEqual> var_register_map_;
/*! \brief Last used register number. */
size_t last_register_;
/*! \brief Total number of virtual registers allocated. */
size_t registers_num_;
/*! \brief Compiler engine to lower primitive functions. */
CompileEngine engine_;
/*! \brief Global shared meta data */
VMCompilerContext* context_;
/*! \brief Target devices. */
TargetsMap targets_;
}; };
using ConditionNodePtr = std::shared_ptr<ConditionNode>;
/*!
* \brief A var binding condition
*/
struct VarBinding : ConditionNode {
Var var;
MatchValuePtr val;
VarBinding(Var var, MatchValuePtr val)
: var(var), val(val) {}
~VarBinding() {}
};
/*! class VMCompiler : public runtime::ModuleNode {
* \brief Compare the tag of the object public:
*/ PackedFunc GetFunction(const std::string& name,
struct TagCompare : ConditionNode { const std::shared_ptr<ModuleNode>& sptr_to_self) final {
/*! \brief The object to be examined */ if (name == "compile") {
MatchValuePtr obj; return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
this->Compile(args[0], args[1], args[2]);
});
} else if (name == "get_vm") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(vm_);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}
/*! \brief The expected tag */ const char* type_key() const final {
int target_tag; return "VMCompiler";
}
TagCompare(MatchValuePtr obj, size_t target) std::shared_ptr<VirtualMachine> GetVirtualMachine() const {
: obj(obj), target_tag(target) { return vm_;
} }
~TagCompare() {} void Compile(const Module& mod_ref,
}; const TargetsMap& targets,
const tvm::Target& target_host) {
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
targets_ = targets;
target_host_ = target_host;
vm_ = std::make_shared<VirtualMachine>();
// Run some optimizations first, this code should
// be moved to pass manager.
context_.module = OptimizeModule(mod_ref);
// Populate the global map.
//
// This maps global variables to a global index
// in the VMFunction table.
PopulateGlobalMap();
// Next we populate constant map.
auto constant_analysis_result = LayoutConstantPool(context_.module);
context_.const_map = std::get<0>(constant_analysis_result);
context_.const_tensor_shape_map = std::get<1>(constant_analysis_result);
// Next we get ready by allocating space for
// the global state.
vm_->functions.resize(context_.module->functions.size());
vm_->constants.resize(context_.const_map.size() + context_.const_tensor_shape_map.size());
for (auto pair : context_.const_map) {
vm_->constants[pair.second] = Object::Tensor(pair.first->data);
}
using TreeNodePtr = typename relay::TreeNode<ConditionNodePtr>::pointer; for (auto pair : context_.const_tensor_shape_map) {
using TreeLeafNode = relay::TreeLeafNode<ConditionNodePtr>; vm_->constants[pair.second.first] = Object::Tensor(pair.second.second);
using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionNodePtr>; }
using TreeBranchNode = relay::TreeBranchNode<ConditionNodePtr>;
void CompileTreeNode(TreeNodePtr tree, VMCompiler* compiler) { for (auto named_func : context_.module->functions) {
if (std::dynamic_pointer_cast<TreeLeafNode>(tree)) { auto gvar = named_func.first;
auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree); auto func = named_func.second;
compiler->VisitExpr(node->body); VMFunctionCompiler func_compiler(&context_, targets_);
} else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) { auto vm_func = func_compiler.Compile(gvar, func);
compiler->Emit(Instruction::Fatal());
} else if (std::dynamic_pointer_cast<TreeBranchNode>(tree)) { size_t func_index = context_.global_map.at(gvar);
auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree); CHECK(func_index < vm_->functions.size());
if (std::dynamic_pointer_cast<TagCompare>(node->cond)) { vm_->functions[func_index] = vm_func;
// For Tag compariton, generate branches
auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond);
auto r = CompileMatchValue(cond->obj, compiler);
compiler->Emit(Instruction::GetTag(r, compiler->NewRegister()));
auto operand1 = compiler->last_register;
compiler->Emit(Instruction::LoadConsti(cond->target_tag, compiler->NewRegister()));
auto operand2 = compiler->last_register;
compiler->Emit(Instruction::If(operand1, operand2, 1, 0));
auto cond_offset = compiler->instructions.size() - 1;
CompileTreeNode(node->then_branch, compiler);
auto if_reg = compiler->last_register;
compiler->Emit(Instruction::Goto(1));
auto goto_offset = compiler->instructions.size() - 1;
CompileTreeNode(node->else_branch, compiler);
auto else_reg = compiler->last_register;
compiler->Emit(Instruction::Move(else_reg, if_reg));
compiler->last_register = if_reg;
auto else_offset = compiler->instructions.size() - 1;
// Fixing offsets
compiler->instructions[cond_offset].if_op.false_offset = goto_offset - cond_offset + 1;
compiler->instructions[goto_offset].pc_offset = else_offset - goto_offset + 1;
} else {
// For other non-branch conditions, move to then_branch directly
auto cond = std::dynamic_pointer_cast<VarBinding>(node->cond);
compiler->var_register_map[cond->var] = CompileMatchValue(cond->val, compiler);
CompileTreeNode(node->then_branch, compiler);
} }
#ifdef USE_RELAY_DEBUG
for (auto vm_func : vm.functions) {
std::cout << "Function: " << vm_func.name << std::endl
<< vm_func << "-------------" << std::endl;
} }
} #endif // USE_RELAY_DEBUG
TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data, PopulatePackedFuncMap();
Pattern pattern,
TreeNodePtr then_branch,
TreeNodePtr else_branch) {
if (pattern.as<PatternWildcardNode>()) {
// We ignore wildcard binding since it's not producing new vars
return then_branch;
} else if (pattern.as<PatternVarNode>()) {
auto pat = pattern.as<PatternVarNode>();
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pat = pattern.as<PatternConstructorNode>();
auto pattern = GetRef<PatternConstructor>(pat);
auto tag = pattern->constructor->tag;
size_t field_index = 0; for (auto gv : context_.global_map) {
for (auto& p : pattern->patterns) { vm_->global_map.insert({gv.first->name_hint, gv.second});
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
} }
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} }
}
TreeNodePtr BuildDecisionTreeFromClause(MatchValuePtr data, protected:
Clause clause, Module OptimizeModule(const Module& mod) {
TreeNodePtr else_branch) { // TODO(@icemelon9): check number of targets and build config, add more optimization pass
return BuildDecisionTreeFromPattern(data, clause->lhs, transform::Sequential seq({transform::SimplifyInference(),
TreeLeafNode::Make(clause->rhs), else_branch); transform::ToANormalForm(),
} transform::InlinePrimitives(),
transform::LambdaLift(),
TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> clauses) { transform::InlinePrimitives(),
// When nothing matches, the VM throws fatal error transform::FuseOps()});
TreeNodePtr else_branch = TreeLeafFatalNode::Make(); auto pass_ctx = transform::PassContext::Create();
// Start from the last clause tvm::With<relay::transform::PassContext> ctx(pass_ctx);
for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) { return seq(mod);
else_branch = BuildDecisionTreeFromClause(data, *it, else_branch);
} }
return else_branch;
}
void CompileMatch(Match match, VMCompiler* compiler) { void PopulateGlobalMap() {
auto data = std::make_shared<RegisterValue>(compiler->last_register); // First we populate global map.
auto decision_tree = BuildDecisionTreeFromClauses(data, match->clauses); size_t global_index = 0;
CompileTreeNode(decision_tree, compiler); for (auto named_func : context_.module->functions) {
} auto gvar = named_func.first;
context_.global_map.insert({gvar, global_index++});
}
}
void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs, void PopulatePackedFuncMap() {
std::vector<PackedFunc>* packed_funcs) { auto const& lowered_funcs = context_.lowered_funcs;
runtime::Module mod; if (lowered_funcs.size() == 0) {
if (lowered_funcs.size() > 0) { return;
// TODO(@jroesch): we need to read target from build config }
Target target = Target::Create("llvm"); runtime::Module mod;
// TODO(@icemelon9): support heterogeneous targets
Target target;
for (auto kv : targets_) {
target = kv.second;
}
if (const auto* f = runtime::Registry::Get("relay.backend.build")) { if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target); mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()),
target, target_host_);
} else { } else {
LOG(FATAL) << "relay.backend.build is not registered"; LOG(FATAL) << "relay.backend.build is not registered";
} }
CHECK(mod.operator->()); CHECK(mod.operator->());
for (auto lfunc : lowered_funcs) { for (auto lfunc : lowered_funcs) {
packed_funcs->push_back(mod.GetFunction(lfunc->name)); vm_->packed_funcs.push_back(mod.GetFunction(lfunc->name));
} }
} }
}
VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) { protected:
DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false); /*! \brief Target devices. */
size_t params = func->params.size(); TargetsMap targets_;
VMCompiler compiler(context); /*! \brief Target host device. */
compiler.Compile(func); tvm::Target target_host_;
// return the last evaluated expression /*! \brief Global shared meta data */
compiler.instructions.push_back(Instruction::Ret(compiler.last_register)); VMCompilerContext context_;
/*! \brief Compiled virtual machine. */
// Would like to refactor this so we only check if closure once. std::shared_ptr<VirtualMachine> vm_;
if (IsClosure(func)) { };
auto inner_params = Downcast<Function>(func->body)->params.size();
return VMFunction(var->name_hint, params + inner_params, compiler.instructions,
compiler.registers_num);
} else {
return VMFunction(var->name_hint, params, compiler.instructions, compiler.registers_num);
}
}
Module OptimizeModule(const Module& mod) {
transform::Sequential seq({transform::ToANormalForm(),
transform::InlinePrimitives(),
transform::LambdaLift(),
transform::InlinePrimitives()});
auto pass_ctx = transform::PassContext::Create();
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
return seq(mod);
}
void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) { runtime::Module CreateVMCompiler() {
// First we populate global map. std::shared_ptr<VMCompiler> exec = std::make_shared<VMCompiler>();
size_t global_index = 0; return runtime::Module(exec);
for (auto named_func : mod->functions) {
auto gvar = named_func.first;
global_map->insert({gvar, global_index++});
}
} }
VirtualMachine CompileModule(const Module& mod_ref) { TVM_REGISTER_GLOBAL("relay._vm._VMCompiler")
Module mod = mod_ref; .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateVMCompiler();
// Run some optimizations first, this code should });
// be moved to pass manager.
mod = OptimizeModule(mod);
VirtualMachine vm;
VMCompilerContext context;
context.module = mod;
// Populate the global map.
//
// This maps global variables to a global index
// in the VMFunction table.
PopulateGlobalMap(&context.global_map, mod);
// Next we populate constant map.
auto constant_analysis_result = LayoutConstantPool(mod);
context.const_map = std::get<0>(constant_analysis_result);
context.const_tensor_shape_map = std::get<1>(constant_analysis_result);
// Next we get ready by allocating space for
// the global state.
vm.functions.resize(mod->functions.size());
vm.constants.resize(context.const_map.size() + context.const_tensor_shape_map.size());
for (auto pair : context.const_map) {
vm.constants[pair.second] = Object::Tensor(pair.first->data);
}
for (auto pair : context.const_tensor_shape_map) {
vm.constants[pair.second.first] = Object::Tensor(pair.second.second);
}
for (auto named_func : mod->functions) {
auto gvar = named_func.first;
auto func = named_func.second;
auto vm_func = CompileFunc(&context, gvar, func);
size_t func_index = context.global_map.at(gvar);
CHECK(func_index < vm.functions.size());
vm.functions[func_index] = vm_func;
}
#ifdef USE_RELAY_DEBUG
for (auto vm_func : vm.functions) {
std::cout << "Function: " << vm_func.name << std::endl
<< vm_func << "-------------" << std::endl;
}
#endif // USE_RELAY_DEBUG
PopulatePackedFuncMap(context.lowered_funcs, &vm.packed_funcs);
for (auto gv : context.global_map) {
vm.global_map_.insert({gv.first->name_hint, gv.second});
}
return vm;
}
} // namespace vm } // namespace vm
} // namespace relay } // namespace relay
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/vm.cc
* \brief The Relay virtual machine.
*/
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/module.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/analysis.h>
namespace tvm {
namespace relay {
namespace vm {
runtime::vm::VirtualMachine CompileModule(const Module& mod);
using tvm::runtime::Object;
using tvm::runtime::ObjectTag;
using tvm::runtime::vm::VirtualMachine;
VirtualMachine FromModule(const Module& module, const std::vector<TVMContext>& ctxs) {
auto vm = CompileModule(module);
vm.Init(ctxs);
return vm;
}
Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
const std::vector<Object>& vm_args) {
VirtualMachine vm = FromModule(module, ctxs);
// TODO(zhiics): This measurement is for temporary usage. Remove it later. We
// need to introduce a better profiling method.
#if ENABLE_PROFILING
DLOG(INFO) << "Entry function is main." << std::endl;
auto start = std::chrono::high_resolution_clock::now();
#endif // ENABLE_PROFILING
Object res = vm.Invoke("main", vm_args);
#if ENABLE_PROFILING
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
LOG(INFO) << "Inference time: " << duration << "ms\n";
#endif // ENABLE_PROFILING
return res;
}
Value VMToValue(const relay::Module& module, Object obj) {
CHECK(module.defined());
switch (obj->tag) {
case ObjectTag::kTensor: {
return TensorValueNode::make(ToNDArray(obj));
}
case ObjectTag::kDatatype: {
const auto& data_type = obj.AsDatatype();
tvm::Array<Value> fields;
for (size_t i = 0; i < data_type->fields.size(); ++i) {
fields.push_back(VMToValue(module, data_type->fields[i]));
}
return ConstructorValueNode::make(data_type->tag, fields);
}
default:
LOG(FATAL) << "unsupported return value of type: " << obj->tag;
return Value();
}
}
TVM_REGISTER_API("relay._vm._Tensor").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Object::Tensor(args[0]);
});
TVM_REGISTER_API("relay._vm._Tuple").set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<Object> fields;
for (auto i = 0; i < args.size(); i++) {
fields.push_back(args[i]);
}
*ret = Object::Tuple(fields);
});
template <typename T>
std::string ToString(const T& t) {
std::stringstream s;
s << t;
return s.str();
}
TVM_REGISTER_API("relay._vm._ObjectTag").set_body([](TVMArgs args, TVMRetValue* ret) {
Object obj = args[0];
*ret = ToString(obj->tag);
});
TVM_REGISTER_API("relay._vm._Datatype")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*ret = Object::Datatype(tag, fields);
});
TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef to_compile = args[0];
TVMContext ctx;
int dev_type = args[1];
ctx.device_type = static_cast<DLDeviceType>(dev_type);
ctx.device_id = args[2];
Module module;
if (to_compile.as<FunctionNode>()) {
Function to_compile = args[0];
module = ModuleNode::FromExpr(to_compile);
} else if (to_compile.as<ModuleNode>()) {
module = args[0];
} else {
LOG(FATAL) << "expected function or module";
}
std::vector<Object> vm_args;
for (auto i = 3; i < args.size(); i++) {
Object obj = args[i];
vm_args.push_back(obj);
}
auto result = EvaluateModule(module, {ctx}, vm_args);
DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
*ret = VMToValue(module, result);
});
} // namespace vm
} // namespace relay
} // namespace tvm
...@@ -25,7 +25,10 @@ ...@@ -25,7 +25,10 @@
#include <tvm/logging.h> #include <tvm/logging.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include <iostream> #include <iostream>
#include "../runtime_base.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -87,5 +90,69 @@ NDArray ToNDArray(const Object& obj) { ...@@ -87,5 +90,69 @@ NDArray ToNDArray(const Object& obj) {
return tensor->data; return tensor->data;
} }
TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsTensor();
*rv = cell->data;
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsDatatype();
*rv = static_cast<int>(cell->tag);
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsDatatype();
*rv = static_cast<int>(cell->fields.size());
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
int idx = args[1];
auto cell = obj.AsDatatype();
CHECK_LT(idx, cell->fields.size());
*rv = cell->fields[idx];
});
TVM_REGISTER_GLOBAL("_vmobj.Tensor")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Object::Tensor(args[0]);
});
TVM_REGISTER_GLOBAL("_vmobj.Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<Object> fields;
for (auto i = 0; i < args.size(); ++i) {
fields.push_back(args[i]);
}
*rv = Object::Tuple(fields);
});
TVM_REGISTER_GLOBAL("_vmobj.Datatype")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*rv = Object::Datatype(tag, fields);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
using namespace tvm::runtime;
int TVMGetObjectTag(TVMObjectHandle handle, int* tag) {
API_BEGIN();
*tag = static_cast<int>(static_cast<ObjectCell*>(handle)->tag);
API_END();
}
...@@ -41,7 +41,7 @@ class PooledAllocator final : public Allocator { ...@@ -41,7 +41,7 @@ class PooledAllocator final : public Allocator {
static constexpr size_t kDefaultPageSize = 4096; static constexpr size_t kDefaultPageSize = 4096;
explicit PooledAllocator(TVMContext ctx, size_t page_size = kDefaultPageSize) explicit PooledAllocator(TVMContext ctx, size_t page_size = kDefaultPageSize)
: Allocator(), page_size_(page_size), used_memory_(0) {} : Allocator(), page_size_(page_size), used_memory_(0), ctx_(ctx) {}
~PooledAllocator() { ReleaseAll(); } ~PooledAllocator() { ReleaseAll(); }
......
...@@ -32,8 +32,8 @@ ...@@ -32,8 +32,8 @@
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include "../../runtime/vm/memory_manager.h" #include "memory_manager.h"
#include "../../runtime/vm/naive_allocator.h" #include "naive_allocator.h"
using namespace tvm::runtime; using namespace tvm::runtime;
...@@ -554,6 +554,37 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) { ...@@ -554,6 +554,37 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
return os; return os;
} }
PackedFunc VirtualMachine::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
std::vector<Object> func_args;
for (int i = 1; i < args.size(); ++i) {
Object obj = args[i];
func_args.push_back(obj);
}
*rv = this->Invoke(func_name, func_args);
});
} else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size() % 2, 0);
std::vector<TVMContext> contexts;
for (int i = 0; i < args.size() / 2; ++i) {
TVMContext ctx;
int device_type = args[i * 2];
ctx.device_type = DLDeviceType(device_type);
ctx.device_id = args[i * 2 + 1];
contexts.push_back(ctx);
}
this->Init(contexts);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}
void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) {
auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size); auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size);
frames.push_back(frame); frames.push_back(frame);
...@@ -573,11 +604,11 @@ Index VirtualMachine::PopFrame() { ...@@ -573,11 +604,11 @@ Index VirtualMachine::PopFrame() {
void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Object>& args) { void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Object>& args) {
DLOG(INFO) << "Invoking global " << func.name << " " << args.size(); DLOG(INFO) << "Invoking global " << func.name << " " << args.size();
PushFrame(func.params, this->pc + 1, func); PushFrame(func.params.size(), this->pc + 1, func);
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
WriteRegister(i, args[i]); WriteRegister(i, args[i]);
} }
DLOG(INFO) << "func.params= " << func.params; DLOG(INFO) << "func.params= " << func.params.size();
code = func.instructions.data(); code = func.instructions.data();
pc = 0; pc = 0;
...@@ -594,7 +625,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>& ...@@ -594,7 +625,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>&
} }
Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object>& args) { Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object>& args) {
auto func_index = this->global_map_[name]; auto func_index = this->global_map[name];
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index; DLOG(INFO) << "Invoke Global " << name << " at index " << func_index;
return Invoke(this->functions[func_index], args); return Invoke(this->functions[func_index], args);
} }
...@@ -719,12 +750,12 @@ void VirtualMachine::Run() { ...@@ -719,12 +750,12 @@ void VirtualMachine::Run() {
auto object = ReadRegister(instr.closure); auto object = ReadRegister(instr.closure);
const auto& closure = object.AsClosure(); const auto& closure = object.AsClosure();
std::vector<Object> args; std::vector<Object> args;
for (Index i = 0; i < instr.closure_args_num; ++i) {
args.push_back(ReadRegister(instr.closure_args[i]));
}
for (auto free_var : closure->free_vars) { for (auto free_var : closure->free_vars) {
args.push_back(free_var); args.push_back(free_var);
} }
for (Index i = 0; i < instr.closure_args_num; ++i) {
args.push_back(ReadRegister(instr.closure_args[i]));
}
InvokeGlobal(this->functions[closure->func_index], args); InvokeGlobal(this->functions[closure->func_index], args);
frames.back().caller_return_register = instr.dst; frames.back().caller_return_register = instr.dst;
goto main_loop; goto main_loop;
......
...@@ -92,9 +92,9 @@ def test_squeezenet(): ...@@ -92,9 +92,9 @@ def test_squeezenet():
def test_inception_v3(): def test_inception_v3():
image_shape = (1, 3, 299, 299) image_shape = (3, 299, 299)
mod, params = testing.inception_v3.get_workload(image_shape=image_shape) mod, params = testing.inception_v3.get_workload(image_shape=image_shape)
benchmark_execution(mod, params, data_shape=image_shape) benchmark_execution(mod, params, data_shape=(1, 3, 299, 299))
def test_dqn(): def test_dqn():
...@@ -107,7 +107,7 @@ def test_dqn(): ...@@ -107,7 +107,7 @@ def test_dqn():
def test_dcgan(): def test_dcgan():
image_shape = (1, 100) image_shape = (1, 100)
mod, params = testing.dcgan.get_workload(batch_size=1) mod, params = testing.dcgan.get_workload(batch_size=1)
benchmark_execution(mod, params, data_shape=image_shape) benchmark_execution(mod, params, data_shape=image_shape, out_shape=(1, 3, 64, 64))
def test_mobilenet(): def test_mobilenet():
...@@ -126,8 +126,7 @@ if __name__ == '__main__': ...@@ -126,8 +126,7 @@ if __name__ == '__main__':
test_squeezenet() test_squeezenet()
test_mobilenet() test_mobilenet()
test_densenet() test_densenet()
# The following networks fail test_inception_v3()
# test_inception_v3() test_mlp()
# test_mlp() test_dqn()
# test_dqn() test_dcgan()
# test_dcgan()
...@@ -23,40 +23,44 @@ from tvm import relay ...@@ -23,40 +23,44 @@ from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
def veval(f, *args, ctx=tvm.cpu()): def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr): if isinstance(f, relay.Expr):
ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx) mod = relay.Module()
if len(args) == 0: mod["main"] = f
return ex.evaluate(f) compiler = relay.vm.VMCompiler()
else: vm = compiler.compile(mod, target)
return ex.evaluate(f)(*args) vm.init(tvm.cpu())
return vm.run(*args)
else: else:
assert isinstance(f, relay.Module), "expected expression or module" assert isinstance(f, relay.Module), "expected expression or module"
mod = f mod = f
ex = relay.create_executor('vm', mod=mod, ctx=ctx) compiler = relay.vm.VMCompiler()
if len(args) == 0: vm = compiler.compile(mod, target)
return ex.evaluate() vm.init(tvm.cpu())
else: ret = vm.run(*args)
return ex.evaluate()(*args) return ret
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.interpreter.TensorValue): if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
return [o.data.asnumpy().tolist()] return [o.asnumpy().tolist()]
if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
result = [] result = []
for f in o.fields: for f in o:
result.extend(vmobj_to_list(f)) result.extend(vmobj_to_list(f))
return result return result
else:
raise RuntimeError("Unknown object type: %s" % type(o))
def test_split(): def test_split():
x = relay.var('x', shape=(12,)) x = relay.var('x', shape=(12,))
y = relay.split(x, 3, axis=0).astuple() y = relay.split(x, 3, axis=0).astuple()
z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0) f = relay.Function([x], y)
f = relay.Function([x], z)
x_data = np.random.rand(12,).astype('float32') x_data = np.random.rand(12,).astype('float32')
res = veval(f, x_data) res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0]) ref_res = np.split(x_data, 3, axis=0)
for i in range(3):
tvm.testing.assert_allclose(res[i].asnumpy(), ref_res[i])
def test_split_no_fuse(): def test_split_no_fuse():
x = relay.var('x', shape=(12,)) x = relay.var('x', shape=(12,))
...@@ -68,9 +72,8 @@ def test_split_no_fuse(): ...@@ -68,9 +72,8 @@ def test_split_no_fuse():
res = veval(f, x_data) res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0]) tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
def test_id(): def test_id():
x = relay.var('x', shape=(10, 10)) x = relay.var('x', shape=(10, 10), dtype='float64')
f = relay.Function([x], x) f = relay.Function([x], x)
x_data = np.random.rand(10, 10).astype('float64') x_data = np.random.rand(10, 10).astype('float64')
res = veval(f, x_data) res = veval(f, x_data)
...@@ -209,7 +212,7 @@ def test_list_constructor(): ...@@ -209,7 +212,7 @@ def test_list_constructor():
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
obj = vmobj_to_list(result) obj = vmobj_to_list(result)
tvm.testing.assert_allclose(obj, np.array([3,2,1])) tvm.testing.assert_allclose(obj, np.array([3,2,1]))
...@@ -275,7 +278,7 @@ def test_compose(): ...@@ -275,7 +278,7 @@ def test_compose():
mod["main"] = f mod["main"] = f
x_data = np.array(np.random.rand()).astype('float32') x_data = np.array(np.random.rand()).astype('float32')
result = veval(mod)(x_data) result = veval(mod, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0) tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
...@@ -296,7 +299,7 @@ def test_list_hd(): ...@@ -296,7 +299,7 @@ def test_list_hd():
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 3) tvm.testing.assert_allclose(result.asnumpy(), 3)
@raises(Exception) @raises(Exception)
...@@ -312,7 +315,7 @@ def test_list_tl_empty_list(): ...@@ -312,7 +315,7 @@ def test_list_tl_empty_list():
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
print(result) print(result)
def test_list_tl(): def test_list_tl():
...@@ -332,7 +335,7 @@ def test_list_tl(): ...@@ -332,7 +335,7 @@ def test_list_tl():
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1])) tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1]))
def test_list_nth(): def test_list_nth():
...@@ -351,7 +354,7 @@ def test_list_nth(): ...@@ -351,7 +354,7 @@ def test_list_nth():
f = relay.Function([], nth(l, relay.const(i))) f = relay.Function([], nth(l, relay.const(i)))
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), expected[i]) tvm.testing.assert_allclose(result.asnumpy(), expected[i])
def test_list_update(): def test_list_update():
...@@ -375,7 +378,7 @@ def test_list_update(): ...@@ -375,7 +378,7 @@ def test_list_update():
f = relay.Function([], l) f = relay.Function([], l)
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected)) tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected))
def test_list_length(): def test_list_length():
...@@ -397,7 +400,7 @@ def test_list_length(): ...@@ -397,7 +400,7 @@ def test_list_length():
f = relay.Function([], l) f = relay.Function([], l)
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 10) tvm.testing.assert_allclose(result.asnumpy(), 10)
def test_list_map(): def test_list_map():
...@@ -415,7 +418,7 @@ def test_list_map(): ...@@ -415,7 +418,7 @@ def test_list_map():
f = relay.Function([], map(add_one_func, l)) f = relay.Function([], map(add_one_func, l))
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2])) tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2]))
def test_list_foldl(): def test_list_foldl():
...@@ -433,7 +436,7 @@ def test_list_foldl(): ...@@ -433,7 +436,7 @@ def test_list_foldl():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil()))) l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldl(rev_dup_func, nil(), l)) f = relay.Function([], foldl(rev_dup_func, nil(), l))
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1])) tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1]))
def test_list_foldr(): def test_list_foldr():
...@@ -451,7 +454,7 @@ def test_list_foldr(): ...@@ -451,7 +454,7 @@ def test_list_foldr():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil()))) l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldr(identity_func, nil(), l)) f = relay.Function([], foldr(identity_func, nil(), l))
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3])) tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3]))
def test_list_sum(): def test_list_sum():
...@@ -465,7 +468,7 @@ def test_list_sum(): ...@@ -465,7 +468,7 @@ def test_list_sum():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil()))) l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], sum(l)) f = relay.Function([], sum(l))
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 6) tvm.testing.assert_allclose(result.asnumpy(), 6)
def test_list_filter(): def test_list_filter():
...@@ -485,7 +488,7 @@ def test_list_filter(): ...@@ -485,7 +488,7 @@ def test_list_filter():
cons(relay.const(1), nil()))))) cons(relay.const(1), nil())))))
f = relay.Function([], filter(greater_than_one, l)) f = relay.Function([], filter(greater_than_one, l))
mod["main"] = f mod["main"] = f
result = veval(mod)() result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5])) tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5]))
def test_closure(): def test_closure():
...@@ -513,6 +516,10 @@ if __name__ == "__main__": ...@@ -513,6 +516,10 @@ if __name__ == "__main__":
test_split() test_split()
test_split_no_fuse() test_split_no_fuse()
test_list_constructor() test_list_constructor()
test_let_tensor()
test_let_scalar()
test_compose()
test_list_hd()
test_list_tl_empty_list() test_list_tl_empty_list()
test_list_tl() test_list_tl()
test_list_nth() test_list_nth()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment