Commit b6dc7826 by Haichen Shen Committed by Jared Roesch

[Relay][VM] Port VM, VM compiler, and Object into python (#3391)

* tmp

* Port vm and object to python

* clean up

* update vm build module

* update

* x

* tweak

* cleanup

* update

* fix rebase

* Rename to VMCompiler

* fix
parent afd4b3e4
......@@ -103,7 +103,7 @@ typedef enum {
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
kObject = 14U,
kObjectCell = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
......@@ -176,6 +176,8 @@ typedef void* TVMRetValueHandle;
* can be NULL, which indicates the default one.
*/
typedef void* TVMStreamHandle;
/*! \brief Handle to Object. */
typedef void* TVMObjectHandle;
/*!
* \brief Used for implementing C API function.
......@@ -545,6 +547,15 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
TVMStreamHandle src,
TVMStreamHandle dst);
/*!
* \brief Get the tag from an object.
*
* \param obj The object handle.
* \param tag The tag of object.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag);
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
......
......@@ -324,6 +324,7 @@ class Object {
Object() : ptr_() {}
Object(const Object& obj) : ptr_(obj.ptr_) {}
ObjectCell* operator->() { return this->ptr_.operator->(); }
const ObjectCell* operator->() const { return this->ptr_.operator->(); }
/*! \brief Construct a tensor object. */
static Object Tensor(const NDArray& data);
......
......@@ -491,7 +491,7 @@ class TVMPODValue_ {
}
operator Object() const {
if (type_code_ == kNull) return Object();
TVM_CHECK_TYPE_CODE(type_code_, kObject);
TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
return Object(static_cast<ObjectCell*>(value_.v_handle));
}
operator TVMContext() const {
......@@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ {
}
TVMRetValue& operator=(Object other) {
this->Clear();
type_code_ = kObject;
type_code_ = kObjectCell;
value_.v_handle = other.ptr_.data_;
other.ptr_.data_ = nullptr;
return *this;
......@@ -861,7 +861,7 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >());
break;
}
case kObject: {
case kObjectCell: {
*this = other.operator Object();
break;
}
......@@ -912,7 +912,7 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
case kObject: {
case kObjectCell: {
static_cast<ObjectCell*>(value_.v_handle)->DecRef();
break;
}
......@@ -945,7 +945,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
case kObject: return "Object";
case kObjectCell: return "ObjectCell";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
......
......@@ -143,7 +143,7 @@ struct Instruction {
/*! \brief The registers containing the arguments. */
RegName* invoke_args_registers;
};
struct /* Const Operands */ {
struct /* LoadConst Operands */ {
/* \brief The index into the constant pool. */
Index const_index;
};
......@@ -308,14 +308,14 @@ struct Instruction {
struct VMFunction {
/*! \brief The function's name. */
std::string name;
/*! \brief The number of function parameters. */
Index params;
/*! \brief The function parameter names. */
std::vector<std::string> params;
/*! \brief The instructions representing the function. */
std::vector<Instruction> instructions;
/*! \brief The size of the frame for this function */
Index register_file_size;
VMFunction(const std::string& name, Index params,
VMFunction(const std::string& name, std::vector<std::string> params,
const std::vector<Instruction>& instructions,
Index register_file_size)
: name(name),
......@@ -370,7 +370,15 @@ struct VMFrame {
* multiple threads, or serialized them to disk or over the
* wire.
*/
struct VirtualMachine {
class VirtualMachine : public runtime::ModuleNode {
public:
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final {
return "VirtualMachine";
}
/*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs;
/*! \brief The virtual machine's function table. */
......@@ -442,7 +450,7 @@ struct VirtualMachine {
/*! \brief A map from globals (as strings) to their index in the function map.
*/
std::unordered_map<std::string, Index> global_map_;
std::unordered_map<std::string, Index> global_map;
private:
/*! \brief Invoke a global setting up the VM state to execute.
......
......@@ -37,6 +37,7 @@ from . import node as _node
FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
ObjectHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle):
......@@ -162,9 +163,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg)
elif isinstance(arg, ObjectBase):
elif isinstance(arg, _CLASS_OBJECT):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT
type_codes[i] = TypeCode.OBJECT_CELL
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
......@@ -236,6 +237,7 @@ def _return_module(x):
handle = ModuleHandle(handle)
return _CLASS_MODULE(handle)
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
......@@ -243,18 +245,11 @@ def _handle_return_func(x):
handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False)
class ObjectBase(object):
__slots__ = ["handle"]
def __init__(self, handle):
self.handle = handle
# setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
RETURN_SWITCH[TypeCode.OBJECT] = lambda x: _CLASS_OBJECT(x.v_handle)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH
ObjectHandle = ctypes.c_void_p
"""Maps object type to its constructor"""
OBJECT_TYPE = {}
def _register_object(index, cls):
"""register object class"""
OBJECT_TYPE[index] = cls
def _return_object(x):
handle = x.v_handle
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
tag = ctypes.c_int()
check_call(_LIB.TVMGetObjectTag(handle, ctypes.byref(tag)))
cls = OBJECT_TYPE.get(tag.value, ObjectBase)
obj = cls(handle)
return obj
RETURN_SWITCH[TypeCode.OBJECT_CELL] = _return_object
class ObjectBase(object):
__slots__ = ["handle"]
def __init__(self, handle):
self.handle = handle
......@@ -37,7 +37,7 @@ cdef enum TVMTypeCode:
kStr = 11
kBytes = 12
kNDArrayContainer = 13
kObject = 14
kObjectCell = 14
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h":
......@@ -130,6 +130,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
int TVMArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
int TVMGetObjectTag(ObjectHandle obj, int* tag)
cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle)
......
......@@ -19,3 +19,4 @@ include "./base.pxi"
include "./node.pxi"
include "./function.pxi"
include "./ndarray.pxi"
include "./vmobj.pxi"
......@@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle or
tcode == kObject or
tcode == kObjectCell or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))
......@@ -160,7 +160,7 @@ cdef inline int make_arg(object arg,
tcode[0] = kModuleHandle
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kObject
tcode[0] = kObjectCell
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
......@@ -212,8 +212,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
elif tcode == kObject:
return _CLASS_OBJECT(ctypes_handle(value.v_handle))
elif tcode == kObjectCell:
return make_ret_object(value.v_handle)
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
......@@ -310,27 +310,6 @@ cdef class FunctionBase:
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
_CLASS_FUNCTION = None
_CLASS_MODULE = None
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Maps object type to its constructor"""
OBJECT_TYPE = []
def _register_object(int index, object cls):
"""register node class"""
while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls
cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE
cdef int tag
cdef list object_type
cdef object cls
cdef object handle
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
CALL(TVMGetObjectTag(chandle, &tag))
if tag < len(object_type):
cls = object_type[tag]
if cls is not None:
obj = cls(handle)
else:
obj = ObjectBase(handle)
else:
obj = ObjectBase(handle)
return obj
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
......@@ -22,6 +22,7 @@ from __future__ import absolute_import
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from . import vmobj as _vmobj
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -30,28 +31,19 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module, _set_class_object
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import convert_to_tvm_func
else:
from ._cy2.core import _set_class_function, _set_class_module, _set_class_object
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module, _set_class_object
from ._ctypes.function import ObjectBase as _ObjectBase
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func
class Object(_ObjectBase):
# TODO(@jroesch): Eventually add back introspection functionality.
pass
_set_class_object(Object)
FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase):
......
......@@ -42,7 +42,7 @@ class TypeCode(object):
STR = 11
BYTES = 12
NDARRAY_CONTAINER = 13
OBJECT = 14
OBJECT_CELL = 14
EXT_BEGIN = 15
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import
import sys
from .base import _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_object
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_object
from ._ctypes.vmobj import ObjectBase as _ObjectBase
from ._ctypes.vmobj import _register_object
class ObjectTag(object):
"""Type code used in API calls"""
TENSOR = 0
CLOSURE = 1
DATATYPE = 2
class Object(_ObjectBase):
"""The VM Object used in Relay virtual machine."""
def register_object(cls):
_register_object(cls.tag, cls)
return cls
_set_class_object(Object)
......@@ -33,6 +33,8 @@ from . import parser
from . import debug
from . import param_dict
from . import feature
from .backend import vm
from .backend import vmobj
# Root operators
from .op import Op
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The VM Object FFI namespace."""
from tvm._ffi.function import _init_api
_init_api("_vmobj", __name__)
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
"""
The Relay Virtual Vachine.
......@@ -23,55 +23,38 @@ Implements a Python interface to compiling and executing on the Relay VM.
import numpy as np
import tvm
from tvm._ffi.function import Object
from .. import transform
from ..backend.interpreter import Executor
from ..expr import GlobalVar, Expr
from . import _vm
Object = Object
def optimize(mod):
"""Perform several optimizations on a module before executing it in the
Relay virtual machine.
Parameters
----------
mod : tvm.relay.Module
The module to optimize.
Returns
-------
ret : tvm.relay.Module
The optimized module.
"""
main_func = mod["main"]
opt_passes = []
if not main_func.params and isinstance(main_func.body, GlobalVar):
opt_passes.append(transform.EtaExpand())
opt_passes = opt_passes + [
transform.SimplifyInference(),
transform.FuseOps(),
transform.InferType()
]
seq = transform.Sequential(opt_passes)
return seq(mod)
from . import vmobj as _obj
from .interpreter import Executor
def _update_target(target):
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
def _convert(arg, cargs):
if isinstance(arg, np.ndarray):
tensor = _vm._Tensor(tvm.nd.array(arg))
cargs.append(tensor)
elif isinstance(arg, tvm.nd.NDArray):
tensor = _vm._Tensor(arg)
cargs.append(tensor)
elif isinstance(arg, tuple):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.tensor_object(arg))
elif isinstance(arg, (tuple, list)):
field_args = []
for field in arg:
_convert(field, field_args)
cargs.append(_vm._Tuple(*field_args))
cargs.append(_obj.tuple_object(field_args))
else:
raise "unsupported type"
......@@ -82,28 +65,97 @@ def convert(args):
return cargs
def _eval_vm(mod, ctx, *args):
class VirtualMachine(object):
"""Relay VM runtime."""
def __init__(self, mod):
self.mod = mod
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
def init(self, ctx):
"""Initialize the context in the VM.
Parameters
----------
ctx : :py:class:`TVMContext`
The runtime context to run the code on.
"""
Evaluate a module on a given context with the provided arguments.
args = [ctx.device_type, ctx.device_id]
self._init(*args)
def invoke(self, func_name, *args):
"""Invoke a function.
Parameters
----------
mod: relay.Module
The module to optimize, will execute its entry_func.
func_name : str
The name of the function.
ctx: tvm.Context
The TVM context to execute on.
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
args: List[tvm.NDArray, np.ndarray]
The arguments to evaluate.
Returns
-------
result : Object
The output.
"""
mod = optimize(mod)
args = list(args)
assert isinstance(args, list)
cargs = convert(args)
return self._invoke(func_name, *cargs)
def run(self, *args):
"""Run the main function.
Parameters
----------
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
Returns
-------
result : Object
The output.
"""
return self.invoke("main", *args)
class VMCompiler(object):
"""Build Relay module to run on VM runtime."""
def __init__(self):
self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
def compile(self, mod, target=None, target_host=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Returns
-------
vm : VirtualMachine
The VM runtime.
"""
target = _update_target(target)
self._compile(mod, target, target_host)
return VirtualMachine(self._get_vm())
result = _vm._evaluate_vm(mod, ctx.device_type, ctx.device_id, *cargs)
return result
class VMExecutor(Executor):
"""
......@@ -126,19 +178,21 @@ class VMExecutor(Executor):
The target option to build the function.
"""
def __init__(self, mod, ctx, target):
if mod is None:
raise RuntimeError("Must provide module to get VM executor.")
self.mod = mod
self.ctx = ctx
self.target = target
compiler = VMCompiler()
self.vm = compiler.compile(mod, target)
self.vm.init(ctx)
def _make_executor(self, expr=None):
expr = expr if expr else self.mod
assert expr, "either expr or self.mod should be not null."
if isinstance(expr, Expr):
self.mod["main"] = expr
assert expr is None
main = self.mod["main"]
def _vm_wrapper(*args, **kwargs):
args = self._convert_args(main, args, kwargs)
return _eval_vm(self.mod, self.ctx, *args)
return self.vm.run(*args)
return _vm_wrapper
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Runtime Object API."""
from __future__ import absolute_import as _abs
import numpy as _np
from tvm._ffi.vmobj import Object, ObjectTag, register_object
from tvm import ndarray as _nd
from . import _vmobj
# TODO(@icemelon9): Add ClosureObject
@register_object
class TensorObject(Object):
"""Tensor object."""
tag = ObjectTag.TENSOR
def __init__(self, handle):
"""Constructs a Tensor object
Parameters
----------
handle : object
Object handle
Returns
-------
obj : TensorObject
A tensor object.
"""
super(TensorObject, self).__init__(handle)
self.data = _vmobj.GetTensorData(self)
def asnumpy(self):
"""Convert data to numpy array
Returns
-------
np_arr : numpy.ndarray
The corresponding numpy array.
"""
return self.data.asnumpy()
@register_object
class DatatypeObject(Object):
"""Datatype object."""
tag = ObjectTag.DATATYPE
def __init__(self, handle):
"""Constructs a Datatype object
Parameters
----------
handle : object
Object handle
Returns
-------
obj : DatatypeObject
A Datatype object.
"""
super(DatatypeObject, self).__init__(handle)
self.tag = _vmobj.GetDatatypeTag(self)
num_fields = _vmobj.GetDatatypeNumberOfFields(self)
self.fields = []
for i in range(num_fields):
self.fields.append(_vmobj.GetDatatypeFields(self, i))
def __getitem__(self, idx):
return self.fields[idx]
def __len__(self):
return self.num_fields
def __iter__(self):
return iter(self.fields)
# TODO(icemelon9): Add closure object
def tensor_object(arr, ctx=_nd.cpu(0)):
"""Create a tensor object from source arr.
Parameters
----------
arr : numpy.ndarray or tvm.nd.NDArray
The source array.
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : TensorObject
The created object.
"""
if isinstance(arr, _np.ndarray):
tensor = _vmobj.Tensor(_nd.array(arr, ctx))
elif isinstance(arr, _nd.NDArray):
tensor = _vmobj.Tensor(arr)
else:
raise RuntimeError("Unsupported type for tensor object.")
return tensor
def tuple_object(fields):
"""Create a datatype object from source tuple.
Parameters
----------
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : DatatypeObject
The created object.
"""
for f in fields:
assert isinstance(f, Object)
return _vmobj.Tuple(*fields)
def datatype_object(tag, fields):
"""Create a datatype object from tag and source fields.
Parameters
----------
tag : int
The tag of datatype.
fields : list[Object] or tuple[Object]
The source tuple.
Returns
-------
ret : DatatypeObject
The created object.
"""
for f in fields:
assert isinstance(f, Object)
return _vmobj.Datatype(tag, *fields)
......@@ -529,6 +529,18 @@ def CanonicalizeCast():
return _transform.CanonicalizeCast()
def LambdaLift():
"""
Lift the closure to global function.
Returns
-------
ret : tvm.relay.Pass
The registered pass that lifts the lambda function.
"""
return _transform.LambdaLift()
def PrintIR():
"""
Print the IR for a module to help debugging.
......
......@@ -24,6 +24,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/build_module.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
#include <memory>
......
......@@ -63,6 +63,7 @@ using TagNameMap = std::unordered_map<size_t, tvm::relay::Constructor>;
using GlobalMap = NodeMap<GlobalVar, Index>;
using ConstMap = NodeMap<Constant, Index>;
using ConstTensorShapeMap = NodeMap<TensorType, std::pair<Index, NDArray>>;
using TargetsMap = Map<tvm::Integer, tvm::Target>;
struct VMCompilerContext {
// The module context for the compilation
......@@ -153,47 +154,134 @@ struct AccessField : MatchValue {
~AccessField() {}
};
struct VMCompiler;
/*!
* \brief Condition in a decision tree
*/
struct ConditionNode {
virtual ~ConditionNode() {}
};
using ConditionNodePtr = std::shared_ptr<ConditionNode>;
/*!
* \brief Compile a pattern match expression
* It first converts the pattern match expression into a desicision tree, the condition
* could be object comparison or variable binding. If any of the condition fails in a clause,
* the decision tree switches to check the conditions of next clause and so on. If no clause
* matches the value, a fatal node is inserted.
*
* After the decision tree is built, we convert it into bytecodes using If/Goto.
* \brief A var binding condition
*/
void CompileMatch(Match match, VMCompiler* compiler);
struct VarBinding : ConditionNode {
Var var;
MatchValuePtr val;
struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
/*! \brief Store the expression a variable points to. */
std::unordered_map<Var, Expr, NodeHash, NodeEqual> expr_map;
VarBinding(Var var, MatchValuePtr val)
: var(var), val(val) {}
std::vector<Instruction> instructions;
~VarBinding() {}
};
// var -> register num
std::unordered_map<Var, RegName, NodeHash, NodeEqual> var_register_map;
/*!
* \brief Compare the tag of the object
*/
struct TagCompare : ConditionNode {
/*! \brief The object to be examined */
MatchValuePtr obj;
size_t last_register;
/*! \brief The expected tag */
int target_tag;
// Total number of virtual registers allocated
size_t registers_num;
CompileEngine engine;
TagCompare(MatchValuePtr obj, size_t target)
: obj(obj), target_tag(target) {
}
/*! \brief Global shared meta data */
VMCompilerContext* context;
~TagCompare() {}
};
using TreeNodePtr = typename relay::TreeNode<ConditionNodePtr>::pointer;
using TreeLeafNode = relay::TreeLeafNode<ConditionNodePtr>;
using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionNodePtr>;
using TreeBranchNode = relay::TreeBranchNode<ConditionNodePtr>;
TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data,
Pattern pattern,
TreeNodePtr then_branch,
TreeNodePtr else_branch) {
if (pattern.as<PatternWildcardNode>()) {
// We ignore wildcard binding since it's not producing new vars
return then_branch;
} else if (pattern.as<PatternVarNode>()) {
auto pat = pattern.as<PatternVarNode>();
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pat = pattern.as<PatternConstructorNode>();
auto pattern = GetRef<PatternConstructor>(pat);
auto tag = pattern->constructor->tag;
size_t field_index = 0;
for (auto& p : pattern->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
}
}
TreeNodePtr BuildDecisionTreeFromClause(MatchValuePtr data,
Clause clause,
TreeNodePtr else_branch) {
return BuildDecisionTreeFromPattern(data, clause->lhs,
TreeLeafNode::Make(clause->rhs), else_branch);
}
VMCompiler(VMCompilerContext* context)
: instructions(),
var_register_map(),
last_register(0),
registers_num(0),
engine(CompileEngine::Global()),
context(context)
{}
TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> clauses) {
// When nothing matches, the VM throws fatal error
TreeNodePtr else_branch = TreeLeafFatalNode::Make();
// Start from the last clause
for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) {
else_branch = BuildDecisionTreeFromClause(data, *it, else_branch);
}
return else_branch;
}
size_t NewRegister() { return registers_num++; }
class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
public:
VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets)
: last_register_(0),
registers_num_(0),
engine_(CompileEngine::Global()),
context_(context),
targets_(targets) {}
VMFunction Compile(const GlobalVar& var, const Function& func) {
size_t i = 0;
// We then assign register num to the free variables
for (auto param : func->params) {
auto arg_register = NewRegister();
CHECK_EQ(i, arg_register);
var_register_map_.insert({param, arg_register});
params_.push_back(param->name_hint());
++i;
}
if (IsClosure(func)) {
Function inner_func = Downcast<Function>(func->body);
for (auto param : inner_func->params) {
auto arg_register = NewRegister();
CHECK_EQ(i, arg_register);
var_register_map_.insert({param, arg_register});
params_.push_back(param->name_hint());
++i;
}
this->VisitExpr(inner_func->body);
} else {
this->VisitExpr(func->body);
}
instructions_.push_back(Instruction::Ret(last_register_));
return VMFunction(var->name_hint, params_, instructions_, registers_num_);
}
protected:
size_t NewRegister() { return registers_num_++; }
inline void Emit(const Instruction& instr) {
DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
......@@ -210,10 +298,10 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::AllocClosure:
case Opcode::Move:
case Opcode::InvokeClosure:
last_register = instr.dst;
last_register_ = instr.dst;
break;
case Opcode::InvokePacked:
last_register = instr.packed_args[instr.arity - 1];
last_register_ = instr.packed_args[instr.arity - 1];
break;
case Opcode::If:
case Opcode::Ret:
......@@ -221,21 +309,21 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::Fatal:
break;
}
instructions.push_back(instr);
instructions_.push_back(instr);
}
void VisitExpr_(const ConstantNode* const_node) {
auto rconst = GetRef<Constant>(const_node);
auto it = this->context->const_map.find(rconst);
CHECK(it != this->context->const_map.end());
auto it = this->context_->const_map.find(rconst);
CHECK(it != this->context_->const_map.end());
Emit(Instruction::LoadConst(it->second, NewRegister()));
}
void VisitExpr_(const VarNode* var_node) {
auto var = GetRef<Var>(var_node);
auto reg_it = this->var_register_map.find(var);
CHECK(reg_it != this->var_register_map.end());
last_register = reg_it->second;
auto reg_it = this->var_register_map_.find(var);
CHECK(reg_it != this->var_register_map_.end());
last_register_ = reg_it->second;
}
void VisitExpr_(const TupleNode* tuple_node) {
......@@ -244,7 +332,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
for (auto& field : tuple->fields) {
this->VisitExpr(field);
fields_registers.push_back(last_register);
fields_registers.push_back(last_register_);
}
// TODO(@jroesch): use correct tag
......@@ -259,29 +347,28 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
auto match = GetRef<Match>(match_node);
this->VisitExpr(match->data);
CompileMatch(match, this);
CompileMatch(match);
}
void VisitExpr_(const LetNode* let_node) {
DLOG(INFO) << let_node->value;
this->VisitExpr(let_node->value);
DLOG(INFO) << this->last_register;
var_register_map.insert({let_node->var, this->last_register});
var_register_map_.insert({let_node->var, this->last_register_});
this->VisitExpr(let_node->body);
}
void VisitExpr_(const TupleGetItemNode* get_node) {
auto get = GetRef<TupleGetItem>(get_node);
this->VisitExpr(get->tuple);
auto tuple_register = last_register;
auto tuple_register = last_register_;
Emit(Instruction::GetField(tuple_register, get->index, NewRegister()));
}
void VisitExpr_(const GlobalVarNode* gvar) {
auto var = GetRef<GlobalVar>(gvar);
auto func = this->context->module->Lookup(var);
auto it = this->context->global_map.find(var);
CHECK(it != this->context->global_map.end());
auto func = context_->module->Lookup(var);
auto it = context_->global_map.find(var);
CHECK(it != context_->global_map.end());
// Allocate closure with zero free vars
Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister()));
}
......@@ -289,30 +376,30 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
void VisitExpr_(const IfNode* if_node) {
this->VisitExpr(if_node->cond);
size_t test_register = last_register;
size_t test_register = last_register_;
this->Emit(Instruction::LoadConsti(1, NewRegister()));
auto after_cond = this->instructions.size();
auto target_register = this->last_register;
auto after_cond = instructions_.size();
auto target_register = last_register_;
this->Emit(Instruction::If(test_register, target_register, 0, 0));
this->VisitExpr(if_node->true_branch);
size_t true_register = last_register;
size_t true_register = last_register_;
Emit(Instruction::Goto(0));
// Finally store how many instructions there are in the
// true branch.
auto after_true = this->instructions.size();
auto after_true = this->instructions_.size();
this->VisitExpr(if_node->false_branch);
size_t false_register = last_register;
size_t false_register = last_register_;
// In else-branch, override the then-branch register
Emit(Instruction::Move(false_register, true_register));
// Compute the total number of instructions
// after generating false.
auto after_false = this->instructions.size();
auto after_false = this->instructions_.size();
// Now we will compute the jump targets in order
// to properly patch the instruction with the
......@@ -322,13 +409,13 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
// we patch up the if instruction, and goto.
auto true_offset = 1;
auto false_offset = after_true - after_cond;
this->instructions[after_cond].if_op.true_offset = true_offset;
this->instructions[after_cond].if_op.false_offset = false_offset;
instructions_[after_cond].if_op.true_offset = true_offset;
instructions_[after_cond].if_op.false_offset = false_offset;
// Patch the Goto.
this->instructions[after_true - 1].pc_offset = (after_false - after_true) + 1;
this->instructions_[after_true - 1].pc_offset = (after_false - after_true) + 1;
this->last_register = true_register;
this->last_register_ = true_register;
}
Instruction AllocTensorFromType(const TensorTypeNode* ttype) {
......@@ -399,18 +486,27 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
// Next generate the invoke instruction.
CHECK(func->IsPrimitive());
auto target = Target::Create("llvm");
Target target;
if (targets_.size() == 1) {
// homogeneous execution.
for (auto kv : targets_) {
target = kv.second;
}
} else {
// heterogeneous execution.
LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
}
auto key = CCacheKeyNode::make(func, target);
auto cfunc = engine->Lower(key);
auto cfunc = engine_->Lower(key);
// TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1;
if (this->context->seen_funcs.find(cfunc->funcs[0]) == this->context->seen_funcs.end()) {
op_index = this->context->lowered_funcs.size();
this->context->lowered_funcs.push_back(cfunc->funcs[0]);
this->context->seen_funcs[cfunc->funcs[0]] = op_index;
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
op_index = context_->lowered_funcs.size();
context_->lowered_funcs.push_back(cfunc->funcs[0]);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = this->context->seen_funcs[cfunc->funcs[0]];
op_index = context_->seen_funcs[cfunc->funcs[0]];
}
Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
......@@ -430,7 +526,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
for (auto arg : call_node->args) {
this->VisitExpr(arg);
args_registers.push_back(last_register);
args_registers.push_back(last_register_);
}
Expr op = call_node->op;
......@@ -440,12 +536,12 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
EmitInvokePrimitive(GetRef<Function>(func_node), args_registers, call_node->checked_type());
} else if (auto global_node = op.as<GlobalVarNode>()) {
auto global = GetRef<GlobalVar>(global_node);
auto it = this->context->global_map.find(global);
CHECK(it != this->context->global_map.end());
auto it = context_->global_map.find(global);
CHECK(it != context_->global_map.end());
DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
<< " with func_index=" << it->second;
auto func = this->context->module->Lookup(global);
auto func = context_->module->Lookup(global);
if (IsClosure(func)) {
auto arity = func->params.size();
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
......@@ -458,7 +554,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
NewRegister()));
} else if (auto var_node = op.as<VarNode>()) {
VisitExpr(GetRef<Var>(var_node));
Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister()));
Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
} else {
LOG(FATAL) << "unsupported case in vm compiler: " << op;
}
......@@ -472,300 +568,175 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
}
}
void CompileClosure(const Function& func) {
// We first layout the function arguments.
auto inner_func = Downcast<Function>(func->body);
size_t i = 0;
for (auto param : inner_func->params) {
auto arg_register = NewRegister();
CHECK_EQ(i, arg_register);
var_register_map.insert({param, arg_register});
i++;
}
// We then assign register num to the free variables
for (auto param : func->params) {
auto arg_register = NewRegister();
CHECK_EQ(i, arg_register);
var_register_map.insert({param, arg_register});
i++;
}
// We will now process the body like normal.
this->VisitExpr(inner_func->body);
}
void Compile(const Function& func) {
// We need to generate code specially for lifted closures.
if (IsClosure(func)) {
CompileClosure(func);
return;
}
for (size_t i = 0; i < func->params.size(); ++i) {
auto arg_register = NewRegister();
CHECK_EQ(arg_register, i);
var_register_map.insert({func->params[i], arg_register});
}
this->VisitExpr(func->body);
}
};
/*!
/*!
* \brief Compile a match value
* Generate byte code that compute the value specificed in val
*
* \return The register number assigned for the final value
*/
RegName CompileMatchValue(MatchValuePtr val, VMCompiler* compiler) {
RegName CompileMatchValue(MatchValuePtr val) {
if (std::dynamic_pointer_cast<RegisterValue>(val)) {
auto r = std::dynamic_pointer_cast<RegisterValue>(val);
return r->rergister_num;
} else {
auto path = std::dynamic_pointer_cast<AccessField>(val);
auto p = CompileMatchValue(path->parent, compiler);
compiler->Emit(Instruction::GetField(p, path->index, compiler->NewRegister()));
path->reg = compiler->last_register;
auto p = CompileMatchValue(path->parent);
Emit(Instruction::GetField(p, path->index, NewRegister()));
path->reg = last_register_;
return path->reg;
}
}
/*!
* \brief Condition in a decision tree
*/
struct ConditionNode {
virtual ~ConditionNode() {}
};
using ConditionNodePtr = std::shared_ptr<ConditionNode>;
/*!
* \brief A var binding condition
*/
struct VarBinding : ConditionNode {
Var var;
MatchValuePtr val;
VarBinding(Var var, MatchValuePtr val)
: var(var), val(val) {}
~VarBinding() {}
};
/*!
* \brief Compare the tag of the object
*/
struct TagCompare : ConditionNode {
/*! \brief The object to be examined */
MatchValuePtr obj;
/*! \brief The expected tag */
int target_tag;
TagCompare(MatchValuePtr obj, size_t target)
: obj(obj), target_tag(target) {
}
~TagCompare() {}
};
using TreeNodePtr = typename relay::TreeNode<ConditionNodePtr>::pointer;
using TreeLeafNode = relay::TreeLeafNode<ConditionNodePtr>;
using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionNodePtr>;
using TreeBranchNode = relay::TreeBranchNode<ConditionNodePtr>;
void CompileTreeNode(TreeNodePtr tree, VMCompiler* compiler) {
void CompileTreeNode(TreeNodePtr tree) {
if (std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree);
compiler->VisitExpr(node->body);
VisitExpr(node->body);
} else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
compiler->Emit(Instruction::Fatal());
Emit(Instruction::Fatal());
} else if (std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree);
if (std::dynamic_pointer_cast<TagCompare>(node->cond)) {
// For Tag compariton, generate branches
auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond);
auto r = CompileMatchValue(cond->obj, compiler);
compiler->Emit(Instruction::GetTag(r, compiler->NewRegister()));
auto operand1 = compiler->last_register;
compiler->Emit(Instruction::LoadConsti(cond->target_tag, compiler->NewRegister()));
auto operand2 = compiler->last_register;
compiler->Emit(Instruction::If(operand1, operand2, 1, 0));
auto cond_offset = compiler->instructions.size() - 1;
CompileTreeNode(node->then_branch, compiler);
auto if_reg = compiler->last_register;
compiler->Emit(Instruction::Goto(1));
auto goto_offset = compiler->instructions.size() - 1;
CompileTreeNode(node->else_branch, compiler);
auto else_reg = compiler->last_register;
compiler->Emit(Instruction::Move(else_reg, if_reg));
compiler->last_register = if_reg;
auto else_offset = compiler->instructions.size() - 1;
auto r = CompileMatchValue(cond->obj);
Emit(Instruction::GetTag(r, NewRegister()));
auto operand1 = last_register_;
Emit(Instruction::LoadConsti(cond->target_tag, NewRegister()));
auto operand2 = last_register_;
Emit(Instruction::If(operand1, operand2, 1, 0));
auto cond_offset = instructions_.size() - 1;
CompileTreeNode(node->then_branch);
auto if_reg = last_register_;
Emit(Instruction::Goto(1));
auto goto_offset = instructions_.size() - 1;
CompileTreeNode(node->else_branch);
auto else_reg = last_register_;
Emit(Instruction::Move(else_reg, if_reg));
last_register_ = if_reg;
auto else_offset = instructions_.size() - 1;
// Fixing offsets
compiler->instructions[cond_offset].if_op.false_offset = goto_offset - cond_offset + 1;
compiler->instructions[goto_offset].pc_offset = else_offset - goto_offset + 1;
instructions_[cond_offset].if_op.false_offset = goto_offset - cond_offset + 1;
instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1;
} else {
// For other non-branch conditions, move to then_branch directly
auto cond = std::dynamic_pointer_cast<VarBinding>(node->cond);
compiler->var_register_map[cond->var] = CompileMatchValue(cond->val, compiler);
CompileTreeNode(node->then_branch, compiler);
var_register_map_[cond->var] = CompileMatchValue(cond->val);
CompileTreeNode(node->then_branch);
}
}
}
TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data,
Pattern pattern,
TreeNodePtr then_branch,
TreeNodePtr else_branch) {
if (pattern.as<PatternWildcardNode>()) {
// We ignore wildcard binding since it's not producing new vars
return then_branch;
} else if (pattern.as<PatternVarNode>()) {
auto pat = pattern.as<PatternVarNode>();
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pat = pattern.as<PatternConstructorNode>();
auto pattern = GetRef<PatternConstructor>(pat);
auto tag = pattern->constructor->tag;
size_t field_index = 0;
for (auto& p : pattern->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
}
}
TreeNodePtr BuildDecisionTreeFromClause(MatchValuePtr data,
Clause clause,
TreeNodePtr else_branch) {
return BuildDecisionTreeFromPattern(data, clause->lhs,
TreeLeafNode::Make(clause->rhs), else_branch);
}
TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> clauses) {
// When nothing matches, the VM throws fatal error
TreeNodePtr else_branch = TreeLeafFatalNode::Make();
// Start from the last clause
for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) {
else_branch = BuildDecisionTreeFromClause(data, *it, else_branch);
/*!
* \brief Compile a pattern match expression
* It first converts the pattern match expression into a desicision tree, the condition
* could be object comparison or variable binding. If any of the condition fails in a clause,
* the decision tree switches to check the conditions of next clause and so on. If no clause
* matches the value, a fatal node is inserted.
*
* After the decision tree is built, we convert it into bytecodes using If/Goto.
*/
void CompileMatch(Match match) {
auto data = std::make_shared<RegisterValue>(last_register_);
auto decision_tree = BuildDecisionTreeFromClauses(data, match->clauses);
CompileTreeNode(decision_tree);
}
return else_branch;
}
void CompileMatch(Match match, VMCompiler* compiler) {
auto data = std::make_shared<RegisterValue>(compiler->last_register);
auto decision_tree = BuildDecisionTreeFromClauses(data, match->clauses);
CompileTreeNode(decision_tree, compiler);
}
protected:
/*! \brief Store the expression a variable points to. */
std::unordered_map<Var, Expr, NodeHash, NodeEqual> expr_map_;
/*! \brief Instructions in the VMFunction. */
std::vector<Instruction> instructions_;
/*! \brief Parameter names of the function. */
std::vector<std::string> params_;
/*! \brief Map from var to register number. */
std::unordered_map<Var, RegName, NodeHash, NodeEqual> var_register_map_;
/*! \brief Last used register number. */
size_t last_register_;
/*! \brief Total number of virtual registers allocated. */
size_t registers_num_;
/*! \brief Compiler engine to lower primitive functions. */
CompileEngine engine_;
/*! \brief Global shared meta data */
VMCompilerContext* context_;
/*! \brief Target devices. */
TargetsMap targets_;
};
void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
std::vector<PackedFunc>* packed_funcs) {
runtime::Module mod;
if (lowered_funcs.size() > 0) {
// TODO(@jroesch): we need to read target from build config
Target target = Target::Create("llvm");
if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target);
class VMCompiler : public runtime::ModuleNode {
public:
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
if (name == "compile") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
this->Compile(args[0], args[1], args[2]);
});
} else if (name == "get_vm") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(vm_);
});
} else {
LOG(FATAL) << "relay.backend.build is not registered";
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
CHECK(mod.operator->());
for (auto lfunc : lowered_funcs) {
packed_funcs->push_back(mod.GetFunction(lfunc->name));
}
}
}
VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false);
size_t params = func->params.size();
VMCompiler compiler(context);
compiler.Compile(func);
// return the last evaluated expression
compiler.instructions.push_back(Instruction::Ret(compiler.last_register));
// Would like to refactor this so we only check if closure once.
if (IsClosure(func)) {
auto inner_params = Downcast<Function>(func->body)->params.size();
return VMFunction(var->name_hint, params + inner_params, compiler.instructions,
compiler.registers_num);
} else {
return VMFunction(var->name_hint, params, compiler.instructions, compiler.registers_num);
const char* type_key() const final {
return "VMCompiler";
}
}
Module OptimizeModule(const Module& mod) {
transform::Sequential seq({transform::ToANormalForm(),
transform::InlinePrimitives(),
transform::LambdaLift(),
transform::InlinePrimitives()});
auto pass_ctx = transform::PassContext::Create();
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
return seq(mod);
}
void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) {
// First we populate global map.
size_t global_index = 0;
for (auto named_func : mod->functions) {
auto gvar = named_func.first;
global_map->insert({gvar, global_index++});
std::shared_ptr<VirtualMachine> GetVirtualMachine() const {
return vm_;
}
}
VirtualMachine CompileModule(const Module& mod_ref) {
Module mod = mod_ref;
void Compile(const Module& mod_ref,
const TargetsMap& targets,
const tvm::Target& target_host) {
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
targets_ = targets;
target_host_ = target_host;
vm_ = std::make_shared<VirtualMachine>();
// Run some optimizations first, this code should
// be moved to pass manager.
mod = OptimizeModule(mod);
VirtualMachine vm;
VMCompilerContext context;
context.module = mod;
context_.module = OptimizeModule(mod_ref);
// Populate the global map.
//
// This maps global variables to a global index
// in the VMFunction table.
PopulateGlobalMap(&context.global_map, mod);
PopulateGlobalMap();
// Next we populate constant map.
auto constant_analysis_result = LayoutConstantPool(mod);
context.const_map = std::get<0>(constant_analysis_result);
context.const_tensor_shape_map = std::get<1>(constant_analysis_result);
auto constant_analysis_result = LayoutConstantPool(context_.module);
context_.const_map = std::get<0>(constant_analysis_result);
context_.const_tensor_shape_map = std::get<1>(constant_analysis_result);
// Next we get ready by allocating space for
// the global state.
vm.functions.resize(mod->functions.size());
vm.constants.resize(context.const_map.size() + context.const_tensor_shape_map.size());
vm_->functions.resize(context_.module->functions.size());
vm_->constants.resize(context_.const_map.size() + context_.const_tensor_shape_map.size());
for (auto pair : context.const_map) {
vm.constants[pair.second] = Object::Tensor(pair.first->data);
for (auto pair : context_.const_map) {
vm_->constants[pair.second] = Object::Tensor(pair.first->data);
}
for (auto pair : context.const_tensor_shape_map) {
vm.constants[pair.second.first] = Object::Tensor(pair.second.second);
for (auto pair : context_.const_tensor_shape_map) {
vm_->constants[pair.second.first] = Object::Tensor(pair.second.second);
}
for (auto named_func : mod->functions) {
for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
auto func = named_func.second;
auto vm_func = CompileFunc(&context, gvar, func);
VMFunctionCompiler func_compiler(&context_, targets_);
auto vm_func = func_compiler.Compile(gvar, func);
size_t func_index = context.global_map.at(gvar);
CHECK(func_index < vm.functions.size());
vm.functions[func_index] = vm_func;
size_t func_index = context_.global_map.at(gvar);
CHECK(func_index < vm_->functions.size());
vm_->functions[func_index] = vm_func;
}
#ifdef USE_RELAY_DEBUG
......@@ -775,15 +746,80 @@ VirtualMachine CompileModule(const Module& mod_ref) {
}
#endif // USE_RELAY_DEBUG
PopulatePackedFuncMap(context.lowered_funcs, &vm.packed_funcs);
PopulatePackedFuncMap();
for (auto gv : context_.global_map) {
vm_->global_map.insert({gv.first->name_hint, gv.second});
}
}
protected:
Module OptimizeModule(const Module& mod) {
// TODO(@icemelon9): check number of targets and build config, add more optimization pass
transform::Sequential seq({transform::SimplifyInference(),
transform::ToANormalForm(),
transform::InlinePrimitives(),
transform::LambdaLift(),
transform::InlinePrimitives(),
transform::FuseOps()});
auto pass_ctx = transform::PassContext::Create();
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
return seq(mod);
}
void PopulateGlobalMap() {
// First we populate global map.
size_t global_index = 0;
for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
context_.global_map.insert({gvar, global_index++});
}
}
for (auto gv : context.global_map) {
vm.global_map_.insert({gv.first->name_hint, gv.second});
void PopulatePackedFuncMap() {
auto const& lowered_funcs = context_.lowered_funcs;
if (lowered_funcs.size() == 0) {
return;
}
runtime::Module mod;
// TODO(@icemelon9): support heterogeneous targets
Target target;
for (auto kv : targets_) {
target = kv.second;
}
if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()),
target, target_host_);
} else {
LOG(FATAL) << "relay.backend.build is not registered";
}
CHECK(mod.operator->());
for (auto lfunc : lowered_funcs) {
vm_->packed_funcs.push_back(mod.GetFunction(lfunc->name));
}
}
return vm;
protected:
/*! \brief Target devices. */
TargetsMap targets_;
/*! \brief Target host device. */
tvm::Target target_host_;
/*! \brief Global shared meta data */
VMCompilerContext context_;
/*! \brief Compiled virtual machine. */
std::shared_ptr<VirtualMachine> vm_;
};
runtime::Module CreateVMCompiler() {
std::shared_ptr<VMCompiler> exec = std::make_shared<VMCompiler>();
return runtime::Module(exec);
}
TVM_REGISTER_GLOBAL("relay._vm._VMCompiler")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateVMCompiler();
});
} // namespace vm
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/vm.cc
* \brief The Relay virtual machine.
*/
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/module.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/analysis.h>
namespace tvm {
namespace relay {
namespace vm {
runtime::vm::VirtualMachine CompileModule(const Module& mod);
using tvm::runtime::Object;
using tvm::runtime::ObjectTag;
using tvm::runtime::vm::VirtualMachine;
VirtualMachine FromModule(const Module& module, const std::vector<TVMContext>& ctxs) {
auto vm = CompileModule(module);
vm.Init(ctxs);
return vm;
}
Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
const std::vector<Object>& vm_args) {
VirtualMachine vm = FromModule(module, ctxs);
// TODO(zhiics): This measurement is for temporary usage. Remove it later. We
// need to introduce a better profiling method.
#if ENABLE_PROFILING
DLOG(INFO) << "Entry function is main." << std::endl;
auto start = std::chrono::high_resolution_clock::now();
#endif // ENABLE_PROFILING
Object res = vm.Invoke("main", vm_args);
#if ENABLE_PROFILING
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
LOG(INFO) << "Inference time: " << duration << "ms\n";
#endif // ENABLE_PROFILING
return res;
}
Value VMToValue(const relay::Module& module, Object obj) {
CHECK(module.defined());
switch (obj->tag) {
case ObjectTag::kTensor: {
return TensorValueNode::make(ToNDArray(obj));
}
case ObjectTag::kDatatype: {
const auto& data_type = obj.AsDatatype();
tvm::Array<Value> fields;
for (size_t i = 0; i < data_type->fields.size(); ++i) {
fields.push_back(VMToValue(module, data_type->fields[i]));
}
return ConstructorValueNode::make(data_type->tag, fields);
}
default:
LOG(FATAL) << "unsupported return value of type: " << obj->tag;
return Value();
}
}
TVM_REGISTER_API("relay._vm._Tensor").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Object::Tensor(args[0]);
});
TVM_REGISTER_API("relay._vm._Tuple").set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<Object> fields;
for (auto i = 0; i < args.size(); i++) {
fields.push_back(args[i]);
}
*ret = Object::Tuple(fields);
});
template <typename T>
std::string ToString(const T& t) {
std::stringstream s;
s << t;
return s.str();
}
TVM_REGISTER_API("relay._vm._ObjectTag").set_body([](TVMArgs args, TVMRetValue* ret) {
Object obj = args[0];
*ret = ToString(obj->tag);
});
TVM_REGISTER_API("relay._vm._Datatype")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*ret = Object::Datatype(tag, fields);
});
TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef to_compile = args[0];
TVMContext ctx;
int dev_type = args[1];
ctx.device_type = static_cast<DLDeviceType>(dev_type);
ctx.device_id = args[2];
Module module;
if (to_compile.as<FunctionNode>()) {
Function to_compile = args[0];
module = ModuleNode::FromExpr(to_compile);
} else if (to_compile.as<ModuleNode>()) {
module = args[0];
} else {
LOG(FATAL) << "expected function or module";
}
std::vector<Object> vm_args;
for (auto i = 3; i < args.size(); i++) {
Object obj = args[i];
vm_args.push_back(obj);
}
auto result = EvaluateModule(module, {ctx}, vm_args);
DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
*ret = VMToValue(module, result);
});
} // namespace vm
} // namespace relay
} // namespace tvm
......@@ -25,7 +25,10 @@
#include <tvm/logging.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include <iostream>
#include "../runtime_base.h"
namespace tvm {
namespace runtime {
......@@ -87,5 +90,69 @@ NDArray ToNDArray(const Object& obj) {
return tensor->data;
}
TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsTensor();
*rv = cell->data;
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsDatatype();
*rv = static_cast<int>(cell->tag);
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsDatatype();
*rv = static_cast<int>(cell->fields.size());
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
int idx = args[1];
auto cell = obj.AsDatatype();
CHECK_LT(idx, cell->fields.size());
*rv = cell->fields[idx];
});
TVM_REGISTER_GLOBAL("_vmobj.Tensor")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Object::Tensor(args[0]);
});
TVM_REGISTER_GLOBAL("_vmobj.Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<Object> fields;
for (auto i = 0; i < args.size(); ++i) {
fields.push_back(args[i]);
}
*rv = Object::Tuple(fields);
});
TVM_REGISTER_GLOBAL("_vmobj.Datatype")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*rv = Object::Datatype(tag, fields);
});
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
int TVMGetObjectTag(TVMObjectHandle handle, int* tag) {
API_BEGIN();
*tag = static_cast<int>(static_cast<ObjectCell*>(handle)->tag);
API_END();
}
......@@ -41,7 +41,7 @@ class PooledAllocator final : public Allocator {
static constexpr size_t kDefaultPageSize = 4096;
explicit PooledAllocator(TVMContext ctx, size_t page_size = kDefaultPageSize)
: Allocator(), page_size_(page_size), used_memory_(0) {}
: Allocator(), page_size_(page_size), used_memory_(0), ctx_(ctx) {}
~PooledAllocator() { ReleaseAll(); }
......
......@@ -32,8 +32,8 @@
#include <stdexcept>
#include <vector>
#include "../../runtime/vm/memory_manager.h"
#include "../../runtime/vm/naive_allocator.h"
#include "memory_manager.h"
#include "naive_allocator.h"
using namespace tvm::runtime;
......@@ -554,6 +554,37 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
return os;
}
PackedFunc VirtualMachine::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
std::vector<Object> func_args;
for (int i = 1; i < args.size(); ++i) {
Object obj = args[i];
func_args.push_back(obj);
}
*rv = this->Invoke(func_name, func_args);
});
} else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size() % 2, 0);
std::vector<TVMContext> contexts;
for (int i = 0; i < args.size() / 2; ++i) {
TVMContext ctx;
int device_type = args[i * 2];
ctx.device_type = DLDeviceType(device_type);
ctx.device_id = args[i * 2 + 1];
contexts.push_back(ctx);
}
this->Init(contexts);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}
void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) {
auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size);
frames.push_back(frame);
......@@ -573,11 +604,11 @@ Index VirtualMachine::PopFrame() {
void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Object>& args) {
DLOG(INFO) << "Invoking global " << func.name << " " << args.size();
PushFrame(func.params, this->pc + 1, func);
PushFrame(func.params.size(), this->pc + 1, func);
for (size_t i = 0; i < args.size(); ++i) {
WriteRegister(i, args[i]);
}
DLOG(INFO) << "func.params= " << func.params;
DLOG(INFO) << "func.params= " << func.params.size();
code = func.instructions.data();
pc = 0;
......@@ -594,7 +625,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>&
}
Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object>& args) {
auto func_index = this->global_map_[name];
auto func_index = this->global_map[name];
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index;
return Invoke(this->functions[func_index], args);
}
......@@ -719,12 +750,12 @@ void VirtualMachine::Run() {
auto object = ReadRegister(instr.closure);
const auto& closure = object.AsClosure();
std::vector<Object> args;
for (Index i = 0; i < instr.closure_args_num; ++i) {
args.push_back(ReadRegister(instr.closure_args[i]));
}
for (auto free_var : closure->free_vars) {
args.push_back(free_var);
}
for (Index i = 0; i < instr.closure_args_num; ++i) {
args.push_back(ReadRegister(instr.closure_args[i]));
}
InvokeGlobal(this->functions[closure->func_index], args);
frames.back().caller_return_register = instr.dst;
goto main_loop;
......
......@@ -92,9 +92,9 @@ def test_squeezenet():
def test_inception_v3():
image_shape = (1, 3, 299, 299)
image_shape = (3, 299, 299)
mod, params = testing.inception_v3.get_workload(image_shape=image_shape)
benchmark_execution(mod, params, data_shape=image_shape)
benchmark_execution(mod, params, data_shape=(1, 3, 299, 299))
def test_dqn():
......@@ -107,7 +107,7 @@ def test_dqn():
def test_dcgan():
image_shape = (1, 100)
mod, params = testing.dcgan.get_workload(batch_size=1)
benchmark_execution(mod, params, data_shape=image_shape)
benchmark_execution(mod, params, data_shape=image_shape, out_shape=(1, 3, 64, 64))
def test_mobilenet():
......@@ -126,8 +126,7 @@ if __name__ == '__main__':
test_squeezenet()
test_mobilenet()
test_densenet()
# The following networks fail
# test_inception_v3()
# test_mlp()
# test_dqn()
# test_dcgan()
test_inception_v3()
test_mlp()
test_dqn()
test_dcgan()
......@@ -23,40 +23,44 @@ from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude
def veval(f, *args, ctx=tvm.cpu()):
def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr):
ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx)
if len(args) == 0:
return ex.evaluate(f)
else:
return ex.evaluate(f)(*args)
mod = relay.Module()
mod["main"] = f
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(tvm.cpu())
return vm.run(*args)
else:
assert isinstance(f, relay.Module), "expected expression or module"
mod = f
ex = relay.create_executor('vm', mod=mod, ctx=ctx)
if len(args) == 0:
return ex.evaluate()
else:
return ex.evaluate()(*args)
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(tvm.cpu())
ret = vm.run(*args)
return ret
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy().tolist()]
if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
result = []
for f in o.fields:
for f in o:
result.extend(vmobj_to_list(f))
return result
else:
raise RuntimeError("Unknown object type: %s" % type(o))
def test_split():
x = relay.var('x', shape=(12,))
y = relay.split(x, 3, axis=0).astuple()
z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0)
f = relay.Function([x], z)
f = relay.Function([x], y)
x_data = np.random.rand(12,).astype('float32')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
ref_res = np.split(x_data, 3, axis=0)
for i in range(3):
tvm.testing.assert_allclose(res[i].asnumpy(), ref_res[i])
def test_split_no_fuse():
x = relay.var('x', shape=(12,))
......@@ -68,9 +72,8 @@ def test_split_no_fuse():
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
def test_id():
x = relay.var('x', shape=(10, 10))
x = relay.var('x', shape=(10, 10), dtype='float64')
f = relay.Function([x], x)
x_data = np.random.rand(10, 10).astype('float64')
res = veval(f, x_data)
......@@ -209,7 +212,7 @@ def test_list_constructor():
mod["main"] = f
result = veval(mod)()
result = veval(mod)
obj = vmobj_to_list(result)
tvm.testing.assert_allclose(obj, np.array([3,2,1]))
......@@ -275,7 +278,7 @@ def test_compose():
mod["main"] = f
x_data = np.array(np.random.rand()).astype('float32')
result = veval(mod)(x_data)
result = veval(mod, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
......@@ -296,7 +299,7 @@ def test_list_hd():
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 3)
@raises(Exception)
......@@ -312,7 +315,7 @@ def test_list_tl_empty_list():
mod["main"] = f
result = veval(mod)()
result = veval(mod)
print(result)
def test_list_tl():
......@@ -332,7 +335,7 @@ def test_list_tl():
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1]))
def test_list_nth():
......@@ -351,7 +354,7 @@ def test_list_nth():
f = relay.Function([], nth(l, relay.const(i)))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), expected[i])
def test_list_update():
......@@ -375,7 +378,7 @@ def test_list_update():
f = relay.Function([], l)
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected))
def test_list_length():
......@@ -397,7 +400,7 @@ def test_list_length():
f = relay.Function([], l)
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 10)
def test_list_map():
......@@ -415,7 +418,7 @@ def test_list_map():
f = relay.Function([], map(add_one_func, l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2]))
def test_list_foldl():
......@@ -433,7 +436,7 @@ def test_list_foldl():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldl(rev_dup_func, nil(), l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1]))
def test_list_foldr():
......@@ -451,7 +454,7 @@ def test_list_foldr():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldr(identity_func, nil(), l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3]))
def test_list_sum():
......@@ -465,7 +468,7 @@ def test_list_sum():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], sum(l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 6)
def test_list_filter():
......@@ -485,7 +488,7 @@ def test_list_filter():
cons(relay.const(1), nil())))))
f = relay.Function([], filter(greater_than_one, l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5]))
def test_closure():
......@@ -513,6 +516,10 @@ if __name__ == "__main__":
test_split()
test_split_no_fuse()
test_list_constructor()
test_let_tensor()
test_let_scalar()
test_compose()
test_list_hd()
test_list_tl_empty_list()
test_list_tl()
test_list_nth()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment