Commit 90455121 by Zhi Committed by Jared Roesch

[Relay][VM] Relay VM serialization (#3647)

* relay vm serialization

* fix lint

* load params, fix stream

* lint

* fix typo
parent 0365e50a
...@@ -204,7 +204,7 @@ InvokeClosure ...@@ -204,7 +204,7 @@ InvokeClosure
**Arguments**: **Arguments**:
:: ::
RegName closure RegName closure
size_t closure_args_num size_t num_closure_args
RegName* closure_args RegName* closure_args
Invokes `closure`, consuming the number of arguments declared in the closure's VMFunction. Invokes `closure`, consuming the number of arguments declared in the closure's VMFunction.
......
...@@ -36,6 +36,9 @@ namespace tvm { ...@@ -36,6 +36,9 @@ namespace tvm {
namespace runtime { namespace runtime {
namespace vm { namespace vm {
/*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
/*! \brief A register name. */ /*! \brief A register name. */
using RegName = int64_t; using RegName = int64_t;
...@@ -103,7 +106,7 @@ struct Instruction { ...@@ -103,7 +106,7 @@ struct Instruction {
/*! \brief The register containing the closure. */ /*! \brief The register containing the closure. */
RegName closure; RegName closure;
/*! \brief The number of arguments to the closure. */ /*! \brief The number of arguments to the closure. */
Index closure_args_num; Index num_closure_args;
/*! \brief The closure arguments as an array. */ /*! \brief The closure arguments as an array. */
RegName* closure_args; RegName* closure_args;
}; };
...@@ -115,7 +118,7 @@ struct Instruction { ...@@ -115,7 +118,7 @@ struct Instruction {
/*! \brief The source register for a move operation. */ /*! \brief The source register for a move operation. */
RegName from; RegName from;
}; };
struct /* Packed Operands */ { struct /* InvokePacked Operands */ {
/*! \brief The index into the packed function table. */ /*! \brief The index into the packed function table. */
Index packed_index; Index packed_index;
/*! \brief The arity of the packed function. */ /*! \brief The arity of the packed function. */
...@@ -149,7 +152,7 @@ struct Instruction { ...@@ -149,7 +152,7 @@ struct Instruction {
}; };
struct /* LoadConsti Operands */ { struct /* LoadConsti Operands */ {
/* \brief The index into the constant pool. */ /* \brief The index into the constant pool. */
size_t val; Index val;
} load_consti; } load_consti;
struct /* Jump Operands */ { struct /* Jump Operands */ {
/*! \brief The jump offset. */ /*! \brief The jump offset. */
...@@ -284,7 +287,7 @@ struct Instruction { ...@@ -284,7 +287,7 @@ struct Instruction {
* \param dst The destination register. * \param dst The destination register.
* \return The load_constanti instruction. * \return The load_constanti instruction.
*/ */
static Instruction LoadConsti(size_t val, RegName dst); static Instruction LoadConsti(Index val, RegName dst);
/*! \brief Construct a move instruction. /*! \brief Construct a move instruction.
* \param src The source register. * \param src The source register.
* \param dst The destination register. * \param dst The destination register.
...@@ -379,6 +382,8 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -379,6 +382,8 @@ class VirtualMachine : public runtime::ModuleNode {
return "VirtualMachine"; return "VirtualMachine";
} }
/*! \brief The runtime module/library that contains generated code. */
runtime::Module lib;
/*! \brief The virtual machine's packed function table. */ /*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs; std::vector<PackedFunc> packed_funcs;
/*! \brief The virtual machine's function table. */ /*! \brief The virtual machine's function table. */
...@@ -448,16 +453,30 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -448,16 +453,30 @@ class VirtualMachine : public runtime::ModuleNode {
void Init(const std::vector<TVMContext>& contexts); void Init(const std::vector<TVMContext>& contexts);
void Run(); void Run();
/*!
* \brief Load parameters from the parameter bytearray.
* \param params The binary file that contains parameters.
*/
void LoadParams(const std::string& params);
/*! \brief A map from globals (as strings) to their index in the function map. /*! \brief A map from globals (as strings) to their index in the function map.
*/ */
std::unordered_map<std::string, Index> global_map; std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
* corresponds to the position of the `packed_funcs` list.
*/
std::unordered_map<std::string, Index> primitive_map;
private: private:
/*! \brief Invoke a global setting up the VM state to execute. /*! \brief Invoke a global setting up the VM state to execute.
* *
* This does not begin execution of the VM. * This does not begin execution of the VM.
*/ */
void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args); void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args);
/*! \brief The parameter name to data mapping. */
std::unordered_map<std::string, Object> params_;
}; };
} // namespace vm } // namespace vm
......
...@@ -34,6 +34,8 @@ from . import debug ...@@ -34,6 +34,8 @@ from . import debug
from . import param_dict from . import param_dict
from . import feature from . import feature
from .backend import vm from .backend import vm
from .backend import serializer
from .backend import deserializer
from .backend import vmobj from .backend import vmobj
# Root operators # Root operators
......
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
The Relay Virtual Machine deserializer.
Python interface for deserializing a Relay VM.
"""
from tvm import module
from tvm._ffi.runtime_ctypes import TVMByteArray
from . import _vm
from . import vm as rly_vm
def _create_deserializer(code, lib):
"""Create a deserializer object.
Parameters
----------
code : bytearray
The serialized virtual machine code.
lib : :py:class:`~tvm.module.Module`
The serialized runtime module/library that contains the hardware
dependent binary code.
Returns
-------
ret : Deserializer
The created virtual machine deserializer.
"""
if isinstance(code, (bytes, str)):
code = bytearray(code)
elif not isinstance(code, (bytearray, TVMByteArray)):
raise TypeError("vm is expected to be the type of bytearray or " +
"TVMByteArray, but received {}".format(type(code)))
if not isinstance(lib, module.Module):
raise TypeError("lib is expected to be the type of tvm.module.Module" +
", but received {}".format(type(lib)))
return _vm._Deserializer(code, lib)
class Deserializer:
"""Relay VM deserializer.
Parameters
----------
code : bytearray
The serialized virtual machine code.
lib : :py:class:`~tvm.module.Module`
The serialized runtime module/library that contains the hardware
dependent binary code.
"""
def __init__(self, code, lib):
self.mod = _create_deserializer(code, lib)
self._deserialize = self.mod["deserialize"]
def deserialize(self):
"""Deserialize the serialized bytecode into a Relay VM.
Returns
-------
ret : VirtualMachine
The deserialized Relay VM.
"""
return rly_vm.VirtualMachine(self._deserialize())
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
The Relay Virtual Machine serializer.
Python interface for serializing a Relay VM.
"""
import tvm
from . import _vm
from . import vm as rly_vm
def _create_serializer(vm):
"""Create a VM serializer.
Parameters
----------
vm : Union[VirtualMachine, :py:class:`~tvm.module.Module`]
The virtual machine to be serialized.
Returns
-------
ret : Serializer
The created virtual machine serializer.
"""
if isinstance(vm, rly_vm.VirtualMachine):
vm = vm.module
elif not isinstance(vm, tvm.module.Module):
raise TypeError("vm is expected to be the type of VirtualMachine or " +
"tvm.Module, but received {}".format(type(vm)))
return _vm._Serializer(vm)
class Serializer:
"""Relay VM serializer."""
def __init__(self, vm):
self.mod = _create_serializer(vm)
self._get_lib = self.mod["get_lib"]
self._get_bytecode = self.mod["get_bytecode"]
self._get_globals = self.mod["get_globals"]
self._get_stats = self.mod["get_stats"]
self._get_primitive_ops = self.mod["get_primitive_ops"]
self._serialize = self.mod["serialize"]
@property
def stats(self):
"""Get the statistics of the Relay VM.
Returns
-------
ret : String
The serialized statistic information.
"""
return self._get_stats()
@property
def primitive_ops(self):
"""Get the name of the primitive ops that are executed in the VM.
Returns
-------
ret : List[:py:class:`~tvm.expr.StringImm`]
The list of primitive ops.
"""
return [prim_op.value for prim_op in self._get_primitive_ops()]
@property
def bytecode(self):
"""Get the bytecode of the Relay VM.
Returns
-------
ret : String
The serialized bytecode.
Notes
-----
The bytecode is in the following format:
func_name reg_file_size num_instructions
param1 param2 ... paramM
instruction1
instruction2
...
instructionN
Each instruction is printed in the following format:
hash opcode field1 ... fieldX # The text format.
The part starting from # is only used for visualization and debugging.
The real serialized code doesn't contain it, therefore the deserializer
doesn't need to deal with it as well.
"""
return self._get_bytecode()
@property
def globals(self):
"""Get the globals used by the Relay VM.
Returns
-------
ret : List[:py:class:`~tvm.expr.StringImm`]
The serialized globals.
"""
return [glb.value for glb in self._get_globals()]
def serialize(self):
"""Serialize the Relay VM.
Returns
-------
code : bytearray
The binary blob representing a serialized Relay VM. It can then be
saved to disk and later deserialized into a new VM.
lib : :py:class:`~tvm.module.Module`
The runtime module that contains the generated code. It is
basically a library that is composed of hardware dependent code.
Notes
-----
The returned code is organized with the following sections in order.
- Global section. This section contains the globals used by the
virtual machine.
- Constant section. This section is used to store the constant pool of
a virtual machine.
- Primitive name section. This section is introduced to accommodate
the list of primitive operator names that will be invoked by the
virtual machine.
- Code section. The VM functions, including bytecode, are sitting in
this section.
Examples
--------
.. code-block:: python
import numpy as np
import tvm
from tvm import relay
# define a simple network.
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x + x)
mod = relay.Module({"main": f})
# create a Relay VM.
ctx = tvm.cpu()
target = "llvm"
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(ctx)
# serialize.
ser = relay.serializer.Serializer(vm)
code, lib = ser.serialize()
# save and load the code and lib file.
tmp = tvm.contrib.util.tempdir()
path_lib = tmp.relpath("lib.so")
lib.export_library(path_lib)
with open(tmp.relpath("code.bc"), "wb") as fo:
fo.write(code)
loaded_lib = tvm.module.load(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read())
# deserialize.
deser = relay.deserializer.Deserializer(loaded_code, loaded_lib)
des_vm = deser.deserialize()
# execute the deserialized vm.
des_vm.init(ctx)
x_data = np.random.rand(10, 10).astype('float32')
res = des_vm.run(x_data)
print(res.asnumpy())
"""
return self._serialize(), self._get_lib()
...@@ -16,13 +16,14 @@ ...@@ -16,13 +16,14 @@
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
""" """
The Relay Virtual Vachine. The Relay Virtual Machine.
Implements a Python interface to compiling and executing on the Relay VM. Implements a Python interface to compiling and executing on the Relay VM.
""" """
import numpy as np import numpy as np
import tvm import tvm
from tvm._ffi.runtime_ctypes import TVMByteArray
from . import _vm from . import _vm
from . import vmobj as _obj from . import vmobj as _obj
from .interpreter import Executor from .interpreter import Executor
...@@ -71,6 +72,7 @@ class VirtualMachine(object): ...@@ -71,6 +72,7 @@ class VirtualMachine(object):
def __init__(self, mod): def __init__(self, mod):
self.mod = mod self.mod = mod
self._init = self.mod["init"] self._init = self.mod["init"]
self._load_params = self.mod["load_params"]
self._invoke = self.mod["invoke"] self._invoke = self.mod["invoke"]
def init(self, ctx): def init(self, ctx):
...@@ -84,6 +86,23 @@ class VirtualMachine(object): ...@@ -84,6 +86,23 @@ class VirtualMachine(object):
args = [ctx.device_type, ctx.device_id] args = [ctx.device_type, ctx.device_id]
self._init(*args) self._init(*args)
def load_params(self, params):
"""Load parameters for the VM.
Parameters
----------
params : Union[bytearray, Dict]
The dictionary that contains serialized parameters.
"""
if isinstance(params, dict):
params = tvm.relay.save_param_dict(params)
elif isinstance(params, (bytes, str)):
params = bytearray(params)
if not isinstance(params, (bytearray, TVMByteArray)):
raise TypeError("params must be a bytearray")
self._load_params(bytearray(params))
def invoke(self, func_name, *args): def invoke(self, func_name, *args):
"""Invoke a function. """Invoke a function.
...@@ -118,6 +137,11 @@ class VirtualMachine(object): ...@@ -118,6 +137,11 @@ class VirtualMachine(object):
""" """
return self.invoke("main", *args) return self.invoke("main", *args)
@property
def module(self):
"""Return the runtime module contained in a virtual machine."""
return self.mod
class VMCompiler(object): class VMCompiler(object):
"""Build Relay module to run on VM runtime.""" """Build Relay module to run on VM runtime."""
......
...@@ -745,7 +745,7 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -745,7 +745,7 @@ class VMCompiler : public runtime::ModuleNode {
} }
#endif // USE_RELAY_DEBUG #endif // USE_RELAY_DEBUG
PopulatePackedFuncMap(); LibraryCodegen();
for (auto gv : context_.global_map) { for (auto gv : context_.global_map) {
vm_->global_map.insert({gv.first->name_hint, gv.second}); vm_->global_map.insert({gv.first->name_hint, gv.second});
...@@ -775,26 +775,28 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -775,26 +775,28 @@ class VMCompiler : public runtime::ModuleNode {
} }
} }
void PopulatePackedFuncMap() { void LibraryCodegen() {
auto const& lowered_funcs = context_.lowered_funcs; auto const& lowered_funcs = context_.lowered_funcs;
if (lowered_funcs.size() == 0) { if (lowered_funcs.size() == 0) {
return; return;
} }
runtime::Module mod;
// TODO(@icemelon9): support heterogeneous targets // TODO(@icemelon9): support heterogeneous targets
Target target; Target target;
for (auto kv : targets_) { for (auto kv : targets_) {
target = kv.second; target = kv.second;
} }
if (const auto* f = runtime::Registry::Get("relay.backend.build")) { if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), runtime::Module mod =
target, target_host_); (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target,
target_host_);
CHECK(mod.operator->());
vm_->lib = mod;
} else { } else {
LOG(FATAL) << "relay.backend.build is not registered"; LOG(FATAL) << "relay.backend.build is not registered";
} }
CHECK(mod.operator->()); size_t primitive_index = 0;
for (auto lfunc : lowered_funcs) { for (auto lfunc : lowered_funcs) {
vm_->packed_funcs.push_back(mod.GetFunction(lfunc->name)); vm_->primitive_map.insert({lfunc->name, primitive_index++});
} }
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/deserializer.h
* \brief Define a deserializer for the serialized Relay VM.
*/
#ifndef TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
#define TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
#include <dmlc/memory_io.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/vm.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace relay {
namespace vm {
using namespace tvm::runtime::vm;
namespace runtime = tvm::runtime;
class Deserializer : public runtime::ModuleNode {
public:
/*!
* \brief Initialize the deserializer for creating a virtual machine object.
*
* \param code The serialized code.
* \param lib The serialized runtime module/library that contains the
* hardware dependent code.
*/
inline void Init(const std::string& code, const runtime::Module& lib);
/*!
* \brief Return the member function to the frontend.
*
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
*
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final { return "Deserializer"; }
/*! \brief Deserialize the serialized VM. */
void Deserialize();
virtual ~Deserializer() { delete strm_; }
private:
/*! \brief Deserialize the globals in `vm_`. */
void DeserializeGlobalSection();
/*! \brief Deserialize the constant pool in `vm_`. */
void DeserializeConstantSection();
/*! \brief Deserialize primitive op names in `vm_`. */
void DeserializePrimitiveOpNames();
/*! \brief Deserialize the vm functions in `vm_`. */
void DeserializeCodeSection();
/*! \brief The code to be serialized. */
std::string code_;
/*! \brief The stream used for serialization. */
dmlc::Stream* strm_;
/*! \brief The VM to be created. */
std::shared_ptr<VirtualMachine> vm_;
};
} // namespace vm
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/serialize_util.h
* \brief Definitions of helpers for serializing and deserializing a Relay VM.
*/
#ifndef TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_
#define TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_
#include <dmlc/common.h>
#include <dmlc/memory_io.h>
#include <tvm/runtime/vm.h>
#include <functional>
#include <string>
#include <vector>
namespace tvm {
namespace relay {
namespace vm {
/*! \brief The magic number for the serialized VM bytecode file */
constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D;
template <typename T>
static inline size_t VectorHash(size_t key, const std::vector<T>& values) {
for (const auto& it : values) {
key = dmlc::HashCombine(key, it);
}
return key;
}
// A struct to hold the funciton info in the code section.
struct VMFunctionSerializer {
/*! \brief The name of the VMFunction. */
std::string name;
/*! \brief The number of registers used by the VMFunction. */
Index register_file_size;
/*! \brief The number of instructions in the VMFunction. */
size_t num_instructions;
/*! \brief The parameters of the VMFunction. */
std::vector<std::string> params;
VMFunctionSerializer() = default;
VMFunctionSerializer(const std::string& name,
Index register_file_size,
size_t num_instructions,
const std::vector<std::string>& params)
: name(name),
register_file_size(register_file_size),
num_instructions(num_instructions),
params(params) {}
/*!
* \brief Load the serialized function header.
* \param strm The stream used to load data.
* \return True if successful. Otherwise, false.
*/
bool Load(dmlc::Stream* strm) {
std::vector<std::string> func_info;
if (!strm->Read(&func_info)) return false;
CHECK_EQ(func_info.size(), 3U) << "Failed to decode the vm function."
<< "\n";
name = func_info[0];
register_file_size = std::stoll(func_info[1]);
// Get the number of instructions.
num_instructions = static_cast<size_t>(std::stoll(func_info[2]));
return strm->Read(&params);
}
/*!
* \brief Save the VM function header into the serialized form.
* \param strm The stream used to save data.
*/
void Save(dmlc::Stream* strm) const {
std::vector<std::string> func_info;
func_info.push_back(name);
func_info.push_back(std::to_string(register_file_size));
func_info.push_back(std::to_string(num_instructions));
strm->Write(func_info);
strm->Write(params);
}
};
struct VMInstructionSerializer {
/*! \brief The opcode of the instruction. */
Index opcode;
/*! \brief The fields of the instruction. */
std::vector<Index> fields;
VMInstructionSerializer() = default;
VMInstructionSerializer(Index opcode, const std::vector<Index>& fields) :
opcode(opcode), fields(fields) {}
/*!
* \brief Compute the hash of the serialized instruction.
* \return The hash that combines the opcode and all fields of the VM
* instruction.
*/
Index Hash() const {
size_t key = static_cast<size_t>(opcode);
key = VectorHash(key, fields);
return key;
}
/*!
* \brief Load the serialized instruction.
* \param strm The stream used to load data.
* \return True if successful. Otherwise, false.
*/
bool Load(dmlc::Stream* strm) {
std::vector<Index> instr;
if (!strm->Read(&instr)) return false;
CHECK_GE(instr.size(), 2U);
Index loaded_hash = instr[0];
opcode = instr[1];
for (size_t i = 2; i < instr.size(); i++) {
fields.push_back(instr[i]);
}
Index hash = Hash();
CHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: "
<< opcode << "\n";
return true;
}
/*!
* \brief Save the instruction into the serialized form.
* \param strm The stream used to save data.
*/
void Save(dmlc::Stream* strm) const {
Index hash = Hash();
std::vector<Index> serialized({hash, opcode});
serialized.insert(serialized.end(), fields.begin(), fields.end());
strm->Write(serialized);
}
};
} // namespace vm
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/serializer.h
* \brief Define a serializer for the Relay VM.
*
* The following components of a Relay VM will be serialized:
* - The `constants`, e.g., the constant pool, that contains the
* constants used in a Relay program.
* - The `packed_funcs` that essentially contains the generated code for
* a specific target. We return it as a runtime module that can be exported as
* a library file (e.g., .so, .o, or .tar).
* - The `global_map` that contains the globals.
* - The `primitive_map` that contains the name of individual primitive operators.
* - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of
* a list of instructions/bytecode.
*
* Note that only the library is returned as a separate module. All othere parts
* are stored in a single serialized code that is organized with the following
* sections in order.
* - Global section, containing all globals.
* - Constant section, storing the constant pool.
* - Primitive name section, containing the function name of the primitive ops
* used by the virtual machine.
* - Code section, handling the VM functions and bytecode.
*
* The code section is again organized as follows for each VM function:
* func_name, register_file_size, num_instructions (N)
* param1, param2, ..., paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Serializing an `Instruction` requires us to deal with the bytecode. Each line
* of the instructions could be serialized as the following format:
* hash, opcode, f1, f2, ..., fX, field with variable length
* 1. hash: the hash of the instruction. This number will be used to help us
* validate if an instruction is well-formed during deserialization.
* 2. opcode: the opcode code of the instruction.
* 3. f1, f2, ..., fX. These fields together represent the fixed fields in
* an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For
* example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
* 4. The rest of the line indicates the field with variable length, e.g.,
* the shape of a tensor, the args used by an `InvokPacked` instruction, etc.
*/
#ifndef TVM_RELAY_BACKEND_VM_SERIALIZER_H_
#define TVM_RELAY_BACKEND_VM_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <tvm/ir.h>
#include <tvm/node/container.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/vm.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace relay {
namespace vm {
using namespace tvm::runtime;
using namespace tvm::runtime::vm;
/*!
* \brief The Relay VM serializer.
*/
class Serializer : public runtime::ModuleNode {
public:
/*!
* \brief Initialize the serializer for a virtual machine.
*
* \param vm The Relay virtual machine.
*/
inline void Init(const VirtualMachine* vm);
/*!
* \brief Return the member function to the frontend.
*
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
*
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final { return "Serializer"; }
/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
*/
std::string Stats() const;
/*!
* \brief Serialize the `vm_` into global section, constant section, and code
* section.
*
* \return The binary representation of the VM.
*/
TVMByteArray Serialize();
/*!
* \brief Get a list of the globals used by the `_vm`.
*
* \return The global map in the form a list.
*/
tvm::Array<tvm::Expr> GetGlobals() const;
/*!
* \brief Get the primitive operators that are contained in the Relay VM.
*
* \return The list of primitve operators.
*/
tvm::Array<tvm::Expr> GetPrimitiveOps() const;
/*!
* \brief Get the serialized form of the `functions` in `vm_`. This is
* essentially bytecode serialization.
*
* \return The serialized vm bytecode.
*
* \note The bytecode is in the following format:
* func_name reg_file_size num_instructions
* param1 param2 ... paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Each instruction is printed in the following format:
* opcode num_fields field1 ... fieldX # The text format.
*
* The field starting from # is only used for debugging. The serialized code
* doesn't contain it, therefore the deserializer doens't need to handle it.
*/
std::string GetBytecode() const;
/*! \brief Get the `lib` module in vm_. Serialization of `runtime::module`
* has already been supported by TVM. Therefore, we only return the runtime
* module and let users have the flexibility to call `export_library` from
* the frontend to save the library to disk.
*
* \return The runtime module that contains the hardwre dependent code.
*/
inline runtime::Module GetLib() const;
virtual ~Serializer() { delete strm_; }
private:
/*! \brief Serialize the globals in vm_. */
void SerializeGlobalSection();
/*! \brief Serialize the constant pool in vm_. */
void SerializeConstantSection();
/*! \brief Serialize primitive op names in vm_. */
void SerializePrimitiveOpNames();
/*! \brief Serialize the vm functions in vm_. */
void SerializeCodeSection();
/*! \brief The Relay virtual machine for to be serialized. */
const VirtualMachine* vm_;
/*! \brief The stream used for serialization. */
dmlc::Stream* strm_;
/*! \brief The serialized code. */
std::string code_;
};
} // namespace vm
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_SERIALIZER_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment