Commit 4052de6d by Zhi Committed by Haichen Shen

[relay][vm] Separate VM runtime with executable (#4100)

* [relay][vm] Separate VM runtime with executable

* Address comments

* move ctx back to vm

* make only vm related fields and methods protected

* integrate seriliaztion/deserialization to executable

* create stream
parent cf046972
......@@ -26,6 +26,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <memory>
#include <string>
#include <unordered_map>
......@@ -430,15 +431,184 @@ struct VMFrame {
caller_return_register(0) {}
};
/*! \brief The executable emitted by the VM compiler.
*
* The executable contains information (e.g. data in different memory regions)
* to run in a virtual machine.
*
* - Global section, containing all globals.
* - Constant section, storing the constant pool.
* - Primitive name section, containing the function name of the primitive ops
* used by the virtual machine.
* - Code section, handling the VM functions and bytecode.
*/
class Executable : public ModuleNode {
public:
/*!
* \brief Get a PackedFunc from an executable module.
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
*
* \return PackedFunc or nullptr when it is not available.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
/*!
* \brief Serialize the executable into global section, constant section, and
* code section.
*
* \return The binary representation of the VM.
*/
TVMByteArray Save();
/*!
* \brief Load the saved VM executable.
*
* \param code The bytecode in string.
* \param lib The compiled runtime library.
*
* \return exe The constructed executable.
*/
static runtime::Module Load(const std::string& code, const runtime::Module lib);
/*!
* \brief Get the serialized form of the `functions`. This is
* essentially bytecode serialization.
*
* \return The serialized vm bytecode.
*
* \note The bytecode is in the following format:
* func_name reg_file_size num_instructions
* param1 param2 ... paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Each instruction is printed in the following format:
* opcode num_fields field1 ... fieldX # The text format.
*
* Serializing an `Instruction` requires us to deal with the bytecode. Each line
* of the instructions could be serialized as the following format:
* hash, opcode, f1, f2, ..., fX, field with variable length
* 1. hash: the hash of the instruction. This number will be used to help us
* validate if an instruction is well-formed during deserialization.
* 2. opcode: the opcode code of the instruction.
* 3. f1, f2, ..., fX. These fields together represent the fixed fields in
* an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For
* example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
* 4. The rest of the line indicates the field with variable length, e.g.,
* the shape of a tensor, the args used by an `InvokPacked` instruction, etc.
* The field starting from # is only used for debugging. The serialized code
* doesn't contain it, therefore the deserializer doens't need to handle it.
*/
std::string GetBytecode() const;
/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
*/
std::string Stats() const;
/*! \brief Get the `lib` module in an executable. Users have the flexibility to call
* `export_library` from the frontend to save the library to disk.
*
* \return The runtime module that contains the hardwre dependent code.
*/
runtime::Module GetLib() const { return lib; }
virtual ~Executable() {}
const char* type_key() const final {
return "VMExecutable";
}
/*! \brief The runtime module/library that contains both the host and also the device
* code when executing on non-CPU devices. */
runtime::Module lib;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
* corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
*/
std::unordered_map<std::string, Index> primitive_map;
/*! \brief The virtual machine's function table. */
std::vector<VMFunction> functions;
private:
/*!
* \brief Save the globals.
*
* \param strm The input stream.
*/
void SaveGlobalSection(dmlc::Stream* strm);
/*!
* \brief Save the constant pool.
*
* \param strm The input stream.
*/
void SaveConstantSection(dmlc::Stream* strm);
/*!
* \brief Save primitive op names.
*
* \param strm The input stream.
*/
void SavePrimitiveOpNames(dmlc::Stream* strm);
/*!
* \brief Save the vm functions.
*
* \param strm The input stream.
*/
void SaveCodeSection(dmlc::Stream* strm);
/*!
* \brief Load the globals.
*
* \param strm The input stream.
*/
void LoadGlobalSection(dmlc::Stream* strm);
/*!
* \brief Load the constant pool.
*
* \param strm The input stream.
*/
void LoadConstantSection(dmlc::Stream* strm);
/*!
* \brief Load primitive op names.
*
* \param strm The input stream.
*/
void LoadPrimitiveOpNames(dmlc::Stream* strm);
/*!
* \brief Load the vm functions.
*
* \param strm The input stream.
*/
void LoadCodeSection(dmlc::Stream* strm);
/*! \brief The serialized bytecode. */
std::string code_;
};
/*! \brief The virtual machine.
*
* The virtual machine contains all the current execution state,
* as well as the global view of functions, the global constant
* table, the compiled operators.
* as well as the executable.
*
* The goal is to have a single self-contained object,
* enabling one to easily pass around VMs, execute them on
* multiple threads, or serialized them to disk or over the
* multiple threads, or serialize them to disk or over the
* wire.
*/
class VirtualMachine : public runtime::ModuleNode {
......@@ -486,16 +656,18 @@ class VirtualMachine : public runtime::ModuleNode {
return "VirtualMachine";
}
/*! \brief The runtime module/library that contains generated code. */
runtime::Module lib;
VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {}
/*! \brief load the executable for the virtual machine.
* \param exec The executable.
*/
void LoadExecutable(const Executable* exec);
protected:
/*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs;
/*! \brief The virtual machine's function table. */
std::vector<VMFunction> functions;
/*! \brief The current stack of call frames. */
std::vector<VMFrame> frames;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief The fuction table index of the current function. */
Index func_index;
/*! \brief The current pointer to the code section. */
......@@ -506,6 +678,9 @@ class VirtualMachine : public runtime::ModuleNode {
/*! \brief The special return register. */
ObjectRef return_register;
/*! \brief The executable the VM will operate on. */
const Executable* exec;
/*! \brief The set of TVM contexts the VM is currently executing on. */
std::vector<TVMContext> ctxs;
......@@ -550,8 +725,6 @@ class VirtualMachine : public runtime::ModuleNode {
*/
ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args);
VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {}
/*! \brief Initialize the virtual machine for a set of contexts.
* \param contexts The set of TVM contexts.
*/
......@@ -565,21 +738,6 @@ class VirtualMachine : public runtime::ModuleNode {
*/
TVMContext GetParamsContext() const;
/*!
* \brief Load parameters from the parameter bytearray.
* \param params The binary file that contains parameters.
*/
void LoadParams(const std::string& params);
/*! \brief A map from globals (as strings) to their index in the function map.
*/
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
* corresponds to the position of the `packed_funcs` list.
*/
std::unordered_map<std::string, Index> primitive_map;
private:
/*! \brief Invoke a global setting up the VM state to execute.
*
......
......@@ -37,8 +37,6 @@ from . import param_dict
from . import feature
from .backend import vm
from .backend import profiler_vm
from .backend import serializer
from .backend import deserializer
from .backend import vmobj
# Root operators
......
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
The Relay Virtual Machine deserializer.
Python interface for deserializing a Relay VM.
"""
from tvm import module
from tvm._ffi.runtime_ctypes import TVMByteArray
from . import _vm
from . import vm as rly_vm
def _create_deserializer(code, lib):
"""Create a deserializer object.
Parameters
----------
code : bytearray
The serialized virtual machine code.
lib : :py:class:`~tvm.module.Module`
The serialized runtime module/library that contains the hardware
dependent binary code.
Returns
-------
ret : Deserializer
The created virtual machine deserializer.
"""
if isinstance(code, (bytes, str)):
code = bytearray(code)
elif not isinstance(code, (bytearray, TVMByteArray)):
raise TypeError("vm is expected to be the type of bytearray or " +
"TVMByteArray, but received {}".format(type(code)))
if not isinstance(lib, module.Module):
raise TypeError("lib is expected to be the type of tvm.module.Module" +
", but received {}".format(type(lib)))
return _vm._Deserializer(code, lib)
class Deserializer:
"""Relay VM deserializer.
Parameters
----------
code : bytearray
The serialized virtual machine code.
lib : :py:class:`~tvm.module.Module`
The serialized runtime module/library that contains the hardware
dependent binary code.
"""
def __init__(self, code, lib):
self.mod = _create_deserializer(code, lib)
self._deserialize = self.mod["deserialize"]
def deserialize(self):
"""Deserialize the serialized bytecode into a Relay VM.
Returns
-------
ret : VirtualMachine
The deserialized Relay VM.
"""
return rly_vm.VirtualMachine(self._deserialize())
......@@ -49,8 +49,8 @@ def compile(mod, target=None, target_host=None, params=None):
Returns
-------
vm : VirtualMachineProfiler
The profile VM runtime.
exec : Executable
The executable with profiling code.
"""
compiler = VMCompilerProfiler()
target = compiler.update_target(target)
......@@ -60,7 +60,7 @@ def compile(mod, target=None, target_host=None, params=None):
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return VirtualMachineProfiler(compiler._get_vm())
return vm.Executable(compiler._get_exec())
class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime."""
......@@ -68,13 +68,17 @@ class VMCompilerProfiler(vm.VMCompiler):
super().__init__()
self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]
self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"]
class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime."""
def __init__(self, mod):
super().__init__(mod)
m = mod.module if isinstance(mod, vm.Executable) else mod
self.mod = _vm._VirtualMachineDebug(m)
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
self._get_stat = self.mod["get_stat"]
def get_stat(self):
......
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
The Relay Virtual Machine serializer.
Python interface for serializing a Relay VM.
"""
import tvm
from . import _vm
from . import vm as rly_vm
def _create_serializer(vm):
"""Create a VM serializer.
Parameters
----------
vm : Union[VirtualMachine, :py:class:`~tvm.module.Module`]
The virtual machine to be serialized.
Returns
-------
ret : Serializer
The created virtual machine serializer.
"""
if isinstance(vm, rly_vm.VirtualMachine):
vm = vm.module
elif not isinstance(vm, tvm.module.Module):
raise TypeError("vm is expected to be the type of VirtualMachine or " +
"tvm.Module, but received {}".format(type(vm)))
return _vm._Serializer(vm)
class Serializer:
"""Relay VM serializer."""
def __init__(self, vm):
self.mod = _create_serializer(vm)
self._get_lib = self.mod["get_lib"]
self._get_bytecode = self.mod["get_bytecode"]
self._get_globals = self.mod["get_globals"]
self._get_stats = self.mod["get_stats"]
self._get_primitive_ops = self.mod["get_primitive_ops"]
self._serialize = self.mod["serialize"]
@property
def stats(self):
"""Get the statistics of the Relay VM.
Returns
-------
ret : String
The serialized statistic information.
"""
return self._get_stats()
@property
def primitive_ops(self):
"""Get the name of the primitive ops that are executed in the VM.
Returns
-------
ret : List[:py:class:`~tvm.expr.StringImm`]
The list of primitive ops.
"""
return [prim_op.value for prim_op in self._get_primitive_ops()]
@property
def bytecode(self):
"""Get the bytecode of the Relay VM.
Returns
-------
ret : String
The serialized bytecode.
Notes
-----
The bytecode is in the following format:
func_name reg_file_size num_instructions
param1 param2 ... paramM
instruction1
instruction2
...
instructionN
Each instruction is printed in the following format:
hash opcode field1 ... fieldX # The text format.
The part starting from # is only used for visualization and debugging.
The real serialized code doesn't contain it, therefore the deserializer
doesn't need to deal with it as well.
"""
return self._get_bytecode()
@property
def globals(self):
"""Get the globals used by the Relay VM.
Returns
-------
ret : List[:py:class:`~tvm.expr.StringImm`]
The serialized globals.
"""
return [glb.value for glb in self._get_globals()]
def serialize(self):
"""Serialize the Relay VM.
Returns
-------
code : bytearray
The binary blob representing a serialized Relay VM. It can then be
saved to disk and later deserialized into a new VM.
lib : :py:class:`~tvm.module.Module`
The runtime module that contains the generated code. It is
basically a library that is composed of hardware dependent code.
Notes
-----
The returned code is organized with the following sections in order.
- Global section. This section contains the globals used by the
virtual machine.
- Constant section. This section is used to store the constant pool of
a virtual machine.
- Primitive name section. This section is introduced to accommodate
the list of primitive operator names that will be invoked by the
virtual machine.
- Code section. The VM functions, including bytecode, are sitting in
this section.
Examples
--------
.. code-block:: python
import numpy as np
import tvm
from tvm import relay
# define a simple network.
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x + x)
mod = relay.Module({"main": f})
# create a Relay VM.
ctx = tvm.cpu()
target = "llvm"
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(ctx)
# serialize.
ser = relay.serializer.Serializer(vm)
code, lib = ser.serialize()
# save and load the code and lib file.
tmp = tvm.contrib.util.tempdir()
path_lib = tmp.relpath("lib.so")
lib.export_library(path_lib)
with open(tmp.relpath("code.bc"), "wb") as fo:
fo.write(code)
loaded_lib = tvm.module.load(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read())
# deserialize.
deser = relay.deserializer.Deserializer(loaded_code, loaded_lib)
des_vm = deser.deserialize()
# execute the deserialized vm.
des_vm.init(ctx)
x_data = np.random.rand(10, 10).astype('float32')
res = des_vm.run(x_data)
print(res.asnumpy())
"""
return self._serialize(), self._get_lib()
......@@ -783,9 +783,9 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
Module mod = args[0];
this->Compile(mod, args[1], args[2]);
});
} else if (name == "get_vm") {
} else if (name == "get_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = runtime::Module(vm_);
*rv = runtime::Module(exec_);
});
} else if (name == "set_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
......@@ -864,7 +864,7 @@ void VMCompiler::Compile(Module mod,
// Next we get ready by allocating space for
// the global state.
vm_->functions.resize(context_.module->functions.size());
exec_->functions.resize(context_.module->functions.size());
for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
......@@ -873,25 +873,25 @@ void VMCompiler::Compile(Module mod,
auto vm_func = func_compiler.Compile(gvar, func);
size_t func_index = context_.global_map.at(gvar);
CHECK(func_index < vm_->functions.size());
vm_->functions[func_index] = vm_func;
CHECK(func_index < exec_->functions.size());
exec_->functions[func_index] = vm_func;
}
#if USE_RELAY_DEBUG
for (auto vm_func : vm_->functions) {
for (auto vm_func : exec_->functions) {
DLOG(INFO) << vm_func << "-------------";
}
#endif // USE_RELAY_DEBUG
// populate constants
for (auto data : context_.constants) {
vm_->constants.push_back(runtime::vm::Tensor(data));
exec_->constants.push_back(runtime::vm::Tensor(data));
}
LibraryCodegen();
for (auto gv : context_.global_map) {
vm_->global_map.insert({gv.first->name_hint, gv.second});
exec_->global_map.insert({gv.first->name_hint, gv.second});
}
}
......@@ -987,13 +987,13 @@ void VMCompiler::LibraryCodegen() {
// therefore target won't be used in the build function
runtime::Module mod = (*f)(funcs, Target(), target_host_);
CHECK(mod.operator->());
vm_->lib = mod;
exec_->lib = mod;
} else {
LOG(FATAL) << "relay.backend.build is not registered";
}
size_t primitive_index = 0;
for (auto cfunc : cached_funcs) {
vm_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
......
......@@ -92,12 +92,8 @@ class VMCompiler : public runtime::ModuleNode {
return "VMCompiler";
}
std::shared_ptr<VirtualMachine> GetVirtualMachine() const {
return vm_;
}
virtual void InitVM() {
vm_ = std::make_shared<VirtualMachine>();
void InitVM() {
exec_ = std::make_shared<Executable>();
}
/*!
......@@ -144,8 +140,8 @@ class VMCompiler : public runtime::ModuleNode {
tvm::Target target_host_;
/*! \brief Global shared meta data */
VMCompilerContext context_;
/*! \brief Compiled virtual machine. */
std::shared_ptr<VirtualMachine> vm_;
/*! \brief Compiled executable. */
std::shared_ptr<Executable> exec_;
/*! \brief parameters */
std::unordered_map<std::string, runtime::NDArray> params_;
};
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/deserializer.h
* \brief Define a deserializer for the serialized Relay VM.
*/
#ifndef TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
#define TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
#include <dmlc/memory_io.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/vm.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace relay {
namespace vm {
using namespace tvm::runtime::vm;
namespace runtime = tvm::runtime;
class Deserializer : public runtime::ModuleNode {
public:
/*!
* \brief Initialize the deserializer for creating a virtual machine object.
*
* \param code The serialized code.
* \param lib The serialized runtime module/library that contains the
* hardware dependent code.
*/
inline void Init(const std::string& code, const runtime::Module& lib);
/*!
* \brief Return the member function to the frontend.
*
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
*
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final { return "Deserializer"; }
/*! \brief Deserialize the serialized VM. */
void Deserialize();
virtual ~Deserializer() { delete strm_; }
private:
/*! \brief Deserialize the globals in `vm_`. */
void DeserializeGlobalSection();
/*! \brief Deserialize the constant pool in `vm_`. */
void DeserializeConstantSection();
/*! \brief Deserialize primitive op names in `vm_`. */
void DeserializePrimitiveOpNames();
/*! \brief Deserialize the vm functions in `vm_`. */
void DeserializeCodeSection();
/*! \brief The code to be serialized. */
std::string code_;
/*! \brief The stream used for serialization. */
dmlc::Stream* strm_;
/*! \brief The VM to be created. */
std::shared_ptr<VirtualMachine> vm_;
};
} // namespace vm
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
......@@ -33,7 +33,6 @@ namespace vm {
class VMCompilerDebug : public VMCompiler {
public:
VMCompilerDebug() {}
void InitVM() override { vm_ = std::make_shared<VirtualMachineDebug>(); }
virtual ~VMCompilerDebug() {}
};
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/serializer.h
* \brief Define a serializer for the Relay VM.
*
* The following components of a Relay VM will be serialized:
* - The `constants`, e.g., the constant pool, that contains the
* constants used in a Relay program.
* - The `packed_funcs` that essentially contains the generated code for
* a specific target. We return it as a runtime module that can be exported as
* a library file (e.g., .so, .o, or .tar).
* - The `global_map` that contains the globals.
* - The `primitive_map` that contains the name of individual primitive operators.
* - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of
* a list of instructions/bytecode.
*
* Note that only the library is returned as a separate module. All othere parts
* are stored in a single serialized code that is organized with the following
* sections in order.
* - Global section, containing all globals.
* - Constant section, storing the constant pool.
* - Primitive name section, containing the function name of the primitive ops
* used by the virtual machine.
* - Code section, handling the VM functions and bytecode.
*
* The code section is again organized as follows for each VM function:
* func_name, register_file_size, num_instructions (N)
* param1, param2, ..., paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Serializing an `Instruction` requires us to deal with the bytecode. Each line
* of the instructions could be serialized as the following format:
* hash, opcode, f1, f2, ..., fX, field with variable length
* 1. hash: the hash of the instruction. This number will be used to help us
* validate if an instruction is well-formed during deserialization.
* 2. opcode: the opcode code of the instruction.
* 3. f1, f2, ..., fX. These fields together represent the fixed fields in
* an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For
* example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
* 4. The rest of the line indicates the field with variable length, e.g.,
* the shape of a tensor, the args used by an `InvokPacked` instruction, etc.
*/
#ifndef TVM_RELAY_BACKEND_VM_SERIALIZER_H_
#define TVM_RELAY_BACKEND_VM_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <tvm/ir.h>
#include <tvm/node/container.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/vm.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace relay {
namespace vm {
using namespace tvm::runtime;
using namespace tvm::runtime::vm;
/*!
* \brief The Relay VM serializer.
*/
class Serializer : public runtime::ModuleNode {
public:
/*!
* \brief Initialize the serializer for a virtual machine.
*
* \param vm The Relay virtual machine.
*/
inline void Init(const VirtualMachine* vm);
/*!
* \brief Return the member function to the frontend.
*
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
*
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final { return "Serializer"; }
/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
*/
std::string Stats() const;
/*!
* \brief Serialize the `vm_` into global section, constant section, and code
* section.
*
* \return The binary representation of the VM.
*/
TVMByteArray Serialize();
/*!
* \brief Get a list of the globals used by the `_vm`.
*
* \return The global map in the form a list.
*/
tvm::Array<tvm::Expr> GetGlobals() const;
/*!
* \brief Get the primitive operators that are contained in the Relay VM.
*
* \return The list of primitve operators.
*/
tvm::Array<tvm::Expr> GetPrimitiveOps() const;
/*!
* \brief Get the serialized form of the `functions` in `vm_`. This is
* essentially bytecode serialization.
*
* \return The serialized vm bytecode.
*
* \note The bytecode is in the following format:
* func_name reg_file_size num_instructions
* param1 param2 ... paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Each instruction is printed in the following format:
* opcode num_fields field1 ... fieldX # The text format.
*
* The field starting from # is only used for debugging. The serialized code
* doesn't contain it, therefore the deserializer doens't need to handle it.
*/
std::string GetBytecode() const;
/*! \brief Get the `lib` module in vm_. Serialization of `runtime::module`
* has already been supported by TVM. Therefore, we only return the runtime
* module and let users have the flexibility to call `export_library` from
* the frontend to save the library to disk.
*
* \return The runtime module that contains the hardwre dependent code.
*/
inline runtime::Module GetLib() const;
virtual ~Serializer() { delete strm_; }
private:
/*! \brief Serialize the globals in vm_. */
void SerializeGlobalSection();
/*! \brief Serialize the constant pool in vm_. */
void SerializeConstantSection();
/*! \brief Serialize primitive op names in vm_. */
void SerializePrimitiveOpNames();
/*! \brief Serialize the vm functions in vm_. */
void SerializeCodeSection();
/*! \brief The Relay virtual machine for to be serialized. */
const VirtualMachine* vm_;
/*! \brief The stream used for serialization. */
dmlc::Stream* strm_;
/*! \brief The serialized code. */
std::string code_;
};
} // namespace vm
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_SERIALIZER_H_
......@@ -85,19 +85,25 @@ PackedFunc VirtualMachineDebug::GetFunction(
}
}
void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
VirtualMachine::Init(ctxs);
for (auto kv : primitive_map) {
void VirtualMachineDebug::LoadExecutable(const Executable* exec) {
VirtualMachine::LoadExecutable(exec);
CHECK(this->exec);
for (auto kv : this->exec->primitive_map) {
packed_index_map[kv.second] = kv.first;
op_invokes[kv.second] = 0;
}
}
void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
VirtualMachine::Init(ctxs);
}
void VirtualMachineDebug::InvokePacked(Index packed_index,
const PackedFunc& func, Index arg_count,
Index output_size,
const std::vector<ObjectRef>& args) {
auto ctx = VirtualMachine::GetParamsContext();
CHECK(this->exec);
auto ctx = this->GetParamsContext();
// warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args);
......@@ -117,6 +123,21 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
op_invokes[packed_index] += 1;
}
runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
std::shared_ptr<VirtualMachineDebug> vm = std::make_shared<VirtualMachineDebug>();
vm->LoadExecutable(exec);
return runtime::Module(vm);
}
TVM_REGISTER_GLOBAL("relay._vm._VirtualMachineDebug")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec) << "Virtual machine has not been defined yet."
<< "\n";
*rv = CreateVirtualMachineDebug(exec);
});
} // namespace vm
} // namespace runtime
} // namespace tvm
......@@ -47,6 +47,8 @@ class VirtualMachineDebug : public VirtualMachine {
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<ObjectRef>& args) final;
void LoadExecutable(const Executable* exec);
~VirtualMachineDebug() {}
private:
......
......@@ -19,11 +19,11 @@
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/serialize_util.h
* \file src/runtime/vm/serialize_util.h
* \brief Definitions of helpers for serializing and deserializing a Relay VM.
*/
#ifndef TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_
#define TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_
#ifndef TVM_RUNTIME_VM_SERIALIZE_UTIL_H_
#define TVM_RUNTIME_VM_SERIALIZE_UTIL_H_
#include <dmlc/common.h>
#include <dmlc/memory_io.h>
......@@ -34,7 +34,7 @@
#include <vector>
namespace tvm {
namespace relay {
namespace runtime {
namespace vm {
/*! \brief The magic number for the serialized VM bytecode file */
......@@ -158,7 +158,7 @@ struct VMInstructionSerializer {
};
} // namespace vm
} // namespace relay
} // namespace runtime
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_
#endif // TVM_RUNTIME_VM_SERIALIZE_UTIL_H_
......@@ -575,11 +575,12 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(exec) << "The executable is not created yet.";
std::string func_name = args[0];
auto gvit = this->global_map.find(func_name);
CHECK(gvit != this->global_map.end()) << "Cannot find function " << func_name;
auto gvit = exec->global_map.find(func_name);
CHECK(gvit != exec->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second;
const auto& vm_func = this->functions[func_index];
const auto& vm_func = exec->functions[func_index];
const auto& param_names = vm_func.params;
auto ctx = this->GetParamsContext();
......@@ -617,10 +618,6 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
}
this->Init(contexts);
});
} else if (name == "load_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->LoadParams(args[0].operator std::string());
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
......@@ -628,6 +625,9 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
}
TVMContext VirtualMachine::GetParamsContext() const {
CHECK(!ctxs.empty()) << "Context has not been initialized yet."
<< "\n";
// Use the fallback device if no device index is available.
int fallback_device_type = static_cast<int>(ctxs[0].device_type);
// TODO(wweic): For heterogeneous execution, get device information from byte
......@@ -639,32 +639,6 @@ TVMContext VirtualMachine::GetParamsContext() const {
return (cit == ctxs.end() ? ctxs[0] : *cit);
}
void VirtualMachine::LoadParams(const std::string& params) {
dmlc::MemoryStringStream mss(const_cast<std::string*>(&params));
dmlc::Stream* strm = &mss;
uint64_t header, reserved;
CHECK(strm->Read(&header)) << "Invalid parameter file";
CHECK(header == kTVMNDArrayListMagic) << "Invalid parameter file";
CHECK(strm->Read(&reserved)) << "Invalid parameter file";
std::vector<std::string> names;
CHECK(strm->Read(&names)) << "Invalid parameter file";
uint64_t sz;
strm->Read(&sz);
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) << "Invalid parameter file";
auto ctx = GetParamsContext();
for (size_t i = 0; i < size; i++) {
NDArray arr;
CHECK(arr.Load(strm)) << "Invalid parameter file";
ObjectRef obj = Tensor(arr);
auto copy = CopyTo(obj, ctx);
params_.emplace(std::make_pair(names[i], copy));
}
}
void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) {
auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size);
frames.push_back(frame);
......@@ -699,15 +673,17 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<Objec
InvokeGlobal(func, args);
RunLoop();
// TODO(wweic) ctx could be obtained from the ctxs list.
auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]);
DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B";
return return_register;
}
ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<ObjectRef>& args) {
auto func_index = this->global_map[name];
CHECK(exec) << "The executable has not been created yet.";
auto func_index = exec->global_map.at(name);
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index;
return Invoke(this->functions[func_index], args);
return Invoke(exec->functions[func_index], args);
}
void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
......@@ -744,14 +720,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
}
void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
this->ctxs = ctxs;
void VirtualMachine::LoadExecutable(const Executable* exec) {
CHECK(exec) << "The executable is not created yet.";
this->exec = exec;
runtime::Module lib = this->exec->lib;
// Get the list of packed functions.
CHECK(primitive_map.empty() || lib.operator->())
CHECK(exec->primitive_map.empty() || lib.operator->())
<< "runtime module should have been built for primitive functions"
<< "\n";
for (const auto& it : primitive_map) {
for (const auto& it : this->exec->primitive_map) {
const auto& packed_name = it.first;
auto packed_index = static_cast<size_t>(it.second);
if (packed_funcs.size() <= packed_index) {
......@@ -761,6 +739,11 @@ void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
}
}
void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
this->ctxs = ctxs;
}
inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) {
frames.back().register_file[r] = val;
}
......@@ -788,6 +771,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
void VirtualMachine::RunLoop() {
CHECK(this->code);
CHECK(this->exec);
this->pc = 0;
Index frame_start = frames.size();
while (true) {
......@@ -810,7 +794,8 @@ void VirtualMachine::RunLoop() {
throw std::runtime_error("VM encountered fatal error");
}
case Opcode::LoadConst: {
auto constant_obj = this->constants[instr.const_index];
auto constant_obj = exec->constants[instr.const_index];
// TODO(wweic) ctx could be obtained from the ctxs list.
auto device_obj = CopyTo(constant_obj, ctxs[0]);
WriteRegister(instr.dst, device_obj);
pc++;
......@@ -828,7 +813,7 @@ void VirtualMachine::RunLoop() {
for (Index i = 0; i < instr.num_args; ++i) {
args.push_back(ReadRegister(instr.invoke_args_registers[i]));
}
InvokeGlobal(this->functions[instr.func_index], args);
InvokeGlobal(exec->functions[instr.func_index], args);
frames.back().caller_return_register = instr.dst;
goto main_loop;
}
......@@ -858,7 +843,7 @@ void VirtualMachine::RunLoop() {
for (Index i = 0; i < instr.num_closure_args; ++i) {
args.push_back(ReadRegister(instr.closure_args[i]));
}
InvokeGlobal(this->functions[closure->func_index], args);
InvokeGlobal(exec->functions[closure->func_index], args);
frames.back().caller_return_register = instr.dst;
goto main_loop;
}
......@@ -910,6 +895,7 @@ void VirtualMachine::RunLoop() {
for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) {
shape[i] = instr.alloc_tensor.shape[i];
}
// TODO(wweic) ctx could be obtained from the ctxs list.
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]);
auto obj = Tensor(data);
......@@ -931,6 +917,7 @@ void VirtualMachine::RunLoop() {
auto num_dims = shape_tensor->shape[0];
auto shape = std::vector<int64_t>(shape_tensor->shape[0]);
shape.assign(dims, dims + num_dims);
// TODO(wweic) ctx could be obtained from the ctxs list.
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]);
auto obj = Tensor(data);
......@@ -976,6 +963,21 @@ void VirtualMachine::RunLoop() {
}
}
runtime::Module CreateVirtualMachine(const Executable* exec) {
std::shared_ptr<VirtualMachine> vm = std::make_shared<VirtualMachine>();
vm->LoadExecutable(exec);
return runtime::Module(vm);
}
TVM_REGISTER_GLOBAL("relay._vm._VirtualMachine")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec) << "The virtual machine executable has not been defined yet."
<< "\n";
*rv = CreateVirtualMachine(exec);
});
} // namespace vm
} // namespace runtime
} // namespace tvm
......@@ -47,14 +47,16 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr):
mod = relay.Module()
mod["main"] = f
vm = relay.vm.compile(mod, target)
vm.init(tvm.cpu())
exe = relay.vm.compile(mod, target)
vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
return vm.invoke("main", *args)
else:
assert isinstance(f, relay.Module), "expected expression or module"
mod = f
vm = relay.vm.compile(mod, target)
vm.init(tvm.cpu())
exe = relay.vm.compile(mod, target)
vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
ret = vm.invoke("main", *args)
return ret
......@@ -573,25 +575,6 @@ def test_add_op_broadcast():
mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod)
def test_set_params():
mod = relay.Module()
x = relay.var('x', shape=(10, 5))
w = relay.var('w', shape=(6, 5))
b = relay.var('b', shape=(6,))
y = relay.nn.bias_add(relay.nn.dense(x, w), b)
mod["main"] = relay.Function([x, w, b], y)
vm = relay.vm.compile(mod, 'llvm')
vm.init(tvm.cpu())
x_np = np.random.uniform(size=(10, 5)).astype('float32')
w_np = np.random.uniform(size=(6, 5)).astype('float32')
b_np = np.random.uniform(size=(6,)).astype('float32')
ref_np = np.dot(x_np, w_np.T) + b_np
params = {'w': w_np}
vm.load_params(params)
out = vm.run(x_np, b_np)
tvm.testing.assert_allclose(out.asnumpy(), ref_np)
if __name__ == "__main__":
test_id()
......@@ -626,4 +609,3 @@ if __name__ == "__main__":
test_add_op_scalar()
test_add_op_tensor()
test_add_op_broadcast()
test_set_params()
......@@ -22,29 +22,25 @@ import tvm
from tvm import relay
from tvm.relay.module import Module as rly_module
from tvm.relay import vm as _vm
from tvm.relay import serializer, deserializer
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude
from tvm.contrib import util
from tvm.relay import testing
def create_vm(f, ctx=tvm.cpu(), target="llvm", params=None):
def create_exec(f, target="llvm", params=None):
if isinstance(f, relay.Expr):
mod = relay.Module()
mod["main"] = f
vm = _vm.compile(mod, target=target, params=params)
vm.init(ctx)
return vm
executable = _vm.compile(mod, target=target, params=params)
return executable
else:
assert isinstance(f, relay.Module), "expected mod as relay.Module"
vm = _vm.compile(f, target=target, params=params)
vm.init(ctx)
return vm
executable = _vm.compile(f, target=target, params=params)
return executable
def veval(vm, *args, ctx=tvm.cpu()):
assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine"
vm.init(ctx)
ret = vm.run(*args)
return ret
......@@ -59,13 +55,11 @@ def run_network(mod,
return result.asnumpy().astype(dtype)
def get_serialized_output(mod, data, params, target, ctx, dtype='float32'):
vm = create_vm(mod, ctx, target, params=params)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
exe = create_exec(mod, target, params=params)
code, lib = exe.save()
des_exec = _vm.Executable.load_exec(code, lib)
des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(ctx)
des_vm.load_params(params)
result = des_vm.run(data)
return result.asnumpy().astype(dtype)
......@@ -99,26 +93,25 @@ def test_serializer():
main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1))
mod["main"] = main
vm = create_vm(mod)
ser = serializer.Serializer(vm)
exe = create_exec(mod)
glbs = ser.globals
glbs = exe.globals
assert len(glbs) == 3
assert "f1" in glbs
assert "f2" in glbs
assert "main" in glbs
prim_ops = ser.primitive_ops
prim_ops = exe.primitive_ops
assert any(item.startswith('fused_add') for item in prim_ops)
assert any(item.startswith('fused_subtract') for item in prim_ops)
assert any(item.startswith('fused_multiply') for item in prim_ops)
code = ser.bytecode
code = exe.bytecode
assert "main 5 2 5" in code
assert "f1 2 1 3" in code
assert "f2 2 1 3" in code
code, lib = ser.serialize()
code, lib = exe.save()
assert isinstance(code, bytearray)
assert isinstance(lib, tvm.module.Module)
......@@ -129,24 +122,24 @@ def test_save_load():
x_data = np.random.rand(10, 10).astype('float32')
# serialize.
vm = create_vm(f)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
vm = create_exec(f)
code, lib = vm.save()
assert isinstance(code, bytearray)
# save and load the code and lib file.
tmp = util.tempdir()
path_lib = tmp.relpath("lib.so")
lib.export_library(path_lib)
with open(tmp.relpath("code.bc"), "wb") as fo:
with open(tmp.relpath("code.ro"), "wb") as fo:
fo.write(code)
loaded_lib = tvm.module.load(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read())
loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read())
# deserialize.
deser = deserializer.Deserializer(loaded_code, loaded_lib)
des_vm = deser.deserialize()
des_exec = _vm.Executable.load_exec(loaded_code, loaded_lib)
des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
res = veval(des_vm, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data)
......@@ -156,12 +149,12 @@ def test_const():
c = relay.const(1.0, "float32")
x = relay.var('x', shape=(10, 10), dtype='float32')
f = relay.Function([x], x + c)
vm = create_vm(f)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
exe = create_exec(f)
code, lib = exe.save()
assert isinstance(code, bytearray)
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
des_exec = _vm.Executable.load_exec(code, lib)
des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
x_data = np.random.rand(10, 10).astype('float32')
res = veval(des_vm, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + 1)
......@@ -177,11 +170,11 @@ def test_if():
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32')
vm = create_vm(f)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
exe = create_exec(f)
code, lib = exe.save()
des_exec = _vm.Executable.load_exec(code, lib)
des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
# same
res = veval(des_vm, x_data, x_data)
......@@ -213,11 +206,11 @@ def test_loop():
aarg = relay.var('accum', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
vm = create_vm(mod)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
exe = create_exec(mod)
code, lib = exe.save()
des_exec = _vm.Executable.load_exec(code, lib)
des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
result = veval(des_vm, i_data, accum_data)
tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
......@@ -230,11 +223,11 @@ def test_tuple():
i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32')
vm = create_vm(f)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
exe = create_exec(f)
code, lib = exe.save()
des_exec = _vm.Executable.load_exec(code, lib)
des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
result = veval(des_vm, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), j_data)
......@@ -251,11 +244,11 @@ def test_adt_list():
f = relay.Function([], l321)
mod["main"] = f
vm = create_vm(mod)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
exe = create_exec(mod)
code, lib = exe.save()
des_exec = _vm.Executable.load_exec(code, lib)
des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
result = veval(des_vm)
assert len(result) == 2
......@@ -297,11 +290,11 @@ def test_adt_compose():
f = relay.Function([y], add_two_body)
mod["main"] = f
vm = create_vm(mod)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
exe = create_exec(mod)
code, lib = exe.save()
des_exec = _vm.Executable.load_exec(code, lib)
des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
x_data = np.array(np.random.rand()).astype('float32')
result = veval(des_vm, x_data)
......@@ -317,11 +310,11 @@ def test_closure():
clo = ff(relay.const(1.0))
main = clo(relay.const(2.0))
vm = create_vm(main)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
exe = create_exec(main)
code, lib = exe.save()
des_exec = _vm.Executable.load_exec(code, lib)
des_vm = _vm.VirtualMachine(des_exec)
des_vm.init(tvm.cpu())
res = veval(des_vm)
tvm.testing.assert_allclose(res.asnumpy(), 3.0)
......
......@@ -26,9 +26,9 @@ def test_basic():
mod, params = resnet.get_workload()
target = 'llvm'
ctx = tvm.cpu()
vm = relay.profiler_vm.compile(mod, target)
exe = relay.profiler_vm.compile(mod, target, params=params)
vm = relay.profiler_vm.VirtualMachineProfiler(exe)
vm.init(ctx)
vm.load_params(params)
data = np.random.rand(1, 3, 224, 224).astype('float32')
res = vm.invoke("main", [data])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment