Commit 4052de6d by Zhi Committed by Haichen Shen

[relay][vm] Separate VM runtime with executable (#4100)

* [relay][vm] Separate VM runtime with executable

* Address comments

* move ctx back to vm

* make only vm related fields and methods protected

* integrate seriliaztion/deserialization to executable

* create stream
parent cf046972
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -430,15 +431,184 @@ struct VMFrame { ...@@ -430,15 +431,184 @@ struct VMFrame {
caller_return_register(0) {} caller_return_register(0) {}
}; };
/*! \brief The executable emitted by the VM compiler.
*
* The executable contains information (e.g. data in different memory regions)
* to run in a virtual machine.
*
* - Global section, containing all globals.
* - Constant section, storing the constant pool.
* - Primitive name section, containing the function name of the primitive ops
* used by the virtual machine.
* - Code section, handling the VM functions and bytecode.
*/
class Executable : public ModuleNode {
public:
/*!
* \brief Get a PackedFunc from an executable module.
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
*
* \return PackedFunc or nullptr when it is not available.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
/*!
* \brief Serialize the executable into global section, constant section, and
* code section.
*
* \return The binary representation of the VM.
*/
TVMByteArray Save();
/*!
* \brief Load the saved VM executable.
*
* \param code The bytecode in string.
* \param lib The compiled runtime library.
*
* \return exe The constructed executable.
*/
static runtime::Module Load(const std::string& code, const runtime::Module lib);
/*!
* \brief Get the serialized form of the `functions`. This is
* essentially bytecode serialization.
*
* \return The serialized vm bytecode.
*
* \note The bytecode is in the following format:
* func_name reg_file_size num_instructions
* param1 param2 ... paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Each instruction is printed in the following format:
* opcode num_fields field1 ... fieldX # The text format.
*
* Serializing an `Instruction` requires us to deal with the bytecode. Each line
* of the instructions could be serialized as the following format:
* hash, opcode, f1, f2, ..., fX, field with variable length
* 1. hash: the hash of the instruction. This number will be used to help us
* validate if an instruction is well-formed during deserialization.
* 2. opcode: the opcode code of the instruction.
* 3. f1, f2, ..., fX. These fields together represent the fixed fields in
* an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For
* example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
* 4. The rest of the line indicates the field with variable length, e.g.,
* the shape of a tensor, the args used by an `InvokPacked` instruction, etc.
* The field starting from # is only used for debugging. The serialized code
* doesn't contain it, therefore the deserializer doens't need to handle it.
*/
std::string GetBytecode() const;
/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
*/
std::string Stats() const;
/*! \brief Get the `lib` module in an executable. Users have the flexibility to call
* `export_library` from the frontend to save the library to disk.
*
* \return The runtime module that contains the hardwre dependent code.
*/
runtime::Module GetLib() const { return lib; }
virtual ~Executable() {}
const char* type_key() const final {
return "VMExecutable";
}
/*! \brief The runtime module/library that contains both the host and also the device
* code when executing on non-CPU devices. */
runtime::Module lib;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
* corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
*/
std::unordered_map<std::string, Index> primitive_map;
/*! \brief The virtual machine's function table. */
std::vector<VMFunction> functions;
private:
/*!
* \brief Save the globals.
*
* \param strm The input stream.
*/
void SaveGlobalSection(dmlc::Stream* strm);
/*!
* \brief Save the constant pool.
*
* \param strm The input stream.
*/
void SaveConstantSection(dmlc::Stream* strm);
/*!
* \brief Save primitive op names.
*
* \param strm The input stream.
*/
void SavePrimitiveOpNames(dmlc::Stream* strm);
/*!
* \brief Save the vm functions.
*
* \param strm The input stream.
*/
void SaveCodeSection(dmlc::Stream* strm);
/*!
* \brief Load the globals.
*
* \param strm The input stream.
*/
void LoadGlobalSection(dmlc::Stream* strm);
/*!
* \brief Load the constant pool.
*
* \param strm The input stream.
*/
void LoadConstantSection(dmlc::Stream* strm);
/*!
* \brief Load primitive op names.
*
* \param strm The input stream.
*/
void LoadPrimitiveOpNames(dmlc::Stream* strm);
/*!
* \brief Load the vm functions.
*
* \param strm The input stream.
*/
void LoadCodeSection(dmlc::Stream* strm);
/*! \brief The serialized bytecode. */
std::string code_;
};
/*! \brief The virtual machine. /*! \brief The virtual machine.
* *
* The virtual machine contains all the current execution state, * The virtual machine contains all the current execution state,
* as well as the global view of functions, the global constant * as well as the executable.
* table, the compiled operators.
* *
* The goal is to have a single self-contained object, * The goal is to have a single self-contained object,
* enabling one to easily pass around VMs, execute them on * enabling one to easily pass around VMs, execute them on
* multiple threads, or serialized them to disk or over the * multiple threads, or serialize them to disk or over the
* wire. * wire.
*/ */
class VirtualMachine : public runtime::ModuleNode { class VirtualMachine : public runtime::ModuleNode {
...@@ -486,16 +656,18 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -486,16 +656,18 @@ class VirtualMachine : public runtime::ModuleNode {
return "VirtualMachine"; return "VirtualMachine";
} }
/*! \brief The runtime module/library that contains generated code. */ VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {}
runtime::Module lib;
/*! \brief load the executable for the virtual machine.
* \param exec The executable.
*/
void LoadExecutable(const Executable* exec);
protected:
/*! \brief The virtual machine's packed function table. */ /*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs; std::vector<PackedFunc> packed_funcs;
/*! \brief The virtual machine's function table. */
std::vector<VMFunction> functions;
/*! \brief The current stack of call frames. */ /*! \brief The current stack of call frames. */
std::vector<VMFrame> frames; std::vector<VMFrame> frames;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief The fuction table index of the current function. */ /*! \brief The fuction table index of the current function. */
Index func_index; Index func_index;
/*! \brief The current pointer to the code section. */ /*! \brief The current pointer to the code section. */
...@@ -506,6 +678,9 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -506,6 +678,9 @@ class VirtualMachine : public runtime::ModuleNode {
/*! \brief The special return register. */ /*! \brief The special return register. */
ObjectRef return_register; ObjectRef return_register;
/*! \brief The executable the VM will operate on. */
const Executable* exec;
/*! \brief The set of TVM contexts the VM is currently executing on. */ /*! \brief The set of TVM contexts the VM is currently executing on. */
std::vector<TVMContext> ctxs; std::vector<TVMContext> ctxs;
...@@ -550,8 +725,6 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -550,8 +725,6 @@ class VirtualMachine : public runtime::ModuleNode {
*/ */
ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args); ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args);
VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {}
/*! \brief Initialize the virtual machine for a set of contexts. /*! \brief Initialize the virtual machine for a set of contexts.
* \param contexts The set of TVM contexts. * \param contexts The set of TVM contexts.
*/ */
...@@ -565,21 +738,6 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -565,21 +738,6 @@ class VirtualMachine : public runtime::ModuleNode {
*/ */
TVMContext GetParamsContext() const; TVMContext GetParamsContext() const;
/*!
* \brief Load parameters from the parameter bytearray.
* \param params The binary file that contains parameters.
*/
void LoadParams(const std::string& params);
/*! \brief A map from globals (as strings) to their index in the function map.
*/
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
* corresponds to the position of the `packed_funcs` list.
*/
std::unordered_map<std::string, Index> primitive_map;
private: private:
/*! \brief Invoke a global setting up the VM state to execute. /*! \brief Invoke a global setting up the VM state to execute.
* *
......
...@@ -37,8 +37,6 @@ from . import param_dict ...@@ -37,8 +37,6 @@ from . import param_dict
from . import feature from . import feature
from .backend import vm from .backend import vm
from .backend import profiler_vm from .backend import profiler_vm
from .backend import serializer
from .backend import deserializer
from .backend import vmobj from .backend import vmobj
# Root operators # Root operators
......
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
The Relay Virtual Machine deserializer.
Python interface for deserializing a Relay VM.
"""
from tvm import module
from tvm._ffi.runtime_ctypes import TVMByteArray
from . import _vm
from . import vm as rly_vm
def _create_deserializer(code, lib):
"""Create a deserializer object.
Parameters
----------
code : bytearray
The serialized virtual machine code.
lib : :py:class:`~tvm.module.Module`
The serialized runtime module/library that contains the hardware
dependent binary code.
Returns
-------
ret : Deserializer
The created virtual machine deserializer.
"""
if isinstance(code, (bytes, str)):
code = bytearray(code)
elif not isinstance(code, (bytearray, TVMByteArray)):
raise TypeError("vm is expected to be the type of bytearray or " +
"TVMByteArray, but received {}".format(type(code)))
if not isinstance(lib, module.Module):
raise TypeError("lib is expected to be the type of tvm.module.Module" +
", but received {}".format(type(lib)))
return _vm._Deserializer(code, lib)
class Deserializer:
"""Relay VM deserializer.
Parameters
----------
code : bytearray
The serialized virtual machine code.
lib : :py:class:`~tvm.module.Module`
The serialized runtime module/library that contains the hardware
dependent binary code.
"""
def __init__(self, code, lib):
self.mod = _create_deserializer(code, lib)
self._deserialize = self.mod["deserialize"]
def deserialize(self):
"""Deserialize the serialized bytecode into a Relay VM.
Returns
-------
ret : VirtualMachine
The deserialized Relay VM.
"""
return rly_vm.VirtualMachine(self._deserialize())
...@@ -49,8 +49,8 @@ def compile(mod, target=None, target_host=None, params=None): ...@@ -49,8 +49,8 @@ def compile(mod, target=None, target_host=None, params=None):
Returns Returns
------- -------
vm : VirtualMachineProfiler exec : Executable
The profile VM runtime. The executable with profiling code.
""" """
compiler = VMCompilerProfiler() compiler = VMCompilerProfiler()
target = compiler.update_target(target) target = compiler.update_target(target)
...@@ -60,7 +60,7 @@ def compile(mod, target=None, target_host=None, params=None): ...@@ -60,7 +60,7 @@ def compile(mod, target=None, target_host=None, params=None):
tophub_context = compiler.tophub_context(target) tophub_context = compiler.tophub_context(target)
with tophub_context: with tophub_context:
compiler._compile(mod, target, target_host) compiler._compile(mod, target, target_host)
return VirtualMachineProfiler(compiler._get_vm()) return vm.Executable(compiler._get_exec())
class VMCompilerProfiler(vm.VMCompiler): class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime.""" """Build Relay module to run on VM runtime."""
...@@ -68,13 +68,17 @@ class VMCompilerProfiler(vm.VMCompiler): ...@@ -68,13 +68,17 @@ class VMCompilerProfiler(vm.VMCompiler):
super().__init__() super().__init__()
self.mod = _vm._VMCompilerProfiler() self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"] self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"] self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"] self._set_params_func = self.mod["set_params"]
class VirtualMachineProfiler(vm.VirtualMachine): class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime.""" """Relay profile VM runtime."""
def __init__(self, mod): def __init__(self, mod):
super().__init__(mod) super().__init__(mod)
m = mod.module if isinstance(mod, vm.Executable) else mod
self.mod = _vm._VirtualMachineDebug(m)
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
self._get_stat = self.mod["get_stat"] self._get_stat = self.mod["get_stat"]
def get_stat(self): def get_stat(self):
......
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
The Relay Virtual Machine serializer.
Python interface for serializing a Relay VM.
"""
import tvm
from . import _vm
from . import vm as rly_vm
def _create_serializer(vm):
"""Create a VM serializer.
Parameters
----------
vm : Union[VirtualMachine, :py:class:`~tvm.module.Module`]
The virtual machine to be serialized.
Returns
-------
ret : Serializer
The created virtual machine serializer.
"""
if isinstance(vm, rly_vm.VirtualMachine):
vm = vm.module
elif not isinstance(vm, tvm.module.Module):
raise TypeError("vm is expected to be the type of VirtualMachine or " +
"tvm.Module, but received {}".format(type(vm)))
return _vm._Serializer(vm)
class Serializer:
"""Relay VM serializer."""
def __init__(self, vm):
self.mod = _create_serializer(vm)
self._get_lib = self.mod["get_lib"]
self._get_bytecode = self.mod["get_bytecode"]
self._get_globals = self.mod["get_globals"]
self._get_stats = self.mod["get_stats"]
self._get_primitive_ops = self.mod["get_primitive_ops"]
self._serialize = self.mod["serialize"]
@property
def stats(self):
"""Get the statistics of the Relay VM.
Returns
-------
ret : String
The serialized statistic information.
"""
return self._get_stats()
@property
def primitive_ops(self):
"""Get the name of the primitive ops that are executed in the VM.
Returns
-------
ret : List[:py:class:`~tvm.expr.StringImm`]
The list of primitive ops.
"""
return [prim_op.value for prim_op in self._get_primitive_ops()]
@property
def bytecode(self):
"""Get the bytecode of the Relay VM.
Returns
-------
ret : String
The serialized bytecode.
Notes
-----
The bytecode is in the following format:
func_name reg_file_size num_instructions
param1 param2 ... paramM
instruction1
instruction2
...
instructionN
Each instruction is printed in the following format:
hash opcode field1 ... fieldX # The text format.
The part starting from # is only used for visualization and debugging.
The real serialized code doesn't contain it, therefore the deserializer
doesn't need to deal with it as well.
"""
return self._get_bytecode()
@property
def globals(self):
"""Get the globals used by the Relay VM.
Returns
-------
ret : List[:py:class:`~tvm.expr.StringImm`]
The serialized globals.
"""
return [glb.value for glb in self._get_globals()]
def serialize(self):
"""Serialize the Relay VM.
Returns
-------
code : bytearray
The binary blob representing a serialized Relay VM. It can then be
saved to disk and later deserialized into a new VM.
lib : :py:class:`~tvm.module.Module`
The runtime module that contains the generated code. It is
basically a library that is composed of hardware dependent code.
Notes
-----
The returned code is organized with the following sections in order.
- Global section. This section contains the globals used by the
virtual machine.
- Constant section. This section is used to store the constant pool of
a virtual machine.
- Primitive name section. This section is introduced to accommodate
the list of primitive operator names that will be invoked by the
virtual machine.
- Code section. The VM functions, including bytecode, are sitting in
this section.
Examples
--------
.. code-block:: python
import numpy as np
import tvm
from tvm import relay
# define a simple network.
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x + x)
mod = relay.Module({"main": f})
# create a Relay VM.
ctx = tvm.cpu()
target = "llvm"
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(ctx)
# serialize.
ser = relay.serializer.Serializer(vm)
code, lib = ser.serialize()
# save and load the code and lib file.
tmp = tvm.contrib.util.tempdir()
path_lib = tmp.relpath("lib.so")
lib.export_library(path_lib)
with open(tmp.relpath("code.bc"), "wb") as fo:
fo.write(code)
loaded_lib = tvm.module.load(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read())
# deserialize.
deser = relay.deserializer.Deserializer(loaded_code, loaded_lib)
des_vm = deser.deserialize()
# execute the deserialized vm.
des_vm.init(ctx)
x_data = np.random.rand(10, 10).astype('float32')
res = des_vm.run(x_data)
print(res.asnumpy())
"""
return self._serialize(), self._get_lib()
...@@ -24,8 +24,8 @@ import numpy as np ...@@ -24,8 +24,8 @@ import numpy as np
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm.relay import expr as _expr from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from . import _vm from . import _vm
from . import vmobj as _obj from . import vmobj as _obj
from .interpreter import Executor from .interpreter import Executor
...@@ -44,6 +44,7 @@ def _convert(arg, cargs): ...@@ -44,6 +44,7 @@ def _convert(arg, cargs):
else: else:
raise "unsupported type" raise "unsupported type"
def convert(args): def convert(args):
cargs = [] cargs = []
for arg in args: for arg in args:
...@@ -52,12 +53,202 @@ def convert(args): ...@@ -52,12 +53,202 @@ def convert(args):
return cargs return cargs
class Executable(object):
"""Relay VM executable"""
def __init__(self, mod):
self.mod = mod
self._save = self.mod["save"]
self._get_lib = self.mod["get_lib"]
self._get_bytecode = self.mod["get_bytecode"]
self._get_stats = self.mod["get_stats"]
def save(self):
"""Save the Relay VM Executable.
Returns
-------
code : bytearray
The binary blob representing a serialized Relay VM executable. It
can then be saved to disk and later deserialized into a new
Executable.
lib : :py:class:`~tvm.module.Module`
The runtime module that contains the generated code. It is
basically a library that is composed of hardware dependent code.
Notes
-----
The returned code is organized with the following sections in order.
- Global section. This section contains the globals used by the
virtual machine.
- Constant section. This section is used to store the constant pool of
a virtual machine.
- Primitive name section. This section is introduced to accommodate
the list of primitive operator names that will be invoked by the
virtual machine.
- Code section. The VM functions, including bytecode, are sitting in
this section.
Examples
--------
.. code-block:: python
import numpy as np
import tvm
from tvm import relay
# define a simple network.
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x + x)
mod = relay.Module({"main": f})
# create a Relay VM.
ctx = tvm.cpu()
target = "llvm"
executable = relay.vm.compile(mod, target)
code, lib = executable.save()
# save and load the code and lib file.
tmp = tvm.contrib.util.tempdir()
path_lib = tmp.relpath("lib.so")
lib.export_library(path_lib)
with open(tmp.relpath("code.ro"), "wb") as fo:
fo.write(code)
loaded_lib = tvm.module.load(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read())
# deserialize.
des_exec = relay.vm.Executable.load_exec(loaded_code, loaded_code)
# execute the deserialized executable.
x_data = np.random.rand(10, 10).astype('float32')
des_vm = relay.vm.VirtualMachine(des_exec)
des_vm.init(ctx)
res = des_vm.run(x_data)
print(res.asnumpy())
"""
return self._save(), self._get_lib()
@staticmethod
def load_exec(bytecode, lib):
"""Construct an executable from saved artifacts.
Parameters
----------
bytecode : bytearray
The binary blob representing a the Relay VM bytecode.
lib : :py:class:`~tvm.module.Module`
The runtime module that contains the generated code.
Returns
-------
exec: Executable
An executable constructed using the provided artifacts.
"""
if isinstance(bytecode, (bytes, str)):
code = bytearray(bytecode)
elif not isinstance(bytecode, (bytearray, TVMByteArray)):
raise TypeError("bytecode is expected to be the type of bytearray " +
"or TVMByteArray, but received {}".format(type(code)))
if not isinstance(lib, tvm.module.Module):
raise TypeError("lib is expected to be the type of tvm.module.Module" +
", but received {}".format(type(lib)))
return Executable(_vm.Load_Executable(bytecode, lib))
@property
def lib(self):
"""Get the library that contains hardware dependent code.
Returns
-------
ret : :py:class:`~tvm.Module`
The runtime module that contains hardware dependent code.
"""
return self._get_lib()
@property
def stats(self):
"""Get the statistics of the Relay VM executable.
Returns
-------
ret : String
The statistic information of the VM executable.
"""
return self._get_stats()
@property
def primitive_ops(self):
"""Get the name of the primitive ops contained in the executable.
Returns
-------
ret : List[String]
The list of primitive ops.
"""
ret = []
num_primitives = _vm.GetNumOfPrimitives(self.module)
for i in range(num_primitives):
ret.append(_vm.GetPrimitiveFields(self.module, i))
return ret
@property
def bytecode(self):
"""Get the bytecode of the Relay VM executable.
Returns
-------
ret : String
The bytecode of the executable.
Notes
-----
The bytecode is in the following format:
func_name reg_file_size num_instructions
param1 param2 ... paramM
instruction1
instruction2
...
instructionN
Each instruction is printed in the following format:
hash opcode field1 ... fieldX # The text format.
The part starting from # is only used for visualization and debugging.
The real serialized code doesn't contain it, therefore the deserializer
doesn't need to deal with it as well.
"""
return self._get_bytecode()
@property
def globals(self):
"""Get the globals used by the Relay VM executable.
Returns
-------
ret : List[String]
The globals contained in the executable.
"""
ret = []
num_globals = _vm.GetNumOfGlobals(self.module)
for i in range(num_globals):
ret.append(_vm.GetGlobalFields(self.module, i))
return ret
@property
def module(self):
"""Return the runtime module contained in a virtual machine executable."""
return self.mod
class VirtualMachine(object): class VirtualMachine(object):
"""Relay VM runtime.""" """Relay VM runtime."""
def __init__(self, mod): def __init__(self, mod):
self.mod = mod if not isinstance(mod, (Executable, tvm.module.Module)):
raise TypeError("mod is expected to be the type of Executable or " +
"tvm.Module, but received {}".format(type(mod)))
m = mod.module if isinstance(mod, Executable) else mod
self.mod = _vm._VirtualMachine(m)
self._init = self.mod["init"] self._init = self.mod["init"]
self._load_params = self.mod["load_params"]
self._invoke = self.mod["invoke"] self._invoke = self.mod["invoke"]
def init(self, ctx): def init(self, ctx):
...@@ -71,23 +262,6 @@ class VirtualMachine(object): ...@@ -71,23 +262,6 @@ class VirtualMachine(object):
args = [ctx.device_type, ctx.device_id] args = [ctx.device_type, ctx.device_id]
self._init(*args) self._init(*args)
def load_params(self, params):
"""Load parameters for the VM.
Parameters
----------
params : Union[bytearray, Dict]
The dictionary that contains serialized parameters.
"""
if isinstance(params, dict):
params = tvm.relay.save_param_dict(params)
elif isinstance(params, (bytes, str)):
params = bytearray(params)
if not isinstance(params, (bytearray, TVMByteArray)):
raise TypeError("params must be a bytearray")
self._load_params(bytearray(params))
def invoke(self, func_name, *args): def invoke(self, func_name, *args):
"""Invoke a function. """Invoke a function.
...@@ -122,11 +296,6 @@ class VirtualMachine(object): ...@@ -122,11 +296,6 @@ class VirtualMachine(object):
""" """
return self.invoke("main", *args) return self.invoke("main", *args)
@property
def module(self):
"""Return the runtime module contained in a virtual machine."""
return self.mod
def compile(mod, target=None, target_host=None, params=None): def compile(mod, target=None, target_host=None, params=None):
""" """
...@@ -155,8 +324,8 @@ def compile(mod, target=None, target_host=None, params=None): ...@@ -155,8 +324,8 @@ def compile(mod, target=None, target_host=None, params=None):
Returns Returns
------- -------
vm : VirtualMachine exec : Executable
The VM runtime. The VM executable that contains both library code and bytecode.
""" """
compiler = VMCompiler() compiler = VMCompiler()
...@@ -167,14 +336,14 @@ def compile(mod, target=None, target_host=None, params=None): ...@@ -167,14 +336,14 @@ def compile(mod, target=None, target_host=None, params=None):
tophub_context = compiler.tophub_context(target) tophub_context = compiler.tophub_context(target)
with tophub_context: with tophub_context:
compiler._compile(mod, target, target_host) compiler._compile(mod, target, target_host)
return VirtualMachine(compiler._get_vm()) return Executable(compiler._get_exec())
class VMCompiler(object): class VMCompiler(object):
"""Build Relay module to run on VM runtime.""" """Build Relay module to run on VM runtime."""
def __init__(self): def __init__(self):
self.mod = _vm._VMCompiler() self.mod = _vm._VMCompiler()
self._compile = self.mod["compile"] self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"] self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"] self._set_params_func = self.mod["set_params"]
def set_params(self, params): def set_params(self, params):
...@@ -240,7 +409,7 @@ class VMExecutor(Executor): ...@@ -240,7 +409,7 @@ class VMExecutor(Executor):
mod : :py:class:`~tvm.relay.module.Module` mod : :py:class:`~tvm.relay.module.Module`
The module to support the execution. The module to support the execution.
ctx : :py:class:`TVMContext` ctx : :py:class:`~tvm.TVMContext`
The runtime context to run the code on. The runtime context to run the code on.
target : :py:class:`Target` target : :py:class:`Target`
...@@ -252,7 +421,8 @@ class VMExecutor(Executor): ...@@ -252,7 +421,8 @@ class VMExecutor(Executor):
self.mod = mod self.mod = mod
self.ctx = ctx self.ctx = ctx
self.target = target self.target = target
self.vm = compile(mod, target) self.executable = compile(mod, target)
self.vm = VirtualMachine(self.executable)
self.vm.init(ctx) self.vm.init(ctx)
def _make_executor(self, expr=None): def _make_executor(self, expr=None):
......
...@@ -783,9 +783,9 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, ...@@ -783,9 +783,9 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
Module mod = args[0]; Module mod = args[0];
this->Compile(mod, args[1], args[2]); this->Compile(mod, args[1], args[2]);
}); });
} else if (name == "get_vm") { } else if (name == "get_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(vm_); *rv = runtime::Module(exec_);
}); });
} else if (name == "set_params") { } else if (name == "set_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
...@@ -864,7 +864,7 @@ void VMCompiler::Compile(Module mod, ...@@ -864,7 +864,7 @@ void VMCompiler::Compile(Module mod,
// Next we get ready by allocating space for // Next we get ready by allocating space for
// the global state. // the global state.
vm_->functions.resize(context_.module->functions.size()); exec_->functions.resize(context_.module->functions.size());
for (auto named_func : context_.module->functions) { for (auto named_func : context_.module->functions) {
auto gvar = named_func.first; auto gvar = named_func.first;
...@@ -873,25 +873,25 @@ void VMCompiler::Compile(Module mod, ...@@ -873,25 +873,25 @@ void VMCompiler::Compile(Module mod,
auto vm_func = func_compiler.Compile(gvar, func); auto vm_func = func_compiler.Compile(gvar, func);
size_t func_index = context_.global_map.at(gvar); size_t func_index = context_.global_map.at(gvar);
CHECK(func_index < vm_->functions.size()); CHECK(func_index < exec_->functions.size());
vm_->functions[func_index] = vm_func; exec_->functions[func_index] = vm_func;
} }
#if USE_RELAY_DEBUG #if USE_RELAY_DEBUG
for (auto vm_func : vm_->functions) { for (auto vm_func : exec_->functions) {
DLOG(INFO) << vm_func << "-------------"; DLOG(INFO) << vm_func << "-------------";
} }
#endif // USE_RELAY_DEBUG #endif // USE_RELAY_DEBUG
// populate constants // populate constants
for (auto data : context_.constants) { for (auto data : context_.constants) {
vm_->constants.push_back(runtime::vm::Tensor(data)); exec_->constants.push_back(runtime::vm::Tensor(data));
} }
LibraryCodegen(); LibraryCodegen();
for (auto gv : context_.global_map) { for (auto gv : context_.global_map) {
vm_->global_map.insert({gv.first->name_hint, gv.second}); exec_->global_map.insert({gv.first->name_hint, gv.second});
} }
} }
...@@ -987,13 +987,13 @@ void VMCompiler::LibraryCodegen() { ...@@ -987,13 +987,13 @@ void VMCompiler::LibraryCodegen() {
// therefore target won't be used in the build function // therefore target won't be used in the build function
runtime::Module mod = (*f)(funcs, Target(), target_host_); runtime::Module mod = (*f)(funcs, Target(), target_host_);
CHECK(mod.operator->()); CHECK(mod.operator->());
vm_->lib = mod; exec_->lib = mod;
} else { } else {
LOG(FATAL) << "relay.backend.build is not registered"; LOG(FATAL) << "relay.backend.build is not registered";
} }
size_t primitive_index = 0; size_t primitive_index = 0;
for (auto cfunc : cached_funcs) { for (auto cfunc : cached_funcs) {
vm_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
} }
} }
......
...@@ -92,12 +92,8 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -92,12 +92,8 @@ class VMCompiler : public runtime::ModuleNode {
return "VMCompiler"; return "VMCompiler";
} }
std::shared_ptr<VirtualMachine> GetVirtualMachine() const { void InitVM() {
return vm_; exec_ = std::make_shared<Executable>();
}
virtual void InitVM() {
vm_ = std::make_shared<VirtualMachine>();
} }
/*! /*!
...@@ -144,8 +140,8 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -144,8 +140,8 @@ class VMCompiler : public runtime::ModuleNode {
tvm::Target target_host_; tvm::Target target_host_;
/*! \brief Global shared meta data */ /*! \brief Global shared meta data */
VMCompilerContext context_; VMCompilerContext context_;
/*! \brief Compiled virtual machine. */ /*! \brief Compiled executable. */
std::shared_ptr<VirtualMachine> vm_; std::shared_ptr<Executable> exec_;
/*! \brief parameters */ /*! \brief parameters */
std::unordered_map<std::string, runtime::NDArray> params_; std::unordered_map<std::string, runtime::NDArray> params_;
}; };
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/deserializer.cc
* \brief Implementation of APIs to deserialize the serialized VM bytecode.
*/
#include "deserializer.h"
#include <tvm/runtime/registry.h>
#include <memory>
#include <sstream>
#include "serialize_util.h"
namespace tvm {
namespace relay {
namespace vm {
#define STREAM_CHECK(val, section) \
CHECK(val) << "Invalid VM file format in the " << section << " section." \
<< "\n";
void Deserializer::Init(const std::string& code, const runtime::Module& lib) {
code_ = code;
vm_ = std::make_shared<VirtualMachine>();
vm_->lib = lib;
strm_ = new dmlc::MemoryStringStream(&code_);
}
runtime::PackedFunc Deserializer::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "deserialize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->Deserialize();
*rv = runtime::Module(vm_);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}
void Deserializer::Deserialize() {
// Check header.
uint64_t header;
STREAM_CHECK(strm_->Read(&header), "header");
STREAM_CHECK(header == kTVMVMBytecodeMagic, "header");
// Check version.
std::string version;
STREAM_CHECK(strm_->Read(&version), "version");
STREAM_CHECK(version == TVM_VERSION, "version");
// Global section.
DeserializeGlobalSection();
// Constant section.
DeserializeConstantSection();
// Primitive names that will be invoked by `InvokePacked` instructions.
DeserializePrimitiveOpNames();
// Code section.
DeserializeCodeSection();
}
void Deserializer::DeserializeGlobalSection() {
std::vector<std::string> globals;
STREAM_CHECK(strm_->Read(&globals), "global");
for (size_t i = 0; i < globals.size(); i++) {
vm_->global_map.insert({globals[i], i});
}
}
void Deserializer::DeserializeConstantSection() {
uint64_t sz;
// Load the number of constants.
STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "constant");
size_t size = static_cast<size_t>(sz);
// Load each of the constants.
for (size_t i = 0; i < size; i++) {
runtime::NDArray constant;
STREAM_CHECK(constant.Load(strm_), "constant");
runtime::ObjectRef obj = runtime::vm::Tensor(constant);
vm_->constants.push_back(obj);
}
}
void Deserializer::DeserializePrimitiveOpNames() {
std::vector<std::string> primitive_names;
STREAM_CHECK(strm_->Read(&primitive_names), "primitive name");
for (size_t i = 0; i < primitive_names.size(); i++) {
vm_->primitive_map.insert({primitive_names[i], i});
}
}
// Extract the `cnt` number of fields started at `start` from the list
// `instr_fields`.
inline std::vector<Index> ExtractFields(const std::vector<Index>& instr_fields,
Index start,
Index cnt) {
CHECK_LE(static_cast<size_t>(start + cnt), instr_fields.size());
std::vector<Index> ret;
for (auto i = start; i < start + cnt; i++) {
ret.push_back(instr_fields[i]);
}
return ret;
}
Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
Opcode opcode = static_cast<Opcode>(instr.opcode);
switch (opcode) {
case Opcode::Move: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::Move(instr.fields[0], instr.fields[1]);
}
case Opcode::Ret: {
// Number of fields = 1
DCHECK_EQ(instr.fields.size(), 1U);
return Instruction::Ret(instr.fields[0]);
}
case Opcode::Fatal: {
// Number of fields = 0
DCHECK(instr.fields.empty());
return Instruction::Fatal();
}
case Opcode::InvokePacked: {
// Number of fields = 3 + instr.arity
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index packed_index = instr.fields[0];
Index arity = instr.fields[1];
Index output_size = instr.fields[2];
std::vector<RegName> args = ExtractFields(instr.fields, 3, arity);
return Instruction::InvokePacked(packed_index, arity, output_size, args);
}
case Opcode::AllocTensor: {
// Number of fields = 5 + instr.alloc_tensor.ndim
DCHECK_GE(instr.fields.size(), 5U);
DCHECK_EQ(instr.fields.size(), 5U + static_cast<size_t>(instr.fields[3]));
DLDataType dtype;
dtype.code = instr.fields[0];
dtype.bits = instr.fields[1];
dtype.lanes = instr.fields[2];
Index ndim = instr.fields[3];
RegName dst = instr.fields[4];
std::vector<Index> shape = ExtractFields(instr.fields, 5, ndim);
return Instruction::AllocTensor(shape, dtype, dst);
}
case Opcode::AllocTensorReg: {
// Number of fields = 5
DCHECK_EQ(instr.fields.size(), 5U);
Index shape_register = instr.fields[0];
DLDataType dtype;
dtype.code = instr.fields[1];
dtype.bits = instr.fields[2];
dtype.lanes = instr.fields[3];
RegName dst = instr.fields[4];
return Instruction::AllocTensorReg(shape_register, dtype, dst);
}
case Opcode::AllocDatatype: {
// Number of fields = 3 + instr.num_fields
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index constructor_tag = instr.fields[0];
Index num_fields = instr.fields[1];
RegName dst = instr.fields[2];
std::vector<Index> fields = ExtractFields(instr.fields, 3, num_fields);
return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst);
}
case Opcode::AllocClosure: {
// Number of fields = 3 + instr.num_freevar
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index clo_index = instr.fields[0];
Index num_freevar = instr.fields[1];
RegName dst = instr.fields[2];
std::vector<Index> free_vars = ExtractFields(instr.fields, 3, num_freevar);
return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst);
}
case Opcode::If: {
// Number of fields = 4
DCHECK_EQ(instr.fields.size(), 4U);
Index test = instr.fields[0];
Index target = instr.fields[1];
Index true_offset = instr.fields[2];
Index false_offset = instr.fields[3];
return Instruction::If(test, target, true_offset, false_offset);
}
case Opcode::Invoke: {
// Number of fields = 3 + instr.num_args
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index func_index = instr.fields[0];
Index num_args = instr.fields[1];
RegName dst = instr.fields[2];
std::vector<Index> args = ExtractFields(instr.fields, 3, num_args);
return Instruction::Invoke(func_index, args, dst);
}
case Opcode::InvokeClosure: {
// Number of fields = 3 + instr.num_closure_args
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index closure = instr.fields[0];
Index num_closure_args = instr.fields[1];
RegName dst = instr.fields[2];
std::vector<Index> args = ExtractFields(instr.fields, 3, num_closure_args);
return Instruction::InvokeClosure(closure, args, dst);
}
case Opcode::LoadConst: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::LoadConst(instr.fields[0], instr.fields[1]);
}
case Opcode::LoadConsti: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::LoadConsti(instr.fields[0], instr.fields[1]);
}
case Opcode::GetField: {
// Number of fields = 3
DCHECK_EQ(instr.fields.size(), 3U);
return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]);
}
case Opcode::GetTag: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::GetTag(instr.fields[0], instr.fields[1]);
}
case Opcode::Goto: {
// Number of fields = 1
DCHECK_EQ(instr.fields.size(), 1U);
return Instruction::Goto(instr.fields[0]);
}
default:
LOG(FATAL) << "Invalid opcode" << instr.opcode;
return Instruction();
}
}
void Deserializer::DeserializeCodeSection() {
// Load the number of functions.
uint64_t sz;
STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "code");
size_t num_funcs = static_cast<size_t>(sz);
vm_->functions.resize(num_funcs);
for (size_t i = 0; i < num_funcs; i++) {
// Load the function info.
VMFunctionSerializer loaded_func;
STREAM_CHECK(loaded_func.Load(strm_), "code/function");
// Load the instructions.
std::vector<Instruction> instructions;
for (size_t j = 0; j < loaded_func.num_instructions; j++) {
VMInstructionSerializer instr;
std::vector<Index> instr_fields;
STREAM_CHECK(instr.Load(strm_), "code/instruction");
instructions.push_back(DeserializeInstruction(instr));
}
// Create the VM function.
VMFunction vm_func = VMFunction(loaded_func.name,
loaded_func.params,
instructions,
loaded_func.register_file_size);
auto it = vm_->global_map.find(loaded_func.name);
CHECK(it != vm_->global_map.end());
CHECK_LE(it->second, vm_->global_map.size());
vm_->functions[it->second] = vm_func;
}
}
runtime::Module CreateDeserializer(const std::string& code, const runtime::Module lib) {
std::shared_ptr<Deserializer> exec = std::make_shared<Deserializer>();
exec->Init(code, lib);
return runtime::Module(exec);
}
TVM_REGISTER_GLOBAL("relay._vm._Deserializer")
.set_body_typed(CreateDeserializer);
} // namespace vm
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/deserializer.h
* \brief Define a deserializer for the serialized Relay VM.
*/
#ifndef TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
#define TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
#include <dmlc/memory_io.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/vm.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace relay {
namespace vm {
using namespace tvm::runtime::vm;
namespace runtime = tvm::runtime;
class Deserializer : public runtime::ModuleNode {
public:
/*!
* \brief Initialize the deserializer for creating a virtual machine object.
*
* \param code The serialized code.
* \param lib The serialized runtime module/library that contains the
* hardware dependent code.
*/
inline void Init(const std::string& code, const runtime::Module& lib);
/*!
* \brief Return the member function to the frontend.
*
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
*
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final { return "Deserializer"; }
/*! \brief Deserialize the serialized VM. */
void Deserialize();
virtual ~Deserializer() { delete strm_; }
private:
/*! \brief Deserialize the globals in `vm_`. */
void DeserializeGlobalSection();
/*! \brief Deserialize the constant pool in `vm_`. */
void DeserializeConstantSection();
/*! \brief Deserialize primitive op names in `vm_`. */
void DeserializePrimitiveOpNames();
/*! \brief Deserialize the vm functions in `vm_`. */
void DeserializeCodeSection();
/*! \brief The code to be serialized. */
std::string code_;
/*! \brief The stream used for serialization. */
dmlc::Stream* strm_;
/*! \brief The VM to be created. */
std::shared_ptr<VirtualMachine> vm_;
};
} // namespace vm
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
...@@ -33,7 +33,6 @@ namespace vm { ...@@ -33,7 +33,6 @@ namespace vm {
class VMCompilerDebug : public VMCompiler { class VMCompilerDebug : public VMCompiler {
public: public:
VMCompilerDebug() {} VMCompilerDebug() {}
void InitVM() override { vm_ = std::make_shared<VirtualMachineDebug>(); }
virtual ~VMCompilerDebug() {} virtual ~VMCompilerDebug() {}
}; };
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/serializer.h
* \brief Define a serializer for the Relay VM.
*
* The following components of a Relay VM will be serialized:
* - The `constants`, e.g., the constant pool, that contains the
* constants used in a Relay program.
* - The `packed_funcs` that essentially contains the generated code for
* a specific target. We return it as a runtime module that can be exported as
* a library file (e.g., .so, .o, or .tar).
* - The `global_map` that contains the globals.
* - The `primitive_map` that contains the name of individual primitive operators.
* - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of
* a list of instructions/bytecode.
*
* Note that only the library is returned as a separate module. All othere parts
* are stored in a single serialized code that is organized with the following
* sections in order.
* - Global section, containing all globals.
* - Constant section, storing the constant pool.
* - Primitive name section, containing the function name of the primitive ops
* used by the virtual machine.
* - Code section, handling the VM functions and bytecode.
*
* The code section is again organized as follows for each VM function:
* func_name, register_file_size, num_instructions (N)
* param1, param2, ..., paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Serializing an `Instruction` requires us to deal with the bytecode. Each line
* of the instructions could be serialized as the following format:
* hash, opcode, f1, f2, ..., fX, field with variable length
* 1. hash: the hash of the instruction. This number will be used to help us
* validate if an instruction is well-formed during deserialization.
* 2. opcode: the opcode code of the instruction.
* 3. f1, f2, ..., fX. These fields together represent the fixed fields in
* an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For
* example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
* 4. The rest of the line indicates the field with variable length, e.g.,
* the shape of a tensor, the args used by an `InvokPacked` instruction, etc.
*/
#ifndef TVM_RELAY_BACKEND_VM_SERIALIZER_H_
#define TVM_RELAY_BACKEND_VM_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <tvm/ir.h>
#include <tvm/node/container.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/vm.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace relay {
namespace vm {
using namespace tvm::runtime;
using namespace tvm::runtime::vm;
/*!
* \brief The Relay VM serializer.
*/
class Serializer : public runtime::ModuleNode {
public:
/*!
* \brief Initialize the serializer for a virtual machine.
*
* \param vm The Relay virtual machine.
*/
inline void Init(const VirtualMachine* vm);
/*!
* \brief Return the member function to the frontend.
*
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
*
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final { return "Serializer"; }
/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
*/
std::string Stats() const;
/*!
* \brief Serialize the `vm_` into global section, constant section, and code
* section.
*
* \return The binary representation of the VM.
*/
TVMByteArray Serialize();
/*!
* \brief Get a list of the globals used by the `_vm`.
*
* \return The global map in the form a list.
*/
tvm::Array<tvm::Expr> GetGlobals() const;
/*!
* \brief Get the primitive operators that are contained in the Relay VM.
*
* \return The list of primitve operators.
*/
tvm::Array<tvm::Expr> GetPrimitiveOps() const;
/*!
* \brief Get the serialized form of the `functions` in `vm_`. This is
* essentially bytecode serialization.
*
* \return The serialized vm bytecode.
*
* \note The bytecode is in the following format:
* func_name reg_file_size num_instructions
* param1 param2 ... paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Each instruction is printed in the following format:
* opcode num_fields field1 ... fieldX # The text format.
*
* The field starting from # is only used for debugging. The serialized code
* doesn't contain it, therefore the deserializer doens't need to handle it.
*/
std::string GetBytecode() const;
/*! \brief Get the `lib` module in vm_. Serialization of `runtime::module`
* has already been supported by TVM. Therefore, we only return the runtime
* module and let users have the flexibility to call `export_library` from
* the frontend to save the library to disk.
*
* \return The runtime module that contains the hardwre dependent code.
*/
inline runtime::Module GetLib() const;
virtual ~Serializer() { delete strm_; }
private:
/*! \brief Serialize the globals in vm_. */
void SerializeGlobalSection();
/*! \brief Serialize the constant pool in vm_. */
void SerializeConstantSection();
/*! \brief Serialize primitive op names in vm_. */
void SerializePrimitiveOpNames();
/*! \brief Serialize the vm functions in vm_. */
void SerializeCodeSection();
/*! \brief The Relay virtual machine for to be serialized. */
const VirtualMachine* vm_;
/*! \brief The stream used for serialization. */
dmlc::Stream* strm_;
/*! \brief The serialized code. */
std::string code_;
};
} // namespace vm
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_SERIALIZER_H_
...@@ -19,16 +19,18 @@ ...@@ -19,16 +19,18 @@
/*! /*!
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/serializer.cc * \file tvm/runtime/vm/executable.cc
* \brief Implementation of serializing APIs for the Relay VM. * \brief The implementation of a virtual machine executable APIs.
*/ */
#include "serializer.h"
#include <tvm/runtime/registry.h> #include <dmlc/memory_io.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/vm.h>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <iostream>
#include <sstream> #include <sstream>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -36,70 +38,88 @@ ...@@ -36,70 +38,88 @@
#include "serialize_util.h" #include "serialize_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace runtime {
namespace vm { namespace vm {
void Serializer::Init(const VirtualMachine* vm) { #define STREAM_CHECK(val, section) \
vm_ = vm; CHECK(val) << "Invalid VM file format in the " << section << " section." \
// Initialize the stream object. << "\n";
strm_ = new dmlc::MemoryStringStream(&code_);
}
runtime::PackedFunc Serializer::GetFunction( // Helper to serialize a vm instruction.
const std::string& name, VMInstructionSerializer SerializeInstruction(const Instruction& instr);
// Helper to deserialize a serialized vm instruction.
Instruction DeserializeInstruction(const VMInstructionSerializer& instr);
PackedFunc Executable::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "get_lib") { if (name == "get_lib") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetLib(); *rv = this->GetLib();
}); });
} else if (name == "get_primitive_ops") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetPrimitiveOps();
});
} else if (name == "get_bytecode") { } else if (name == "get_bytecode") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetBytecode(); *rv = this->GetBytecode();
}); });
} else if (name == "get_globals") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetGlobals();
});
} else if (name == "get_stats") { } else if (name == "get_stats") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->Stats(); *rv = this->Stats();
}); });
} else if (name == "serialize") { } else if (name == "save") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->Serialize(); *rv = this->Save();
}); });
} else { } else {
LOG(FATAL) << "Unknown packed function: " << name; LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); return PackedFunc(nullptr);
} }
} }
tvm::Array<tvm::Expr> Serializer::GetPrimitiveOps() const { std::string Executable::GetBytecode() const {
std::vector<tvm::Expr> ret; std::ostringstream oss;
for (const auto& it : vm_->primitive_map) {
auto packed_name = tvm::ir::StringImm::make(it.first); for (const auto& func : functions) {
auto packed_index = static_cast<size_t>(it.second); // Print the header of the function format.
if (ret.size() <= packed_index) { oss << "# func name, reg file size, param count, inst count:"
ret.resize(packed_index + 1); << std::endl;
oss << func.name << " "
<< func.register_file_size << " "
<< func.params.size() << " "
<< func.instructions.size() << std::endl;
// Print pramams of a `VMFunction`.
oss << "# Parameters: "<< std::endl;
for (const auto& param : func.params) {
oss << param << " ";
}
oss << std::endl;
// Print the instructions of a `VMFunction`.
// The part after ";" is the instruction in text format.
oss << "hash, opcode, fields # inst(text):"<< std::endl;
for (const auto& instr : func.instructions) {
const auto& serialized_instr = SerializeInstruction(instr);
oss << std::hex << "0x" << serialized_instr.Hash() << " "
<< std::dec << serialized_instr.opcode << " ";
for (auto it : serialized_instr.fields) {
oss << it << " ";
}
oss << " # " << instr;
if (oss.str().back() != '\n') oss << std::endl;
} }
ret[packed_index] = packed_name;
} }
return ret;
return oss.str();
} }
std::string Serializer::Stats() const { std::string Executable::Stats() const {
std::ostringstream oss; std::ostringstream oss;
oss << "Relay VM statistics:" << std::endl; oss << "Relay VM executable statistics:" << std::endl;
// Get the number of constants and the shape of each of them. // Get the number of constants and the shape of each of them.
oss << " Constant shapes (# " << vm_->constants.size() << "): ["; oss << " Constant shapes (# " << constants.size() << "): [";
for (const auto& it : vm_->constants) { for (const auto& it : constants) {
auto* cell = it.as<runtime::vm::TensorObj>(); const auto* cell = it.as<TensorObj>();
CHECK(cell != nullptr); CHECK(cell);
runtime::NDArray data = cell->data; runtime::NDArray data = cell->data;
const auto& shape = data.Shape(); const auto& shape = data.Shape();
...@@ -116,20 +136,27 @@ std::string Serializer::Stats() const { ...@@ -116,20 +136,27 @@ std::string Serializer::Stats() const {
oss.seekp(-2, oss.cur); oss.seekp(-2, oss.cur);
oss << "], " << std::endl; oss << "], " << std::endl;
} }
if (!vm_->constants.empty()) oss.seekp(-2, oss.cur); if (!constants.empty()) oss.seekp(-2, oss.cur);
oss << "]" << std::endl; oss << "]" << std::endl;
// Get the number of globals and the name of each of them. // Get the number of globals and the name of each of them.
oss << " Globals (#" << vm_->global_map.size() << "): ["; oss << " Globals (#" << global_map.size() << "): [";
for (const auto& it : vm_->global_map) { for (const auto& it : global_map) {
oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; oss << "(\"" << it.first << "\", " << it.second << ")" << ", ";
} }
if (!vm_->global_map.empty()) oss.seekp(-2, oss.cur); if (!global_map.empty()) oss.seekp(-2, oss.cur);
oss << "]" << std::endl; oss << "]" << std::endl;
// Get the number of primitive ops and the name of each of them. // Get the number of primitive ops and the name of each of them.
oss << " Primitive ops (#" << vm_->primitive_map.size() << "): ["; oss << " Primitive ops (#" << primitive_map.size() << "): [";
const auto& prim_ops = GetPrimitiveOps(); std::vector<std::string> prim_ops;
for (const auto& it : primitive_map) {
auto packed_index = static_cast<size_t>(it.second);
if (prim_ops.size() <= packed_index) {
prim_ops.resize(packed_index + 1);
}
prim_ops[packed_index] = it.first;
}
for (const auto& it : prim_ops) { for (const auto& it : prim_ops) {
oss << it << ", "; oss << it << ", ";
} }
...@@ -139,23 +166,32 @@ std::string Serializer::Stats() const { ...@@ -139,23 +166,32 @@ std::string Serializer::Stats() const {
return oss.str(); return oss.str();
} }
TVMByteArray Serializer::Serialize() { void SaveHeader(dmlc::Stream* strm) {
uint64_t header = kTVMVMBytecodeMagic; uint64_t header = kTVMVMBytecodeMagic;
strm_->Write(header); strm->Write(header);
std::string version = TVM_VERSION; std::string version = TVM_VERSION;
strm_->Write(version); strm->Write(version);
}
TVMByteArray Executable::Save() {
// Initialize the stream object.
code_.clear();
dmlc::MemoryStringStream strm(&code_);
// Save header
SaveHeader(&strm);
// Global section. // Global section.
SerializeGlobalSection(); SaveGlobalSection(&strm);
// Constant section. // Constant section.
SerializeConstantSection(); SaveConstantSection(&strm);
// Primitive names. // Primitive names.
SerializePrimitiveOpNames(); SavePrimitiveOpNames(&strm);
// Code section. // Code section.
SerializeCodeSection(); SaveCodeSection(&strm);
TVMByteArray arr; TVMByteArray arr;
arr.data = code_.c_str(); arr.data = code_.c_str();
...@@ -163,36 +199,46 @@ TVMByteArray Serializer::Serialize() { ...@@ -163,36 +199,46 @@ TVMByteArray Serializer::Serialize() {
return arr; return arr;
} }
void Serializer::SerializeGlobalSection() { void Executable::SaveGlobalSection(dmlc::Stream* strm) {
auto globals = GetGlobals(); std::vector<std::pair<std::string, Index> > globals(this->global_map.begin(),
this->global_map.end());
auto comp = [](const std::pair<std::string, Index>& a,
const std::pair<std::string, Index>& b) {
return a.second < b.second;
};
std::sort(globals.begin(), globals.end(), comp);
std::vector<std::string> glbs; std::vector<std::string> glbs;
for (const auto& it : globals) { for (const auto& it : globals) {
glbs.push_back(it.as<tvm::ir::StringImm>()->value); glbs.push_back(it.first);
} }
strm_->Write(glbs); strm->Write(glbs);
} }
void Serializer::SerializeConstantSection() { void Executable::SaveConstantSection(dmlc::Stream* strm) {
std::vector<DLTensor*> arrays; std::vector<DLTensor*> arrays;
for (const auto& obj : vm_->constants) { for (const auto& obj : this->constants) {
const auto* cell = obj.as<runtime::vm::TensorObj>(); const auto* cell = obj.as<runtime::vm::TensorObj>();
CHECK(cell != nullptr); CHECK(cell != nullptr);
runtime::NDArray data = cell->data; runtime::NDArray data = cell->data;
arrays.push_back(const_cast<DLTensor*>(data.operator->())); arrays.push_back(const_cast<DLTensor*>(data.operator->()));
} }
strm_->Write(static_cast<uint64_t>(vm_->constants.size())); strm->Write(static_cast<uint64_t>(this->constants.size()));
for (const auto& it : arrays) { for (const auto& it : arrays) {
runtime::SaveDLTensor(strm_, it); runtime::SaveDLTensor(strm, it);
} }
} }
void Serializer::SerializePrimitiveOpNames() { void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) {
auto names = GetPrimitiveOps();
std::vector<std::string> primitive_names; std::vector<std::string> primitive_names;
for (const auto& it : names) { for (const auto& it : this->primitive_map) {
primitive_names.push_back(it.as<tvm::ir::StringImm>()->value); auto packed_index = static_cast<size_t>(it.second);
if (primitive_names.size() <= packed_index) {
primitive_names.resize(packed_index + 1);
}
primitive_names[packed_index] = it.first;
} }
strm_->Write(primitive_names); strm->Write(primitive_names);
} }
// Serialize a virtual machine instruction. It creates a list that contains the // Serialize a virtual machine instruction. It creates a list that contains the
...@@ -206,7 +252,7 @@ void Serializer::SerializePrimitiveOpNames() { ...@@ -206,7 +252,7 @@ void Serializer::SerializePrimitiveOpNames() {
// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` // `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn`
// //
// where hash is the hash of serialized instruction that is computed internally // where hash is the hash of serialized instruction that is computed internally
// by the `VMInstructionSerializer`. It is used for sanity check before decoding. // by the `VMInstructionExecutable`. It is used for sanity check before decoding.
// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` // 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)`
// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` // represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register`
// is the destination register, and the rest of it together indicates the shape // is the destination register, and the rest of it together indicates the shape
...@@ -344,96 +390,345 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { ...@@ -344,96 +390,345 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
return VMInstructionSerializer(static_cast<Index>(instr.op), fields); return VMInstructionSerializer(static_cast<Index>(instr.op), fields);
} }
void Serializer::SerializeCodeSection() { void Executable::SaveCodeSection(dmlc::Stream* strm) {
// Save the number of functions. // Save the number of functions.
strm_->Write(static_cast<uint64_t>(vm_->functions.size())); strm->Write(static_cast<uint64_t>(this->functions.size()));
for (const auto& func : vm_->functions) { for (const auto& func : this->functions) {
// Serialize the function info. // Save the function info.
VMFunctionSerializer func_format(func.name, VMFunctionSerializer func_format(func.name,
func.register_file_size, func.register_file_size,
func.instructions.size(), func.instructions.size(),
func.params); func.params);
func_format.Save(strm_); func_format.Save(strm);
// Serialize each instruction. // Serialize each instruction.
for (const auto& instr : func.instructions) { for (const auto& instr : func.instructions) {
const auto& serialized_instr = SerializeInstruction(instr); const auto& serialized_instr = SerializeInstruction(instr);
serialized_instr.Save(strm_); serialized_instr.Save(strm);
} }
} }
} }
tvm::Array<tvm::Expr> Serializer::GetGlobals() const { void LoadHeader(dmlc::Stream* strm) {
tvm::Array<tvm::Expr> ret; // Check header.
std::vector<std::pair<std::string, Index> > globals(vm_->global_map.begin(), uint64_t header;
vm_->global_map.end()); STREAM_CHECK(strm->Read(&header), "header");
auto comp = [](const std::pair<std::string, Index>& a, STREAM_CHECK(header == kTVMVMBytecodeMagic, "header");
const std::pair<std::string, Index>& b) {
return a.second < b.second; // Check version.
}; std::string version;
std::sort(globals.begin(), globals.end(), comp); STREAM_CHECK(strm->Read(&version), "version");
for (const auto& it : globals) { STREAM_CHECK(version == TVM_VERSION, "version");
ret.push_back(tvm::ir::StringImm::make(it.first)); }
runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) {
std::shared_ptr<Executable> exec = std::make_shared<Executable>();
exec->lib = lib;
exec->code_ = code;
dmlc::MemoryStringStream strm(&exec->code_);
// Load header.
LoadHeader(&strm);
// Global section.
exec->LoadGlobalSection(&strm);
// Constant section.
exec->LoadConstantSection(&strm);
// Primitive names that will be invoked by `InvokePacked` instructions.
exec->LoadPrimitiveOpNames(&strm);
// Code section.
exec->LoadCodeSection(&strm);
return runtime::Module(exec);
}
void Executable::LoadGlobalSection(dmlc::Stream* strm) {
std::vector<std::string> globals;
STREAM_CHECK(strm->Read(&globals), "global");
for (size_t i = 0; i < globals.size(); i++) {
this->global_map.insert({globals[i], i});
}
}
void Executable::LoadConstantSection(dmlc::Stream* strm) {
uint64_t sz;
// Load the number of constants.
STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant");
size_t size = static_cast<size_t>(sz);
// Load each of the constants.
for (size_t i = 0; i < size; i++) {
runtime::NDArray constant;
STREAM_CHECK(constant.Load(strm), "constant");
runtime::ObjectRef obj = runtime::vm::Tensor(constant);
this->constants.push_back(obj);
}
}
void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) {
std::vector<std::string> primitive_names;
STREAM_CHECK(strm->Read(&primitive_names), "primitive name");
for (size_t i = 0; i < primitive_names.size(); i++) {
this->primitive_map.insert({primitive_names[i], i});
}
}
// Extract the `cnt` number of fields started at `start` from the list
// `instr_fields`.
inline std::vector<Index> ExtractFields(const std::vector<Index>& instr_fields,
Index start,
Index cnt) {
CHECK_LE(static_cast<size_t>(start + cnt), instr_fields.size());
std::vector<Index> ret;
for (auto i = start; i < start + cnt; i++) {
ret.push_back(instr_fields[i]);
} }
return ret; return ret;
} }
std::string Serializer::GetBytecode() const { Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
std::ostringstream oss; Opcode opcode = static_cast<Opcode>(instr.opcode);
switch (opcode) {
case Opcode::Move: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::Move(instr.fields[0], instr.fields[1]);
}
case Opcode::Ret: {
// Number of fields = 1
DCHECK_EQ(instr.fields.size(), 1U);
return Instruction::Ret(instr.fields[0]);
}
case Opcode::Fatal: {
// Number of fields = 0
DCHECK(instr.fields.empty());
return Instruction::Fatal();
}
case Opcode::InvokePacked: {
// Number of fields = 3 + instr.arity
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index packed_index = instr.fields[0];
Index arity = instr.fields[1];
Index output_size = instr.fields[2];
std::vector<RegName> args = ExtractFields(instr.fields, 3, arity);
return Instruction::InvokePacked(packed_index, arity, output_size, args);
}
case Opcode::AllocTensor: {
// Number of fields = 5 + instr.alloc_tensor.ndim
DCHECK_GE(instr.fields.size(), 5U);
DCHECK_EQ(instr.fields.size(), 5U + static_cast<size_t>(instr.fields[3]));
for (const auto& func : vm_->functions) { DLDataType dtype;
// Print the header of the function format. dtype.code = instr.fields[0];
oss << "# func name, reg file size, param count, inst count:" dtype.bits = instr.fields[1];
<< std::endl; dtype.lanes = instr.fields[2];
oss << func.name << " "
<< func.register_file_size << " "
<< func.params.size() << " "
<< func.instructions.size() << std::endl;
// Print pramams of a `VMFunction`. Index ndim = instr.fields[3];
oss << "# Parameters:"<< std::endl; RegName dst = instr.fields[4];
for (const auto& param : func.params) {
oss << param << " "; std::vector<Index> shape = ExtractFields(instr.fields, 5, ndim);
return Instruction::AllocTensor(shape, dtype, dst);
} }
oss << std::endl; case Opcode::AllocTensorReg: {
// Number of fields = 5
DCHECK_EQ(instr.fields.size(), 5U);
Index shape_register = instr.fields[0];
// Print the instructions of a `VMFunction`. DLDataType dtype;
// The part after ";" is the instruction in text format. dtype.code = instr.fields[1];
oss << "hash, opcode, fields # inst(text):"<< std::endl; dtype.bits = instr.fields[2];
for (const auto& instr : func.instructions) { dtype.lanes = instr.fields[3];
const auto& serialized_instr = SerializeInstruction(instr);
oss << std::hex << "0x" << serialized_instr.Hash() << " " RegName dst = instr.fields[4];
<< std::dec << serialized_instr.opcode << " ";
for (auto it : serialized_instr.fields) { return Instruction::AllocTensorReg(shape_register, dtype, dst);
oss << it << " ";
}
oss << " # " << instr;
if (oss.str().back() != '\n') oss << std::endl;
} }
} case Opcode::AllocDatatype: {
// Number of fields = 3 + instr.num_fields
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
return oss.str(); Index constructor_tag = instr.fields[0];
} Index num_fields = instr.fields[1];
RegName dst = instr.fields[2];
std::vector<Index> fields = ExtractFields(instr.fields, 3, num_fields);
return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst);
}
case Opcode::AllocClosure: {
// Number of fields = 3 + instr.num_freevar
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index clo_index = instr.fields[0];
Index num_freevar = instr.fields[1];
RegName dst = instr.fields[2];
std::vector<Index> free_vars = ExtractFields(instr.fields, 3, num_freevar);
return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst);
}
case Opcode::If: {
// Number of fields = 4
DCHECK_EQ(instr.fields.size(), 4U);
Index test = instr.fields[0];
Index target = instr.fields[1];
Index true_offset = instr.fields[2];
Index false_offset = instr.fields[3];
return Instruction::If(test, target, true_offset, false_offset);
}
case Opcode::Invoke: {
// Number of fields = 3 + instr.num_args
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index func_index = instr.fields[0];
Index num_args = instr.fields[1];
RegName dst = instr.fields[2];
std::vector<Index> args = ExtractFields(instr.fields, 3, num_args);
runtime::Module Serializer::GetLib() const { return Instruction::Invoke(func_index, args, dst);
return vm_->lib; }
case Opcode::InvokeClosure: {
// Number of fields = 3 + instr.num_closure_args
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index closure = instr.fields[0];
Index num_closure_args = instr.fields[1];
RegName dst = instr.fields[2];
std::vector<Index> args = ExtractFields(instr.fields, 3, num_closure_args);
return Instruction::InvokeClosure(closure, args, dst);
}
case Opcode::LoadConst: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::LoadConst(instr.fields[0], instr.fields[1]);
}
case Opcode::LoadConsti: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::LoadConsti(instr.fields[0], instr.fields[1]);
}
case Opcode::GetField: {
// Number of fields = 3
DCHECK_EQ(instr.fields.size(), 3U);
return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]);
}
case Opcode::GetTag: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::GetTag(instr.fields[0], instr.fields[1]);
}
case Opcode::Goto: {
// Number of fields = 1
DCHECK_EQ(instr.fields.size(), 1U);
return Instruction::Goto(instr.fields[0]);
}
default:
LOG(FATAL) << "Invalid opcode" << instr.opcode;
return Instruction();
}
} }
runtime::Module CreateSerializer(const VirtualMachine* vm) { void Executable::LoadCodeSection(dmlc::Stream* strm) {
std::shared_ptr<Serializer> exec = std::make_shared<Serializer>(); // Load the number of functions.
exec->Init(vm); uint64_t sz;
return runtime::Module(exec); STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "code");
size_t num_funcs = static_cast<size_t>(sz);
this->functions.resize(num_funcs);
for (size_t i = 0; i < num_funcs; i++) {
// Load the function info.
VMFunctionSerializer loaded_func;
STREAM_CHECK(loaded_func.Load(strm), "code/function");
// Load the instructions.
std::vector<Instruction> instructions;
for (size_t j = 0; j < loaded_func.num_instructions; j++) {
VMInstructionSerializer instr;
std::vector<Index> instr_fields;
STREAM_CHECK(instr.Load(strm), "code/instruction");
instructions.push_back(DeserializeInstruction(instr));
}
// Create the VM function.
VMFunction vm_func = VMFunction(loaded_func.name,
loaded_func.params,
instructions,
loaded_func.register_file_size);
auto it = this->global_map.find(loaded_func.name);
CHECK(it != this->global_map.end());
CHECK_LE(it->second, this->global_map.size());
this->functions[it->second] = vm_func;
}
} }
TVM_REGISTER_GLOBAL("relay._vm._Serializer") TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0]; runtime::Module mod = args[0];
const auto* vm = dynamic_cast<VirtualMachine*>(mod.operator->()); const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(vm) << "Virtual machine has not been defined yet." CHECK(exec);
<< "\n"; *rv = static_cast<int>(exec->global_map.size());
*rv = CreateSerializer(vm); });
TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec);
int idx = args[1];
std::vector<std::pair<std::string, Index> > globals(exec->global_map.begin(),
exec->global_map.end());
auto comp = [](const std::pair<std::string, Index>& a,
const std::pair<std::string, Index>& b) {
return a.second < b.second;
};
std::sort(globals.begin(), globals.end(), comp);
CHECK_LT(idx, globals.size());
*rv = globals[idx].first;
});
TVM_REGISTER_GLOBAL("relay._vm.GetNumOfPrimitives")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec);
*rv = static_cast<int>(exec->primitive_map.size());
});
TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec);
int idx = args[1];
CHECK_GE(idx, 0);
CHECK_LT(idx, exec->primitive_map.size());
for (const auto& it : exec->primitive_map) {
if (idx == static_cast<int>(it.second)) {
*rv = it.first;
break;
}
}
});
TVM_REGISTER_GLOBAL("relay._vm.Load_Executable")
.set_body_typed<runtime::Module(std::string, runtime::Module)>([](
std::string code,
runtime::Module lib) {
return Executable::Load(code, lib);
}); });
} // namespace vm } // namespace vm
} // namespace relay } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -85,19 +85,25 @@ PackedFunc VirtualMachineDebug::GetFunction( ...@@ -85,19 +85,25 @@ PackedFunc VirtualMachineDebug::GetFunction(
} }
} }
void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) { void VirtualMachineDebug::LoadExecutable(const Executable* exec) {
VirtualMachine::Init(ctxs); VirtualMachine::LoadExecutable(exec);
for (auto kv : primitive_map) { CHECK(this->exec);
for (auto kv : this->exec->primitive_map) {
packed_index_map[kv.second] = kv.first; packed_index_map[kv.second] = kv.first;
op_invokes[kv.second] = 0; op_invokes[kv.second] = 0;
} }
} }
void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
VirtualMachine::Init(ctxs);
}
void VirtualMachineDebug::InvokePacked(Index packed_index, void VirtualMachineDebug::InvokePacked(Index packed_index,
const PackedFunc& func, Index arg_count, const PackedFunc& func, Index arg_count,
Index output_size, Index output_size,
const std::vector<ObjectRef>& args) { const std::vector<ObjectRef>& args) {
auto ctx = VirtualMachine::GetParamsContext(); CHECK(this->exec);
auto ctx = this->GetParamsContext();
// warmup // warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args); args);
...@@ -117,6 +123,21 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, ...@@ -117,6 +123,21 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
op_invokes[packed_index] += 1; op_invokes[packed_index] += 1;
} }
runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
std::shared_ptr<VirtualMachineDebug> vm = std::make_shared<VirtualMachineDebug>();
vm->LoadExecutable(exec);
return runtime::Module(vm);
}
TVM_REGISTER_GLOBAL("relay._vm._VirtualMachineDebug")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec) << "Virtual machine has not been defined yet."
<< "\n";
*rv = CreateVirtualMachineDebug(exec);
});
} // namespace vm } // namespace vm
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -47,6 +47,8 @@ class VirtualMachineDebug : public VirtualMachine { ...@@ -47,6 +47,8 @@ class VirtualMachineDebug : public VirtualMachine {
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<ObjectRef>& args) final; Index output_size, const std::vector<ObjectRef>& args) final;
void LoadExecutable(const Executable* exec);
~VirtualMachineDebug() {} ~VirtualMachineDebug() {}
private: private:
......
...@@ -19,11 +19,11 @@ ...@@ -19,11 +19,11 @@
/*! /*!
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/serialize_util.h * \file src/runtime/vm/serialize_util.h
* \brief Definitions of helpers for serializing and deserializing a Relay VM. * \brief Definitions of helpers for serializing and deserializing a Relay VM.
*/ */
#ifndef TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ #ifndef TVM_RUNTIME_VM_SERIALIZE_UTIL_H_
#define TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ #define TVM_RUNTIME_VM_SERIALIZE_UTIL_H_
#include <dmlc/common.h> #include <dmlc/common.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
#include <vector> #include <vector>
namespace tvm { namespace tvm {
namespace relay { namespace runtime {
namespace vm { namespace vm {
/*! \brief The magic number for the serialized VM bytecode file */ /*! \brief The magic number for the serialized VM bytecode file */
...@@ -158,7 +158,7 @@ struct VMInstructionSerializer { ...@@ -158,7 +158,7 @@ struct VMInstructionSerializer {
}; };
} // namespace vm } // namespace vm
} // namespace relay } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ #endif // TVM_RUNTIME_VM_SERIALIZE_UTIL_H_
...@@ -575,11 +575,12 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, ...@@ -575,11 +575,12 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "invoke") { if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(exec) << "The executable is not created yet.";
std::string func_name = args[0]; std::string func_name = args[0];
auto gvit = this->global_map.find(func_name); auto gvit = exec->global_map.find(func_name);
CHECK(gvit != this->global_map.end()) << "Cannot find function " << func_name; CHECK(gvit != exec->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second; auto func_index = gvit->second;
const auto& vm_func = this->functions[func_index]; const auto& vm_func = exec->functions[func_index];
const auto& param_names = vm_func.params; const auto& param_names = vm_func.params;
auto ctx = this->GetParamsContext(); auto ctx = this->GetParamsContext();
...@@ -617,10 +618,6 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, ...@@ -617,10 +618,6 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
} }
this->Init(contexts); this->Init(contexts);
}); });
} else if (name == "load_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->LoadParams(args[0].operator std::string());
});
} else { } else {
LOG(FATAL) << "Unknown packed function: " << name; LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
...@@ -628,43 +625,20 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, ...@@ -628,43 +625,20 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
} }
TVMContext VirtualMachine::GetParamsContext() const { TVMContext VirtualMachine::GetParamsContext() const {
CHECK(!ctxs.empty()) << "Context has not been initialized yet."
<< "\n";
// Use the fallback device if no device index is available. // Use the fallback device if no device index is available.
int fallback_device_type = static_cast<int>(ctxs[0].device_type); int fallback_device_type = static_cast<int>(ctxs[0].device_type);
// TODO(wweic): For heterogeneous execution, get device information from byte // TODO(wweic): For heterogeneous execution, get device information from byte
const auto& cit = const auto& cit =
std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) {
return fallback_device_type == static_cast<int>(c.device_type); return fallback_device_type == static_cast<int>(c.device_type);
}); });
return (cit == ctxs.end() ? ctxs[0] : *cit); return (cit == ctxs.end() ? ctxs[0] : *cit);
} }
void VirtualMachine::LoadParams(const std::string& params) {
dmlc::MemoryStringStream mss(const_cast<std::string*>(&params));
dmlc::Stream* strm = &mss;
uint64_t header, reserved;
CHECK(strm->Read(&header)) << "Invalid parameter file";
CHECK(header == kTVMNDArrayListMagic) << "Invalid parameter file";
CHECK(strm->Read(&reserved)) << "Invalid parameter file";
std::vector<std::string> names;
CHECK(strm->Read(&names)) << "Invalid parameter file";
uint64_t sz;
strm->Read(&sz);
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) << "Invalid parameter file";
auto ctx = GetParamsContext();
for (size_t i = 0; i < size; i++) {
NDArray arr;
CHECK(arr.Load(strm)) << "Invalid parameter file";
ObjectRef obj = Tensor(arr);
auto copy = CopyTo(obj, ctx);
params_.emplace(std::make_pair(names[i], copy));
}
}
void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) {
auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size); auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size);
frames.push_back(frame); frames.push_back(frame);
...@@ -699,15 +673,17 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<Objec ...@@ -699,15 +673,17 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<Objec
InvokeGlobal(func, args); InvokeGlobal(func, args);
RunLoop(); RunLoop();
// TODO(wweic) ctx could be obtained from the ctxs list.
auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]); auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]);
DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B"; DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B";
return return_register; return return_register;
} }
ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<ObjectRef>& args) { ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<ObjectRef>& args) {
auto func_index = this->global_map[name]; CHECK(exec) << "The executable has not been created yet.";
auto func_index = exec->global_map.at(name);
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index; DLOG(INFO) << "Invoke Global " << name << " at index " << func_index;
return Invoke(this->functions[func_index], args); return Invoke(exec->functions[func_index], args);
} }
void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
...@@ -744,14 +720,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, ...@@ -744,14 +720,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
} }
void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) { void VirtualMachine::LoadExecutable(const Executable* exec) {
this->ctxs = ctxs; CHECK(exec) << "The executable is not created yet.";
this->exec = exec;
runtime::Module lib = this->exec->lib;
// Get the list of packed functions. // Get the list of packed functions.
CHECK(primitive_map.empty() || lib.operator->()) CHECK(exec->primitive_map.empty() || lib.operator->())
<< "runtime module should have been built for primitive functions" << "runtime module should have been built for primitive functions"
<< "\n"; << "\n";
for (const auto& it : primitive_map) { for (const auto& it : this->exec->primitive_map) {
const auto& packed_name = it.first; const auto& packed_name = it.first;
auto packed_index = static_cast<size_t>(it.second); auto packed_index = static_cast<size_t>(it.second);
if (packed_funcs.size() <= packed_index) { if (packed_funcs.size() <= packed_index) {
...@@ -761,6 +739,11 @@ void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) { ...@@ -761,6 +739,11 @@ void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
} }
} }
void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
this->ctxs = ctxs;
}
inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) {
frames.back().register_file[r] = val; frames.back().register_file[r] = val;
} }
...@@ -788,6 +771,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const { ...@@ -788,6 +771,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
void VirtualMachine::RunLoop() { void VirtualMachine::RunLoop() {
CHECK(this->code); CHECK(this->code);
CHECK(this->exec);
this->pc = 0; this->pc = 0;
Index frame_start = frames.size(); Index frame_start = frames.size();
while (true) { while (true) {
...@@ -810,7 +794,8 @@ void VirtualMachine::RunLoop() { ...@@ -810,7 +794,8 @@ void VirtualMachine::RunLoop() {
throw std::runtime_error("VM encountered fatal error"); throw std::runtime_error("VM encountered fatal error");
} }
case Opcode::LoadConst: { case Opcode::LoadConst: {
auto constant_obj = this->constants[instr.const_index]; auto constant_obj = exec->constants[instr.const_index];
// TODO(wweic) ctx could be obtained from the ctxs list.
auto device_obj = CopyTo(constant_obj, ctxs[0]); auto device_obj = CopyTo(constant_obj, ctxs[0]);
WriteRegister(instr.dst, device_obj); WriteRegister(instr.dst, device_obj);
pc++; pc++;
...@@ -828,7 +813,7 @@ void VirtualMachine::RunLoop() { ...@@ -828,7 +813,7 @@ void VirtualMachine::RunLoop() {
for (Index i = 0; i < instr.num_args; ++i) { for (Index i = 0; i < instr.num_args; ++i) {
args.push_back(ReadRegister(instr.invoke_args_registers[i])); args.push_back(ReadRegister(instr.invoke_args_registers[i]));
} }
InvokeGlobal(this->functions[instr.func_index], args); InvokeGlobal(exec->functions[instr.func_index], args);
frames.back().caller_return_register = instr.dst; frames.back().caller_return_register = instr.dst;
goto main_loop; goto main_loop;
} }
...@@ -858,7 +843,7 @@ void VirtualMachine::RunLoop() { ...@@ -858,7 +843,7 @@ void VirtualMachine::RunLoop() {
for (Index i = 0; i < instr.num_closure_args; ++i) { for (Index i = 0; i < instr.num_closure_args; ++i) {
args.push_back(ReadRegister(instr.closure_args[i])); args.push_back(ReadRegister(instr.closure_args[i]));
} }
InvokeGlobal(this->functions[closure->func_index], args); InvokeGlobal(exec->functions[closure->func_index], args);
frames.back().caller_return_register = instr.dst; frames.back().caller_return_register = instr.dst;
goto main_loop; goto main_loop;
} }
...@@ -910,6 +895,7 @@ void VirtualMachine::RunLoop() { ...@@ -910,6 +895,7 @@ void VirtualMachine::RunLoop() {
for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) {
shape[i] = instr.alloc_tensor.shape[i]; shape[i] = instr.alloc_tensor.shape[i];
} }
// TODO(wweic) ctx could be obtained from the ctxs list.
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]);
auto obj = Tensor(data); auto obj = Tensor(data);
...@@ -931,6 +917,7 @@ void VirtualMachine::RunLoop() { ...@@ -931,6 +917,7 @@ void VirtualMachine::RunLoop() {
auto num_dims = shape_tensor->shape[0]; auto num_dims = shape_tensor->shape[0];
auto shape = std::vector<int64_t>(shape_tensor->shape[0]); auto shape = std::vector<int64_t>(shape_tensor->shape[0]);
shape.assign(dims, dims + num_dims); shape.assign(dims, dims + num_dims);
// TODO(wweic) ctx could be obtained from the ctxs list.
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]);
auto obj = Tensor(data); auto obj = Tensor(data);
...@@ -976,6 +963,21 @@ void VirtualMachine::RunLoop() { ...@@ -976,6 +963,21 @@ void VirtualMachine::RunLoop() {
} }
} }
runtime::Module CreateVirtualMachine(const Executable* exec) {
std::shared_ptr<VirtualMachine> vm = std::make_shared<VirtualMachine>();
vm->LoadExecutable(exec);
return runtime::Module(vm);
}
TVM_REGISTER_GLOBAL("relay._vm._VirtualMachine")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec) << "The virtual machine executable has not been defined yet."
<< "\n";
*rv = CreateVirtualMachine(exec);
});
} // namespace vm } // namespace vm
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -47,14 +47,16 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): ...@@ -47,14 +47,16 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr): if isinstance(f, relay.Expr):
mod = relay.Module() mod = relay.Module()
mod["main"] = f mod["main"] = f
vm = relay.vm.compile(mod, target) exe = relay.vm.compile(mod, target)
vm.init(tvm.cpu()) vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
return vm.invoke("main", *args) return vm.invoke("main", *args)
else: else:
assert isinstance(f, relay.Module), "expected expression or module" assert isinstance(f, relay.Module), "expected expression or module"
mod = f mod = f
vm = relay.vm.compile(mod, target) exe = relay.vm.compile(mod, target)
vm.init(tvm.cpu()) vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
ret = vm.invoke("main", *args) ret = vm.invoke("main", *args)
return ret return ret
...@@ -573,25 +575,6 @@ def test_add_op_broadcast(): ...@@ -573,25 +575,6 @@ def test_add_op_broadcast():
mod["main"] = func mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod) check_result([x_data, y_data], x_data + y_data, mod=mod)
def test_set_params():
mod = relay.Module()
x = relay.var('x', shape=(10, 5))
w = relay.var('w', shape=(6, 5))
b = relay.var('b', shape=(6,))
y = relay.nn.bias_add(relay.nn.dense(x, w), b)
mod["main"] = relay.Function([x, w, b], y)
vm = relay.vm.compile(mod, 'llvm')
vm.init(tvm.cpu())
x_np = np.random.uniform(size=(10, 5)).astype('float32')
w_np = np.random.uniform(size=(6, 5)).astype('float32')
b_np = np.random.uniform(size=(6,)).astype('float32')
ref_np = np.dot(x_np, w_np.T) + b_np
params = {'w': w_np}
vm.load_params(params)
out = vm.run(x_np, b_np)
tvm.testing.assert_allclose(out.asnumpy(), ref_np)
if __name__ == "__main__": if __name__ == "__main__":
test_id() test_id()
...@@ -626,4 +609,3 @@ if __name__ == "__main__": ...@@ -626,4 +609,3 @@ if __name__ == "__main__":
test_add_op_scalar() test_add_op_scalar()
test_add_op_tensor() test_add_op_tensor()
test_add_op_broadcast() test_add_op_broadcast()
test_set_params()
...@@ -22,29 +22,25 @@ import tvm ...@@ -22,29 +22,25 @@ import tvm
from tvm import relay from tvm import relay
from tvm.relay.module import Module as rly_module from tvm.relay.module import Module as rly_module
from tvm.relay import vm as _vm from tvm.relay import vm as _vm
from tvm.relay import serializer, deserializer
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.contrib import util from tvm.contrib import util
from tvm.relay import testing from tvm.relay import testing
def create_vm(f, ctx=tvm.cpu(), target="llvm", params=None): def create_exec(f, target="llvm", params=None):
if isinstance(f, relay.Expr): if isinstance(f, relay.Expr):
mod = relay.Module() mod = relay.Module()
mod["main"] = f mod["main"] = f
vm = _vm.compile(mod, target=target, params=params) executable = _vm.compile(mod, target=target, params=params)
vm.init(ctx) return executable
return vm
else: else:
assert isinstance(f, relay.Module), "expected mod as relay.Module" assert isinstance(f, relay.Module), "expected mod as relay.Module"
vm = _vm.compile(f, target=target, params=params) executable = _vm.compile(f, target=target, params=params)
vm.init(ctx) return executable
return vm
def veval(vm, *args, ctx=tvm.cpu()): def veval(vm, *args, ctx=tvm.cpu()):
assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine" assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine"
vm.init(ctx)
ret = vm.run(*args) ret = vm.run(*args)
return ret return ret
...@@ -59,13 +55,11 @@ def run_network(mod, ...@@ -59,13 +55,11 @@ def run_network(mod,
return result.asnumpy().astype(dtype) return result.asnumpy().astype(dtype)
def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): def get_serialized_output(mod, data, params, target, ctx, dtype='float32'):
vm = create_vm(mod, ctx, target, params=params) exe = create_exec(mod, target, params=params)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize()
des_vm.init(ctx) des_vm.init(ctx)
des_vm.load_params(params)
result = des_vm.run(data) result = des_vm.run(data)
return result.asnumpy().astype(dtype) return result.asnumpy().astype(dtype)
...@@ -99,26 +93,25 @@ def test_serializer(): ...@@ -99,26 +93,25 @@ def test_serializer():
main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1)) main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1))
mod["main"] = main mod["main"] = main
vm = create_vm(mod) exe = create_exec(mod)
ser = serializer.Serializer(vm)
glbs = ser.globals glbs = exe.globals
assert len(glbs) == 3 assert len(glbs) == 3
assert "f1" in glbs assert "f1" in glbs
assert "f2" in glbs assert "f2" in glbs
assert "main" in glbs assert "main" in glbs
prim_ops = ser.primitive_ops prim_ops = exe.primitive_ops
assert any(item.startswith('fused_add') for item in prim_ops) assert any(item.startswith('fused_add') for item in prim_ops)
assert any(item.startswith('fused_subtract') for item in prim_ops) assert any(item.startswith('fused_subtract') for item in prim_ops)
assert any(item.startswith('fused_multiply') for item in prim_ops) assert any(item.startswith('fused_multiply') for item in prim_ops)
code = ser.bytecode code = exe.bytecode
assert "main 5 2 5" in code assert "main 5 2 5" in code
assert "f1 2 1 3" in code assert "f1 2 1 3" in code
assert "f2 2 1 3" in code assert "f2 2 1 3" in code
code, lib = ser.serialize() code, lib = exe.save()
assert isinstance(code, bytearray) assert isinstance(code, bytearray)
assert isinstance(lib, tvm.module.Module) assert isinstance(lib, tvm.module.Module)
...@@ -129,24 +122,24 @@ def test_save_load(): ...@@ -129,24 +122,24 @@ def test_save_load():
x_data = np.random.rand(10, 10).astype('float32') x_data = np.random.rand(10, 10).astype('float32')
# serialize. # serialize.
vm = create_vm(f) vm = create_exec(f)
ser = serializer.Serializer(vm) code, lib = vm.save()
code, lib = ser.serialize()
assert isinstance(code, bytearray) assert isinstance(code, bytearray)
# save and load the code and lib file. # save and load the code and lib file.
tmp = util.tempdir() tmp = util.tempdir()
path_lib = tmp.relpath("lib.so") path_lib = tmp.relpath("lib.so")
lib.export_library(path_lib) lib.export_library(path_lib)
with open(tmp.relpath("code.bc"), "wb") as fo: with open(tmp.relpath("code.ro"), "wb") as fo:
fo.write(code) fo.write(code)
loaded_lib = tvm.module.load(path_lib) loaded_lib = tvm.module.load(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read())
# deserialize. # deserialize.
deser = deserializer.Deserializer(loaded_code, loaded_lib) des_exec = _vm.Executable.load_exec(loaded_code, loaded_lib)
des_vm = deser.deserialize() des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
res = veval(des_vm, x_data) res = veval(des_vm, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data)
...@@ -156,12 +149,12 @@ def test_const(): ...@@ -156,12 +149,12 @@ def test_const():
c = relay.const(1.0, "float32") c = relay.const(1.0, "float32")
x = relay.var('x', shape=(10, 10), dtype='float32') x = relay.var('x', shape=(10, 10), dtype='float32')
f = relay.Function([x], x + c) f = relay.Function([x], x + c)
vm = create_vm(f) exe = create_exec(f)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize()
assert isinstance(code, bytearray) assert isinstance(code, bytearray)
deser = deserializer.Deserializer(code, lib) des_exec = _vm.Executable.load_exec(code, lib)
des_vm = deser.deserialize() des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
x_data = np.random.rand(10, 10).astype('float32') x_data = np.random.rand(10, 10).astype('float32')
res = veval(des_vm, x_data) res = veval(des_vm, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + 1) tvm.testing.assert_allclose(res.asnumpy(), x_data + 1)
...@@ -177,11 +170,11 @@ def test_if(): ...@@ -177,11 +170,11 @@ def test_if():
x_data = np.random.rand(10, 10).astype('float32') x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32') y_data = np.random.rand(10, 10).astype('float32')
vm = create_vm(f) exe = create_exec(f)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize() des_vm.init(tvm.cpu())
# same # same
res = veval(des_vm, x_data, x_data) res = veval(des_vm, x_data, x_data)
...@@ -213,11 +206,11 @@ def test_loop(): ...@@ -213,11 +206,11 @@ def test_loop():
aarg = relay.var('accum', shape=[], dtype='int32') aarg = relay.var('accum', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
vm = create_vm(mod) exe = create_exec(mod)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize() des_vm.init(tvm.cpu())
result = veval(des_vm, i_data, accum_data) result = veval(des_vm, i_data, accum_data)
tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1))) tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
...@@ -230,11 +223,11 @@ def test_tuple(): ...@@ -230,11 +223,11 @@ def test_tuple():
i_data = np.random.rand(41).astype('float32') i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32') j_data = np.random.rand(10).astype('float32')
vm = create_vm(f) exe = create_exec(f)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize() des_vm.init(tvm.cpu())
result = veval(des_vm, (i_data, j_data)) result = veval(des_vm, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), j_data) tvm.testing.assert_allclose(result.asnumpy(), j_data)
...@@ -251,11 +244,11 @@ def test_adt_list(): ...@@ -251,11 +244,11 @@ def test_adt_list():
f = relay.Function([], l321) f = relay.Function([], l321)
mod["main"] = f mod["main"] = f
vm = create_vm(mod) exe = create_exec(mod)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize() des_vm.init(tvm.cpu())
result = veval(des_vm) result = veval(des_vm)
assert len(result) == 2 assert len(result) == 2
...@@ -297,11 +290,11 @@ def test_adt_compose(): ...@@ -297,11 +290,11 @@ def test_adt_compose():
f = relay.Function([y], add_two_body) f = relay.Function([y], add_two_body)
mod["main"] = f mod["main"] = f
vm = create_vm(mod) exe = create_exec(mod)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize() des_vm.init(tvm.cpu())
x_data = np.array(np.random.rand()).astype('float32') x_data = np.array(np.random.rand()).astype('float32')
result = veval(des_vm, x_data) result = veval(des_vm, x_data)
...@@ -317,11 +310,11 @@ def test_closure(): ...@@ -317,11 +310,11 @@ def test_closure():
clo = ff(relay.const(1.0)) clo = ff(relay.const(1.0))
main = clo(relay.const(2.0)) main = clo(relay.const(2.0))
vm = create_vm(main) exe = create_exec(main)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize() des_vm.init(tvm.cpu())
res = veval(des_vm) res = veval(des_vm)
tvm.testing.assert_allclose(res.asnumpy(), 3.0) tvm.testing.assert_allclose(res.asnumpy(), 3.0)
......
...@@ -26,9 +26,9 @@ def test_basic(): ...@@ -26,9 +26,9 @@ def test_basic():
mod, params = resnet.get_workload() mod, params = resnet.get_workload()
target = 'llvm' target = 'llvm'
ctx = tvm.cpu() ctx = tvm.cpu()
vm = relay.profiler_vm.compile(mod, target) exe = relay.profiler_vm.compile(mod, target, params=params)
vm = relay.profiler_vm.VirtualMachineProfiler(exe)
vm.init(ctx) vm.init(ctx)
vm.load_params(params)
data = np.random.rand(1, 3, 224, 224).astype('float32') data = np.random.rand(1, 3, 224, 224).astype('float32')
res = vm.invoke("main", [data]) res = vm.invoke("main", [data])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment