Commit 4052de6d by Zhi Committed by Haichen Shen

[relay][vm] Separate VM runtime with executable (#4100)

* [relay][vm] Separate VM runtime with executable

* Address comments

* move ctx back to vm

* make only vm related fields and methods protected

* integrate seriliaztion/deserialization to executable

* create stream
parent cf046972
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -430,15 +431,184 @@ struct VMFrame { ...@@ -430,15 +431,184 @@ struct VMFrame {
caller_return_register(0) {} caller_return_register(0) {}
}; };
/*! \brief The executable emitted by the VM compiler.
*
* The executable contains information (e.g. data in different memory regions)
* to run in a virtual machine.
*
* - Global section, containing all globals.
* - Constant section, storing the constant pool.
* - Primitive name section, containing the function name of the primitive ops
* used by the virtual machine.
* - Code section, handling the VM functions and bytecode.
*/
class Executable : public ModuleNode {
public:
/*!
* \brief Get a PackedFunc from an executable module.
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
*
* \return PackedFunc or nullptr when it is not available.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
/*!
* \brief Serialize the executable into global section, constant section, and
* code section.
*
* \return The binary representation of the VM.
*/
TVMByteArray Save();
/*!
* \brief Load the saved VM executable.
*
* \param code The bytecode in string.
* \param lib The compiled runtime library.
*
* \return exe The constructed executable.
*/
static runtime::Module Load(const std::string& code, const runtime::Module lib);
/*!
* \brief Get the serialized form of the `functions`. This is
* essentially bytecode serialization.
*
* \return The serialized vm bytecode.
*
* \note The bytecode is in the following format:
* func_name reg_file_size num_instructions
* param1 param2 ... paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Each instruction is printed in the following format:
* opcode num_fields field1 ... fieldX # The text format.
*
* Serializing an `Instruction` requires us to deal with the bytecode. Each line
* of the instructions could be serialized as the following format:
* hash, opcode, f1, f2, ..., fX, field with variable length
* 1. hash: the hash of the instruction. This number will be used to help us
* validate if an instruction is well-formed during deserialization.
* 2. opcode: the opcode code of the instruction.
* 3. f1, f2, ..., fX. These fields together represent the fixed fields in
* an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For
* example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
* 4. The rest of the line indicates the field with variable length, e.g.,
* the shape of a tensor, the args used by an `InvokPacked` instruction, etc.
* The field starting from # is only used for debugging. The serialized code
* doesn't contain it, therefore the deserializer doens't need to handle it.
*/
std::string GetBytecode() const;
/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
*/
std::string Stats() const;
/*! \brief Get the `lib` module in an executable. Users have the flexibility to call
* `export_library` from the frontend to save the library to disk.
*
* \return The runtime module that contains the hardwre dependent code.
*/
runtime::Module GetLib() const { return lib; }
virtual ~Executable() {}
const char* type_key() const final {
return "VMExecutable";
}
/*! \brief The runtime module/library that contains both the host and also the device
* code when executing on non-CPU devices. */
runtime::Module lib;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
* corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
*/
std::unordered_map<std::string, Index> primitive_map;
/*! \brief The virtual machine's function table. */
std::vector<VMFunction> functions;
private:
/*!
* \brief Save the globals.
*
* \param strm The input stream.
*/
void SaveGlobalSection(dmlc::Stream* strm);
/*!
* \brief Save the constant pool.
*
* \param strm The input stream.
*/
void SaveConstantSection(dmlc::Stream* strm);
/*!
* \brief Save primitive op names.
*
* \param strm The input stream.
*/
void SavePrimitiveOpNames(dmlc::Stream* strm);
/*!
* \brief Save the vm functions.
*
* \param strm The input stream.
*/
void SaveCodeSection(dmlc::Stream* strm);
/*!
* \brief Load the globals.
*
* \param strm The input stream.
*/
void LoadGlobalSection(dmlc::Stream* strm);
/*!
* \brief Load the constant pool.
*
* \param strm The input stream.
*/
void LoadConstantSection(dmlc::Stream* strm);
/*!
* \brief Load primitive op names.
*
* \param strm The input stream.
*/
void LoadPrimitiveOpNames(dmlc::Stream* strm);
/*!
* \brief Load the vm functions.
*
* \param strm The input stream.
*/
void LoadCodeSection(dmlc::Stream* strm);
/*! \brief The serialized bytecode. */
std::string code_;
};
/*! \brief The virtual machine. /*! \brief The virtual machine.
* *
* The virtual machine contains all the current execution state, * The virtual machine contains all the current execution state,
* as well as the global view of functions, the global constant * as well as the executable.
* table, the compiled operators.
* *
* The goal is to have a single self-contained object, * The goal is to have a single self-contained object,
* enabling one to easily pass around VMs, execute them on * enabling one to easily pass around VMs, execute them on
* multiple threads, or serialized them to disk or over the * multiple threads, or serialize them to disk or over the
* wire. * wire.
*/ */
class VirtualMachine : public runtime::ModuleNode { class VirtualMachine : public runtime::ModuleNode {
...@@ -486,16 +656,18 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -486,16 +656,18 @@ class VirtualMachine : public runtime::ModuleNode {
return "VirtualMachine"; return "VirtualMachine";
} }
/*! \brief The runtime module/library that contains generated code. */ VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {}
runtime::Module lib;
/*! \brief load the executable for the virtual machine.
* \param exec The executable.
*/
void LoadExecutable(const Executable* exec);
protected:
/*! \brief The virtual machine's packed function table. */ /*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs; std::vector<PackedFunc> packed_funcs;
/*! \brief The virtual machine's function table. */
std::vector<VMFunction> functions;
/*! \brief The current stack of call frames. */ /*! \brief The current stack of call frames. */
std::vector<VMFrame> frames; std::vector<VMFrame> frames;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief The fuction table index of the current function. */ /*! \brief The fuction table index of the current function. */
Index func_index; Index func_index;
/*! \brief The current pointer to the code section. */ /*! \brief The current pointer to the code section. */
...@@ -506,6 +678,9 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -506,6 +678,9 @@ class VirtualMachine : public runtime::ModuleNode {
/*! \brief The special return register. */ /*! \brief The special return register. */
ObjectRef return_register; ObjectRef return_register;
/*! \brief The executable the VM will operate on. */
const Executable* exec;
/*! \brief The set of TVM contexts the VM is currently executing on. */ /*! \brief The set of TVM contexts the VM is currently executing on. */
std::vector<TVMContext> ctxs; std::vector<TVMContext> ctxs;
...@@ -550,8 +725,6 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -550,8 +725,6 @@ class VirtualMachine : public runtime::ModuleNode {
*/ */
ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args); ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args);
VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {}
/*! \brief Initialize the virtual machine for a set of contexts. /*! \brief Initialize the virtual machine for a set of contexts.
* \param contexts The set of TVM contexts. * \param contexts The set of TVM contexts.
*/ */
...@@ -565,21 +738,6 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -565,21 +738,6 @@ class VirtualMachine : public runtime::ModuleNode {
*/ */
TVMContext GetParamsContext() const; TVMContext GetParamsContext() const;
/*!
* \brief Load parameters from the parameter bytearray.
* \param params The binary file that contains parameters.
*/
void LoadParams(const std::string& params);
/*! \brief A map from globals (as strings) to their index in the function map.
*/
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
* corresponds to the position of the `packed_funcs` list.
*/
std::unordered_map<std::string, Index> primitive_map;
private: private:
/*! \brief Invoke a global setting up the VM state to execute. /*! \brief Invoke a global setting up the VM state to execute.
* *
......
...@@ -37,8 +37,6 @@ from . import param_dict ...@@ -37,8 +37,6 @@ from . import param_dict
from . import feature from . import feature
from .backend import vm from .backend import vm
from .backend import profiler_vm from .backend import profiler_vm
from .backend import serializer
from .backend import deserializer
from .backend import vmobj from .backend import vmobj
# Root operators # Root operators
......
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
The Relay Virtual Machine deserializer.
Python interface for deserializing a Relay VM.
"""
from tvm import module
from tvm._ffi.runtime_ctypes import TVMByteArray
from . import _vm
from . import vm as rly_vm
def _create_deserializer(code, lib):
"""Create a deserializer object.
Parameters
----------
code : bytearray
The serialized virtual machine code.
lib : :py:class:`~tvm.module.Module`
The serialized runtime module/library that contains the hardware
dependent binary code.
Returns
-------
ret : Deserializer
The created virtual machine deserializer.
"""
if isinstance(code, (bytes, str)):
code = bytearray(code)
elif not isinstance(code, (bytearray, TVMByteArray)):
raise TypeError("vm is expected to be the type of bytearray or " +
"TVMByteArray, but received {}".format(type(code)))
if not isinstance(lib, module.Module):
raise TypeError("lib is expected to be the type of tvm.module.Module" +
", but received {}".format(type(lib)))
return _vm._Deserializer(code, lib)
class Deserializer:
"""Relay VM deserializer.
Parameters
----------
code : bytearray
The serialized virtual machine code.
lib : :py:class:`~tvm.module.Module`
The serialized runtime module/library that contains the hardware
dependent binary code.
"""
def __init__(self, code, lib):
self.mod = _create_deserializer(code, lib)
self._deserialize = self.mod["deserialize"]
def deserialize(self):
"""Deserialize the serialized bytecode into a Relay VM.
Returns
-------
ret : VirtualMachine
The deserialized Relay VM.
"""
return rly_vm.VirtualMachine(self._deserialize())
...@@ -49,8 +49,8 @@ def compile(mod, target=None, target_host=None, params=None): ...@@ -49,8 +49,8 @@ def compile(mod, target=None, target_host=None, params=None):
Returns Returns
------- -------
vm : VirtualMachineProfiler exec : Executable
The profile VM runtime. The executable with profiling code.
""" """
compiler = VMCompilerProfiler() compiler = VMCompilerProfiler()
target = compiler.update_target(target) target = compiler.update_target(target)
...@@ -60,7 +60,7 @@ def compile(mod, target=None, target_host=None, params=None): ...@@ -60,7 +60,7 @@ def compile(mod, target=None, target_host=None, params=None):
tophub_context = compiler.tophub_context(target) tophub_context = compiler.tophub_context(target)
with tophub_context: with tophub_context:
compiler._compile(mod, target, target_host) compiler._compile(mod, target, target_host)
return VirtualMachineProfiler(compiler._get_vm()) return vm.Executable(compiler._get_exec())
class VMCompilerProfiler(vm.VMCompiler): class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime.""" """Build Relay module to run on VM runtime."""
...@@ -68,13 +68,17 @@ class VMCompilerProfiler(vm.VMCompiler): ...@@ -68,13 +68,17 @@ class VMCompilerProfiler(vm.VMCompiler):
super().__init__() super().__init__()
self.mod = _vm._VMCompilerProfiler() self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"] self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"] self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"] self._set_params_func = self.mod["set_params"]
class VirtualMachineProfiler(vm.VirtualMachine): class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime.""" """Relay profile VM runtime."""
def __init__(self, mod): def __init__(self, mod):
super().__init__(mod) super().__init__(mod)
m = mod.module if isinstance(mod, vm.Executable) else mod
self.mod = _vm._VirtualMachineDebug(m)
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
self._get_stat = self.mod["get_stat"] self._get_stat = self.mod["get_stat"]
def get_stat(self): def get_stat(self):
......
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
The Relay Virtual Machine serializer.
Python interface for serializing a Relay VM.
"""
import tvm
from . import _vm
from . import vm as rly_vm
def _create_serializer(vm):
"""Create a VM serializer.
Parameters
----------
vm : Union[VirtualMachine, :py:class:`~tvm.module.Module`]
The virtual machine to be serialized.
Returns
-------
ret : Serializer
The created virtual machine serializer.
"""
if isinstance(vm, rly_vm.VirtualMachine):
vm = vm.module
elif not isinstance(vm, tvm.module.Module):
raise TypeError("vm is expected to be the type of VirtualMachine or " +
"tvm.Module, but received {}".format(type(vm)))
return _vm._Serializer(vm)
class Serializer:
"""Relay VM serializer."""
def __init__(self, vm):
self.mod = _create_serializer(vm)
self._get_lib = self.mod["get_lib"]
self._get_bytecode = self.mod["get_bytecode"]
self._get_globals = self.mod["get_globals"]
self._get_stats = self.mod["get_stats"]
self._get_primitive_ops = self.mod["get_primitive_ops"]
self._serialize = self.mod["serialize"]
@property
def stats(self):
"""Get the statistics of the Relay VM.
Returns
-------
ret : String
The serialized statistic information.
"""
return self._get_stats()
@property
def primitive_ops(self):
"""Get the name of the primitive ops that are executed in the VM.
Returns
-------
ret : List[:py:class:`~tvm.expr.StringImm`]
The list of primitive ops.
"""
return [prim_op.value for prim_op in self._get_primitive_ops()]
@property
def bytecode(self):
"""Get the bytecode of the Relay VM.
Returns
-------
ret : String
The serialized bytecode.
Notes
-----
The bytecode is in the following format:
func_name reg_file_size num_instructions
param1 param2 ... paramM
instruction1
instruction2
...
instructionN
Each instruction is printed in the following format:
hash opcode field1 ... fieldX # The text format.
The part starting from # is only used for visualization and debugging.
The real serialized code doesn't contain it, therefore the deserializer
doesn't need to deal with it as well.
"""
return self._get_bytecode()
@property
def globals(self):
"""Get the globals used by the Relay VM.
Returns
-------
ret : List[:py:class:`~tvm.expr.StringImm`]
The serialized globals.
"""
return [glb.value for glb in self._get_globals()]
def serialize(self):
"""Serialize the Relay VM.
Returns
-------
code : bytearray
The binary blob representing a serialized Relay VM. It can then be
saved to disk and later deserialized into a new VM.
lib : :py:class:`~tvm.module.Module`
The runtime module that contains the generated code. It is
basically a library that is composed of hardware dependent code.
Notes
-----
The returned code is organized with the following sections in order.
- Global section. This section contains the globals used by the
virtual machine.
- Constant section. This section is used to store the constant pool of
a virtual machine.
- Primitive name section. This section is introduced to accommodate
the list of primitive operator names that will be invoked by the
virtual machine.
- Code section. The VM functions, including bytecode, are sitting in
this section.
Examples
--------
.. code-block:: python
import numpy as np
import tvm
from tvm import relay
# define a simple network.
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x + x)
mod = relay.Module({"main": f})
# create a Relay VM.
ctx = tvm.cpu()
target = "llvm"
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(ctx)
# serialize.
ser = relay.serializer.Serializer(vm)
code, lib = ser.serialize()
# save and load the code and lib file.
tmp = tvm.contrib.util.tempdir()
path_lib = tmp.relpath("lib.so")
lib.export_library(path_lib)
with open(tmp.relpath("code.bc"), "wb") as fo:
fo.write(code)
loaded_lib = tvm.module.load(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read())
# deserialize.
deser = relay.deserializer.Deserializer(loaded_code, loaded_lib)
des_vm = deser.deserialize()
# execute the deserialized vm.
des_vm.init(ctx)
x_data = np.random.rand(10, 10).astype('float32')
res = des_vm.run(x_data)
print(res.asnumpy())
"""
return self._serialize(), self._get_lib()
...@@ -783,9 +783,9 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, ...@@ -783,9 +783,9 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
Module mod = args[0]; Module mod = args[0];
this->Compile(mod, args[1], args[2]); this->Compile(mod, args[1], args[2]);
}); });
} else if (name == "get_vm") { } else if (name == "get_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(vm_); *rv = runtime::Module(exec_);
}); });
} else if (name == "set_params") { } else if (name == "set_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
...@@ -864,7 +864,7 @@ void VMCompiler::Compile(Module mod, ...@@ -864,7 +864,7 @@ void VMCompiler::Compile(Module mod,
// Next we get ready by allocating space for // Next we get ready by allocating space for
// the global state. // the global state.
vm_->functions.resize(context_.module->functions.size()); exec_->functions.resize(context_.module->functions.size());
for (auto named_func : context_.module->functions) { for (auto named_func : context_.module->functions) {
auto gvar = named_func.first; auto gvar = named_func.first;
...@@ -873,25 +873,25 @@ void VMCompiler::Compile(Module mod, ...@@ -873,25 +873,25 @@ void VMCompiler::Compile(Module mod,
auto vm_func = func_compiler.Compile(gvar, func); auto vm_func = func_compiler.Compile(gvar, func);
size_t func_index = context_.global_map.at(gvar); size_t func_index = context_.global_map.at(gvar);
CHECK(func_index < vm_->functions.size()); CHECK(func_index < exec_->functions.size());
vm_->functions[func_index] = vm_func; exec_->functions[func_index] = vm_func;
} }
#if USE_RELAY_DEBUG #if USE_RELAY_DEBUG
for (auto vm_func : vm_->functions) { for (auto vm_func : exec_->functions) {
DLOG(INFO) << vm_func << "-------------"; DLOG(INFO) << vm_func << "-------------";
} }
#endif // USE_RELAY_DEBUG #endif // USE_RELAY_DEBUG
// populate constants // populate constants
for (auto data : context_.constants) { for (auto data : context_.constants) {
vm_->constants.push_back(runtime::vm::Tensor(data)); exec_->constants.push_back(runtime::vm::Tensor(data));
} }
LibraryCodegen(); LibraryCodegen();
for (auto gv : context_.global_map) { for (auto gv : context_.global_map) {
vm_->global_map.insert({gv.first->name_hint, gv.second}); exec_->global_map.insert({gv.first->name_hint, gv.second});
} }
} }
...@@ -987,13 +987,13 @@ void VMCompiler::LibraryCodegen() { ...@@ -987,13 +987,13 @@ void VMCompiler::LibraryCodegen() {
// therefore target won't be used in the build function // therefore target won't be used in the build function
runtime::Module mod = (*f)(funcs, Target(), target_host_); runtime::Module mod = (*f)(funcs, Target(), target_host_);
CHECK(mod.operator->()); CHECK(mod.operator->());
vm_->lib = mod; exec_->lib = mod;
} else { } else {
LOG(FATAL) << "relay.backend.build is not registered"; LOG(FATAL) << "relay.backend.build is not registered";
} }
size_t primitive_index = 0; size_t primitive_index = 0;
for (auto cfunc : cached_funcs) { for (auto cfunc : cached_funcs) {
vm_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
} }
} }
......
...@@ -92,12 +92,8 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -92,12 +92,8 @@ class VMCompiler : public runtime::ModuleNode {
return "VMCompiler"; return "VMCompiler";
} }
std::shared_ptr<VirtualMachine> GetVirtualMachine() const { void InitVM() {
return vm_; exec_ = std::make_shared<Executable>();
}
virtual void InitVM() {
vm_ = std::make_shared<VirtualMachine>();
} }
/*! /*!
...@@ -144,8 +140,8 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -144,8 +140,8 @@ class VMCompiler : public runtime::ModuleNode {
tvm::Target target_host_; tvm::Target target_host_;
/*! \brief Global shared meta data */ /*! \brief Global shared meta data */
VMCompilerContext context_; VMCompilerContext context_;
/*! \brief Compiled virtual machine. */ /*! \brief Compiled executable. */
std::shared_ptr<VirtualMachine> vm_; std::shared_ptr<Executable> exec_;
/*! \brief parameters */ /*! \brief parameters */
std::unordered_map<std::string, runtime::NDArray> params_; std::unordered_map<std::string, runtime::NDArray> params_;
}; };
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/deserializer.h
* \brief Define a deserializer for the serialized Relay VM.
*/
#ifndef TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
#define TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
#include <dmlc/memory_io.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/vm.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace relay {
namespace vm {
using namespace tvm::runtime::vm;
namespace runtime = tvm::runtime;
class Deserializer : public runtime::ModuleNode {
public:
/*!
* \brief Initialize the deserializer for creating a virtual machine object.
*
* \param code The serialized code.
* \param lib The serialized runtime module/library that contains the
* hardware dependent code.
*/
inline void Init(const std::string& code, const runtime::Module& lib);
/*!
* \brief Return the member function to the frontend.
*
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
*
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final { return "Deserializer"; }
/*! \brief Deserialize the serialized VM. */
void Deserialize();
virtual ~Deserializer() { delete strm_; }
private:
/*! \brief Deserialize the globals in `vm_`. */
void DeserializeGlobalSection();
/*! \brief Deserialize the constant pool in `vm_`. */
void DeserializeConstantSection();
/*! \brief Deserialize primitive op names in `vm_`. */
void DeserializePrimitiveOpNames();
/*! \brief Deserialize the vm functions in `vm_`. */
void DeserializeCodeSection();
/*! \brief The code to be serialized. */
std::string code_;
/*! \brief The stream used for serialization. */
dmlc::Stream* strm_;
/*! \brief The VM to be created. */
std::shared_ptr<VirtualMachine> vm_;
};
} // namespace vm
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
...@@ -33,7 +33,6 @@ namespace vm { ...@@ -33,7 +33,6 @@ namespace vm {
class VMCompilerDebug : public VMCompiler { class VMCompilerDebug : public VMCompiler {
public: public:
VMCompilerDebug() {} VMCompilerDebug() {}
void InitVM() override { vm_ = std::make_shared<VirtualMachineDebug>(); }
virtual ~VMCompilerDebug() {} virtual ~VMCompilerDebug() {}
}; };
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/serializer.h
* \brief Define a serializer for the Relay VM.
*
* The following components of a Relay VM will be serialized:
* - The `constants`, e.g., the constant pool, that contains the
* constants used in a Relay program.
* - The `packed_funcs` that essentially contains the generated code for
* a specific target. We return it as a runtime module that can be exported as
* a library file (e.g., .so, .o, or .tar).
* - The `global_map` that contains the globals.
* - The `primitive_map` that contains the name of individual primitive operators.
* - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of
* a list of instructions/bytecode.
*
* Note that only the library is returned as a separate module. All othere parts
* are stored in a single serialized code that is organized with the following
* sections in order.
* - Global section, containing all globals.
* - Constant section, storing the constant pool.
* - Primitive name section, containing the function name of the primitive ops
* used by the virtual machine.
* - Code section, handling the VM functions and bytecode.
*
* The code section is again organized as follows for each VM function:
* func_name, register_file_size, num_instructions (N)
* param1, param2, ..., paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Serializing an `Instruction` requires us to deal with the bytecode. Each line
* of the instructions could be serialized as the following format:
* hash, opcode, f1, f2, ..., fX, field with variable length
* 1. hash: the hash of the instruction. This number will be used to help us
* validate if an instruction is well-formed during deserialization.
* 2. opcode: the opcode code of the instruction.
* 3. f1, f2, ..., fX. These fields together represent the fixed fields in
* an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For
* example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
* 4. The rest of the line indicates the field with variable length, e.g.,
* the shape of a tensor, the args used by an `InvokPacked` instruction, etc.
*/
#ifndef TVM_RELAY_BACKEND_VM_SERIALIZER_H_
#define TVM_RELAY_BACKEND_VM_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <tvm/ir.h>
#include <tvm/node/container.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/vm.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace relay {
namespace vm {
using namespace tvm::runtime;
using namespace tvm::runtime::vm;
/*!
* \brief The Relay VM serializer.
*/
class Serializer : public runtime::ModuleNode {
public:
/*!
* \brief Initialize the serializer for a virtual machine.
*
* \param vm The Relay virtual machine.
*/
inline void Init(const VirtualMachine* vm);
/*!
* \brief Return the member function to the frontend.
*
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
*
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final { return "Serializer"; }
/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
*/
std::string Stats() const;
/*!
* \brief Serialize the `vm_` into global section, constant section, and code
* section.
*
* \return The binary representation of the VM.
*/
TVMByteArray Serialize();
/*!
* \brief Get a list of the globals used by the `_vm`.
*
* \return The global map in the form a list.
*/
tvm::Array<tvm::Expr> GetGlobals() const;
/*!
* \brief Get the primitive operators that are contained in the Relay VM.
*
* \return The list of primitve operators.
*/
tvm::Array<tvm::Expr> GetPrimitiveOps() const;
/*!
* \brief Get the serialized form of the `functions` in `vm_`. This is
* essentially bytecode serialization.
*
* \return The serialized vm bytecode.
*
* \note The bytecode is in the following format:
* func_name reg_file_size num_instructions
* param1 param2 ... paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Each instruction is printed in the following format:
* opcode num_fields field1 ... fieldX # The text format.
*
* The field starting from # is only used for debugging. The serialized code
* doesn't contain it, therefore the deserializer doens't need to handle it.
*/
std::string GetBytecode() const;
/*! \brief Get the `lib` module in vm_. Serialization of `runtime::module`
* has already been supported by TVM. Therefore, we only return the runtime
* module and let users have the flexibility to call `export_library` from
* the frontend to save the library to disk.
*
* \return The runtime module that contains the hardwre dependent code.
*/
inline runtime::Module GetLib() const;
virtual ~Serializer() { delete strm_; }
private:
/*! \brief Serialize the globals in vm_. */
void SerializeGlobalSection();
/*! \brief Serialize the constant pool in vm_. */
void SerializeConstantSection();
/*! \brief Serialize primitive op names in vm_. */
void SerializePrimitiveOpNames();
/*! \brief Serialize the vm functions in vm_. */
void SerializeCodeSection();
/*! \brief The Relay virtual machine for to be serialized. */
const VirtualMachine* vm_;
/*! \brief The stream used for serialization. */
dmlc::Stream* strm_;
/*! \brief The serialized code. */
std::string code_;
};
} // namespace vm
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_SERIALIZER_H_
...@@ -85,19 +85,25 @@ PackedFunc VirtualMachineDebug::GetFunction( ...@@ -85,19 +85,25 @@ PackedFunc VirtualMachineDebug::GetFunction(
} }
} }
void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) { void VirtualMachineDebug::LoadExecutable(const Executable* exec) {
VirtualMachine::Init(ctxs); VirtualMachine::LoadExecutable(exec);
for (auto kv : primitive_map) { CHECK(this->exec);
for (auto kv : this->exec->primitive_map) {
packed_index_map[kv.second] = kv.first; packed_index_map[kv.second] = kv.first;
op_invokes[kv.second] = 0; op_invokes[kv.second] = 0;
} }
} }
void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
VirtualMachine::Init(ctxs);
}
void VirtualMachineDebug::InvokePacked(Index packed_index, void VirtualMachineDebug::InvokePacked(Index packed_index,
const PackedFunc& func, Index arg_count, const PackedFunc& func, Index arg_count,
Index output_size, Index output_size,
const std::vector<ObjectRef>& args) { const std::vector<ObjectRef>& args) {
auto ctx = VirtualMachine::GetParamsContext(); CHECK(this->exec);
auto ctx = this->GetParamsContext();
// warmup // warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args); args);
...@@ -117,6 +123,21 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, ...@@ -117,6 +123,21 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
op_invokes[packed_index] += 1; op_invokes[packed_index] += 1;
} }
runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
std::shared_ptr<VirtualMachineDebug> vm = std::make_shared<VirtualMachineDebug>();
vm->LoadExecutable(exec);
return runtime::Module(vm);
}
TVM_REGISTER_GLOBAL("relay._vm._VirtualMachineDebug")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec) << "Virtual machine has not been defined yet."
<< "\n";
*rv = CreateVirtualMachineDebug(exec);
});
} // namespace vm } // namespace vm
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -47,6 +47,8 @@ class VirtualMachineDebug : public VirtualMachine { ...@@ -47,6 +47,8 @@ class VirtualMachineDebug : public VirtualMachine {
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<ObjectRef>& args) final; Index output_size, const std::vector<ObjectRef>& args) final;
void LoadExecutable(const Executable* exec);
~VirtualMachineDebug() {} ~VirtualMachineDebug() {}
private: private:
......
...@@ -19,11 +19,11 @@ ...@@ -19,11 +19,11 @@
/*! /*!
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/serialize_util.h * \file src/runtime/vm/serialize_util.h
* \brief Definitions of helpers for serializing and deserializing a Relay VM. * \brief Definitions of helpers for serializing and deserializing a Relay VM.
*/ */
#ifndef TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ #ifndef TVM_RUNTIME_VM_SERIALIZE_UTIL_H_
#define TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ #define TVM_RUNTIME_VM_SERIALIZE_UTIL_H_
#include <dmlc/common.h> #include <dmlc/common.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
#include <vector> #include <vector>
namespace tvm { namespace tvm {
namespace relay { namespace runtime {
namespace vm { namespace vm {
/*! \brief The magic number for the serialized VM bytecode file */ /*! \brief The magic number for the serialized VM bytecode file */
...@@ -158,7 +158,7 @@ struct VMInstructionSerializer { ...@@ -158,7 +158,7 @@ struct VMInstructionSerializer {
}; };
} // namespace vm } // namespace vm
} // namespace relay } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ #endif // TVM_RUNTIME_VM_SERIALIZE_UTIL_H_
...@@ -575,11 +575,12 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, ...@@ -575,11 +575,12 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "invoke") { if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(exec) << "The executable is not created yet.";
std::string func_name = args[0]; std::string func_name = args[0];
auto gvit = this->global_map.find(func_name); auto gvit = exec->global_map.find(func_name);
CHECK(gvit != this->global_map.end()) << "Cannot find function " << func_name; CHECK(gvit != exec->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second; auto func_index = gvit->second;
const auto& vm_func = this->functions[func_index]; const auto& vm_func = exec->functions[func_index];
const auto& param_names = vm_func.params; const auto& param_names = vm_func.params;
auto ctx = this->GetParamsContext(); auto ctx = this->GetParamsContext();
...@@ -617,10 +618,6 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, ...@@ -617,10 +618,6 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
} }
this->Init(contexts); this->Init(contexts);
}); });
} else if (name == "load_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->LoadParams(args[0].operator std::string());
});
} else { } else {
LOG(FATAL) << "Unknown packed function: " << name; LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
...@@ -628,43 +625,20 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, ...@@ -628,43 +625,20 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
} }
TVMContext VirtualMachine::GetParamsContext() const { TVMContext VirtualMachine::GetParamsContext() const {
CHECK(!ctxs.empty()) << "Context has not been initialized yet."
<< "\n";
// Use the fallback device if no device index is available. // Use the fallback device if no device index is available.
int fallback_device_type = static_cast<int>(ctxs[0].device_type); int fallback_device_type = static_cast<int>(ctxs[0].device_type);
// TODO(wweic): For heterogeneous execution, get device information from byte // TODO(wweic): For heterogeneous execution, get device information from byte
const auto& cit = const auto& cit =
std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) {
return fallback_device_type == static_cast<int>(c.device_type); return fallback_device_type == static_cast<int>(c.device_type);
}); });
return (cit == ctxs.end() ? ctxs[0] : *cit); return (cit == ctxs.end() ? ctxs[0] : *cit);
} }
void VirtualMachine::LoadParams(const std::string& params) {
dmlc::MemoryStringStream mss(const_cast<std::string*>(&params));
dmlc::Stream* strm = &mss;
uint64_t header, reserved;
CHECK(strm->Read(&header)) << "Invalid parameter file";
CHECK(header == kTVMNDArrayListMagic) << "Invalid parameter file";
CHECK(strm->Read(&reserved)) << "Invalid parameter file";
std::vector<std::string> names;
CHECK(strm->Read(&names)) << "Invalid parameter file";
uint64_t sz;
strm->Read(&sz);
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) << "Invalid parameter file";
auto ctx = GetParamsContext();
for (size_t i = 0; i < size; i++) {
NDArray arr;
CHECK(arr.Load(strm)) << "Invalid parameter file";
ObjectRef obj = Tensor(arr);
auto copy = CopyTo(obj, ctx);
params_.emplace(std::make_pair(names[i], copy));
}
}
void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) {
auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size); auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size);
frames.push_back(frame); frames.push_back(frame);
...@@ -699,15 +673,17 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<Objec ...@@ -699,15 +673,17 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<Objec
InvokeGlobal(func, args); InvokeGlobal(func, args);
RunLoop(); RunLoop();
// TODO(wweic) ctx could be obtained from the ctxs list.
auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]); auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]);
DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B"; DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B";
return return_register; return return_register;
} }
ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<ObjectRef>& args) { ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<ObjectRef>& args) {
auto func_index = this->global_map[name]; CHECK(exec) << "The executable has not been created yet.";
auto func_index = exec->global_map.at(name);
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index; DLOG(INFO) << "Invoke Global " << name << " at index " << func_index;
return Invoke(this->functions[func_index], args); return Invoke(exec->functions[func_index], args);
} }
void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
...@@ -744,14 +720,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, ...@@ -744,14 +720,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
} }
void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) { void VirtualMachine::LoadExecutable(const Executable* exec) {
this->ctxs = ctxs; CHECK(exec) << "The executable is not created yet.";
this->exec = exec;
runtime::Module lib = this->exec->lib;
// Get the list of packed functions. // Get the list of packed functions.
CHECK(primitive_map.empty() || lib.operator->()) CHECK(exec->primitive_map.empty() || lib.operator->())
<< "runtime module should have been built for primitive functions" << "runtime module should have been built for primitive functions"
<< "\n"; << "\n";
for (const auto& it : primitive_map) { for (const auto& it : this->exec->primitive_map) {
const auto& packed_name = it.first; const auto& packed_name = it.first;
auto packed_index = static_cast<size_t>(it.second); auto packed_index = static_cast<size_t>(it.second);
if (packed_funcs.size() <= packed_index) { if (packed_funcs.size() <= packed_index) {
...@@ -761,6 +739,11 @@ void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) { ...@@ -761,6 +739,11 @@ void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
} }
} }
void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
this->ctxs = ctxs;
}
inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) {
frames.back().register_file[r] = val; frames.back().register_file[r] = val;
} }
...@@ -788,6 +771,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const { ...@@ -788,6 +771,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
void VirtualMachine::RunLoop() { void VirtualMachine::RunLoop() {
CHECK(this->code); CHECK(this->code);
CHECK(this->exec);
this->pc = 0; this->pc = 0;
Index frame_start = frames.size(); Index frame_start = frames.size();
while (true) { while (true) {
...@@ -810,7 +794,8 @@ void VirtualMachine::RunLoop() { ...@@ -810,7 +794,8 @@ void VirtualMachine::RunLoop() {
throw std::runtime_error("VM encountered fatal error"); throw std::runtime_error("VM encountered fatal error");
} }
case Opcode::LoadConst: { case Opcode::LoadConst: {
auto constant_obj = this->constants[instr.const_index]; auto constant_obj = exec->constants[instr.const_index];
// TODO(wweic) ctx could be obtained from the ctxs list.
auto device_obj = CopyTo(constant_obj, ctxs[0]); auto device_obj = CopyTo(constant_obj, ctxs[0]);
WriteRegister(instr.dst, device_obj); WriteRegister(instr.dst, device_obj);
pc++; pc++;
...@@ -828,7 +813,7 @@ void VirtualMachine::RunLoop() { ...@@ -828,7 +813,7 @@ void VirtualMachine::RunLoop() {
for (Index i = 0; i < instr.num_args; ++i) { for (Index i = 0; i < instr.num_args; ++i) {
args.push_back(ReadRegister(instr.invoke_args_registers[i])); args.push_back(ReadRegister(instr.invoke_args_registers[i]));
} }
InvokeGlobal(this->functions[instr.func_index], args); InvokeGlobal(exec->functions[instr.func_index], args);
frames.back().caller_return_register = instr.dst; frames.back().caller_return_register = instr.dst;
goto main_loop; goto main_loop;
} }
...@@ -858,7 +843,7 @@ void VirtualMachine::RunLoop() { ...@@ -858,7 +843,7 @@ void VirtualMachine::RunLoop() {
for (Index i = 0; i < instr.num_closure_args; ++i) { for (Index i = 0; i < instr.num_closure_args; ++i) {
args.push_back(ReadRegister(instr.closure_args[i])); args.push_back(ReadRegister(instr.closure_args[i]));
} }
InvokeGlobal(this->functions[closure->func_index], args); InvokeGlobal(exec->functions[closure->func_index], args);
frames.back().caller_return_register = instr.dst; frames.back().caller_return_register = instr.dst;
goto main_loop; goto main_loop;
} }
...@@ -910,6 +895,7 @@ void VirtualMachine::RunLoop() { ...@@ -910,6 +895,7 @@ void VirtualMachine::RunLoop() {
for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) {
shape[i] = instr.alloc_tensor.shape[i]; shape[i] = instr.alloc_tensor.shape[i];
} }
// TODO(wweic) ctx could be obtained from the ctxs list.
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]);
auto obj = Tensor(data); auto obj = Tensor(data);
...@@ -931,6 +917,7 @@ void VirtualMachine::RunLoop() { ...@@ -931,6 +917,7 @@ void VirtualMachine::RunLoop() {
auto num_dims = shape_tensor->shape[0]; auto num_dims = shape_tensor->shape[0];
auto shape = std::vector<int64_t>(shape_tensor->shape[0]); auto shape = std::vector<int64_t>(shape_tensor->shape[0]);
shape.assign(dims, dims + num_dims); shape.assign(dims, dims + num_dims);
// TODO(wweic) ctx could be obtained from the ctxs list.
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]);
auto obj = Tensor(data); auto obj = Tensor(data);
...@@ -976,6 +963,21 @@ void VirtualMachine::RunLoop() { ...@@ -976,6 +963,21 @@ void VirtualMachine::RunLoop() {
} }
} }
runtime::Module CreateVirtualMachine(const Executable* exec) {
std::shared_ptr<VirtualMachine> vm = std::make_shared<VirtualMachine>();
vm->LoadExecutable(exec);
return runtime::Module(vm);
}
TVM_REGISTER_GLOBAL("relay._vm._VirtualMachine")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec) << "The virtual machine executable has not been defined yet."
<< "\n";
*rv = CreateVirtualMachine(exec);
});
} // namespace vm } // namespace vm
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -47,14 +47,16 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): ...@@ -47,14 +47,16 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr): if isinstance(f, relay.Expr):
mod = relay.Module() mod = relay.Module()
mod["main"] = f mod["main"] = f
vm = relay.vm.compile(mod, target) exe = relay.vm.compile(mod, target)
vm.init(tvm.cpu()) vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
return vm.invoke("main", *args) return vm.invoke("main", *args)
else: else:
assert isinstance(f, relay.Module), "expected expression or module" assert isinstance(f, relay.Module), "expected expression or module"
mod = f mod = f
vm = relay.vm.compile(mod, target) exe = relay.vm.compile(mod, target)
vm.init(tvm.cpu()) vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
ret = vm.invoke("main", *args) ret = vm.invoke("main", *args)
return ret return ret
...@@ -573,25 +575,6 @@ def test_add_op_broadcast(): ...@@ -573,25 +575,6 @@ def test_add_op_broadcast():
mod["main"] = func mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod) check_result([x_data, y_data], x_data + y_data, mod=mod)
def test_set_params():
mod = relay.Module()
x = relay.var('x', shape=(10, 5))
w = relay.var('w', shape=(6, 5))
b = relay.var('b', shape=(6,))
y = relay.nn.bias_add(relay.nn.dense(x, w), b)
mod["main"] = relay.Function([x, w, b], y)
vm = relay.vm.compile(mod, 'llvm')
vm.init(tvm.cpu())
x_np = np.random.uniform(size=(10, 5)).astype('float32')
w_np = np.random.uniform(size=(6, 5)).astype('float32')
b_np = np.random.uniform(size=(6,)).astype('float32')
ref_np = np.dot(x_np, w_np.T) + b_np
params = {'w': w_np}
vm.load_params(params)
out = vm.run(x_np, b_np)
tvm.testing.assert_allclose(out.asnumpy(), ref_np)
if __name__ == "__main__": if __name__ == "__main__":
test_id() test_id()
...@@ -626,4 +609,3 @@ if __name__ == "__main__": ...@@ -626,4 +609,3 @@ if __name__ == "__main__":
test_add_op_scalar() test_add_op_scalar()
test_add_op_tensor() test_add_op_tensor()
test_add_op_broadcast() test_add_op_broadcast()
test_set_params()
...@@ -22,29 +22,25 @@ import tvm ...@@ -22,29 +22,25 @@ import tvm
from tvm import relay from tvm import relay
from tvm.relay.module import Module as rly_module from tvm.relay.module import Module as rly_module
from tvm.relay import vm as _vm from tvm.relay import vm as _vm
from tvm.relay import serializer, deserializer
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.contrib import util from tvm.contrib import util
from tvm.relay import testing from tvm.relay import testing
def create_vm(f, ctx=tvm.cpu(), target="llvm", params=None): def create_exec(f, target="llvm", params=None):
if isinstance(f, relay.Expr): if isinstance(f, relay.Expr):
mod = relay.Module() mod = relay.Module()
mod["main"] = f mod["main"] = f
vm = _vm.compile(mod, target=target, params=params) executable = _vm.compile(mod, target=target, params=params)
vm.init(ctx) return executable
return vm
else: else:
assert isinstance(f, relay.Module), "expected mod as relay.Module" assert isinstance(f, relay.Module), "expected mod as relay.Module"
vm = _vm.compile(f, target=target, params=params) executable = _vm.compile(f, target=target, params=params)
vm.init(ctx) return executable
return vm
def veval(vm, *args, ctx=tvm.cpu()): def veval(vm, *args, ctx=tvm.cpu()):
assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine" assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine"
vm.init(ctx)
ret = vm.run(*args) ret = vm.run(*args)
return ret return ret
...@@ -59,13 +55,11 @@ def run_network(mod, ...@@ -59,13 +55,11 @@ def run_network(mod,
return result.asnumpy().astype(dtype) return result.asnumpy().astype(dtype)
def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): def get_serialized_output(mod, data, params, target, ctx, dtype='float32'):
vm = create_vm(mod, ctx, target, params=params) exe = create_exec(mod, target, params=params)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize()
des_vm.init(ctx) des_vm.init(ctx)
des_vm.load_params(params)
result = des_vm.run(data) result = des_vm.run(data)
return result.asnumpy().astype(dtype) return result.asnumpy().astype(dtype)
...@@ -99,26 +93,25 @@ def test_serializer(): ...@@ -99,26 +93,25 @@ def test_serializer():
main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1)) main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1))
mod["main"] = main mod["main"] = main
vm = create_vm(mod) exe = create_exec(mod)
ser = serializer.Serializer(vm)
glbs = ser.globals glbs = exe.globals
assert len(glbs) == 3 assert len(glbs) == 3
assert "f1" in glbs assert "f1" in glbs
assert "f2" in glbs assert "f2" in glbs
assert "main" in glbs assert "main" in glbs
prim_ops = ser.primitive_ops prim_ops = exe.primitive_ops
assert any(item.startswith('fused_add') for item in prim_ops) assert any(item.startswith('fused_add') for item in prim_ops)
assert any(item.startswith('fused_subtract') for item in prim_ops) assert any(item.startswith('fused_subtract') for item in prim_ops)
assert any(item.startswith('fused_multiply') for item in prim_ops) assert any(item.startswith('fused_multiply') for item in prim_ops)
code = ser.bytecode code = exe.bytecode
assert "main 5 2 5" in code assert "main 5 2 5" in code
assert "f1 2 1 3" in code assert "f1 2 1 3" in code
assert "f2 2 1 3" in code assert "f2 2 1 3" in code
code, lib = ser.serialize() code, lib = exe.save()
assert isinstance(code, bytearray) assert isinstance(code, bytearray)
assert isinstance(lib, tvm.module.Module) assert isinstance(lib, tvm.module.Module)
...@@ -129,24 +122,24 @@ def test_save_load(): ...@@ -129,24 +122,24 @@ def test_save_load():
x_data = np.random.rand(10, 10).astype('float32') x_data = np.random.rand(10, 10).astype('float32')
# serialize. # serialize.
vm = create_vm(f) vm = create_exec(f)
ser = serializer.Serializer(vm) code, lib = vm.save()
code, lib = ser.serialize()
assert isinstance(code, bytearray) assert isinstance(code, bytearray)
# save and load the code and lib file. # save and load the code and lib file.
tmp = util.tempdir() tmp = util.tempdir()
path_lib = tmp.relpath("lib.so") path_lib = tmp.relpath("lib.so")
lib.export_library(path_lib) lib.export_library(path_lib)
with open(tmp.relpath("code.bc"), "wb") as fo: with open(tmp.relpath("code.ro"), "wb") as fo:
fo.write(code) fo.write(code)
loaded_lib = tvm.module.load(path_lib) loaded_lib = tvm.module.load(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read())
# deserialize. # deserialize.
deser = deserializer.Deserializer(loaded_code, loaded_lib) des_exec = _vm.Executable.load_exec(loaded_code, loaded_lib)
des_vm = deser.deserialize() des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
res = veval(des_vm, x_data) res = veval(des_vm, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data)
...@@ -156,12 +149,12 @@ def test_const(): ...@@ -156,12 +149,12 @@ def test_const():
c = relay.const(1.0, "float32") c = relay.const(1.0, "float32")
x = relay.var('x', shape=(10, 10), dtype='float32') x = relay.var('x', shape=(10, 10), dtype='float32')
f = relay.Function([x], x + c) f = relay.Function([x], x + c)
vm = create_vm(f) exe = create_exec(f)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize()
assert isinstance(code, bytearray) assert isinstance(code, bytearray)
deser = deserializer.Deserializer(code, lib) des_exec = _vm.Executable.load_exec(code, lib)
des_vm = deser.deserialize() des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
x_data = np.random.rand(10, 10).astype('float32') x_data = np.random.rand(10, 10).astype('float32')
res = veval(des_vm, x_data) res = veval(des_vm, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + 1) tvm.testing.assert_allclose(res.asnumpy(), x_data + 1)
...@@ -177,11 +170,11 @@ def test_if(): ...@@ -177,11 +170,11 @@ def test_if():
x_data = np.random.rand(10, 10).astype('float32') x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32') y_data = np.random.rand(10, 10).astype('float32')
vm = create_vm(f) exe = create_exec(f)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize() des_vm.init(tvm.cpu())
# same # same
res = veval(des_vm, x_data, x_data) res = veval(des_vm, x_data, x_data)
...@@ -213,11 +206,11 @@ def test_loop(): ...@@ -213,11 +206,11 @@ def test_loop():
aarg = relay.var('accum', shape=[], dtype='int32') aarg = relay.var('accum', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
vm = create_vm(mod) exe = create_exec(mod)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize() des_vm.init(tvm.cpu())
result = veval(des_vm, i_data, accum_data) result = veval(des_vm, i_data, accum_data)
tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1))) tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
...@@ -230,11 +223,11 @@ def test_tuple(): ...@@ -230,11 +223,11 @@ def test_tuple():
i_data = np.random.rand(41).astype('float32') i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32') j_data = np.random.rand(10).astype('float32')
vm = create_vm(f) exe = create_exec(f)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize() des_vm.init(tvm.cpu())
result = veval(des_vm, (i_data, j_data)) result = veval(des_vm, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), j_data) tvm.testing.assert_allclose(result.asnumpy(), j_data)
...@@ -251,11 +244,11 @@ def test_adt_list(): ...@@ -251,11 +244,11 @@ def test_adt_list():
f = relay.Function([], l321) f = relay.Function([], l321)
mod["main"] = f mod["main"] = f
vm = create_vm(mod) exe = create_exec(mod)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize() des_vm.init(tvm.cpu())
result = veval(des_vm) result = veval(des_vm)
assert len(result) == 2 assert len(result) == 2
...@@ -297,11 +290,11 @@ def test_adt_compose(): ...@@ -297,11 +290,11 @@ def test_adt_compose():
f = relay.Function([y], add_two_body) f = relay.Function([y], add_two_body)
mod["main"] = f mod["main"] = f
vm = create_vm(mod) exe = create_exec(mod)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize() des_vm.init(tvm.cpu())
x_data = np.array(np.random.rand()).astype('float32') x_data = np.array(np.random.rand()).astype('float32')
result = veval(des_vm, x_data) result = veval(des_vm, x_data)
...@@ -317,11 +310,11 @@ def test_closure(): ...@@ -317,11 +310,11 @@ def test_closure():
clo = ff(relay.const(1.0)) clo = ff(relay.const(1.0))
main = clo(relay.const(2.0)) main = clo(relay.const(2.0))
vm = create_vm(main) exe = create_exec(main)
ser = serializer.Serializer(vm) code, lib = exe.save()
code, lib = ser.serialize() des_exec = _vm.Executable.load_exec(code, lib)
deser = deserializer.Deserializer(code, lib) des_vm = _vm.VirtualMachine(des_exec)
des_vm = deser.deserialize() des_vm.init(tvm.cpu())
res = veval(des_vm) res = veval(des_vm)
tvm.testing.assert_allclose(res.asnumpy(), 3.0) tvm.testing.assert_allclose(res.asnumpy(), 3.0)
......
...@@ -26,9 +26,9 @@ def test_basic(): ...@@ -26,9 +26,9 @@ def test_basic():
mod, params = resnet.get_workload() mod, params = resnet.get_workload()
target = 'llvm' target = 'llvm'
ctx = tvm.cpu() ctx = tvm.cpu()
vm = relay.profiler_vm.compile(mod, target) exe = relay.profiler_vm.compile(mod, target, params=params)
vm = relay.profiler_vm.VirtualMachineProfiler(exe)
vm.init(ctx) vm.init(ctx)
vm.load_params(params)
data = np.random.rand(1, 3, 224, 224).astype('float32') data = np.random.rand(1, 3, 224, 224).astype('float32')
res = vm.invoke("main", [data]) res = vm.invoke("main", [data])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment