Commit 90455121 by Zhi Committed by Jared Roesch

[Relay][VM] Relay VM serialization (#3647)

* relay vm serialization

* fix lint

* load params, fix stream

* lint

* fix typo
parent 0365e50a
......@@ -204,7 +204,7 @@ InvokeClosure
**Arguments**:
::
RegName closure
size_t closure_args_num
size_t num_closure_args
RegName* closure_args
Invokes `closure`, consuming the number of arguments declared in the closure's VMFunction.
......
......@@ -36,6 +36,9 @@ namespace tvm {
namespace runtime {
namespace vm {
/*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
/*! \brief A register name. */
using RegName = int64_t;
......@@ -103,7 +106,7 @@ struct Instruction {
/*! \brief The register containing the closure. */
RegName closure;
/*! \brief The number of arguments to the closure. */
Index closure_args_num;
Index num_closure_args;
/*! \brief The closure arguments as an array. */
RegName* closure_args;
};
......@@ -115,7 +118,7 @@ struct Instruction {
/*! \brief The source register for a move operation. */
RegName from;
};
struct /* Packed Operands */ {
struct /* InvokePacked Operands */ {
/*! \brief The index into the packed function table. */
Index packed_index;
/*! \brief The arity of the packed function. */
......@@ -149,7 +152,7 @@ struct Instruction {
};
struct /* LoadConsti Operands */ {
/* \brief The index into the constant pool. */
size_t val;
Index val;
} load_consti;
struct /* Jump Operands */ {
/*! \brief The jump offset. */
......@@ -284,7 +287,7 @@ struct Instruction {
* \param dst The destination register.
* \return The load_constanti instruction.
*/
static Instruction LoadConsti(size_t val, RegName dst);
static Instruction LoadConsti(Index val, RegName dst);
/*! \brief Construct a move instruction.
* \param src The source register.
* \param dst The destination register.
......@@ -379,6 +382,8 @@ class VirtualMachine : public runtime::ModuleNode {
return "VirtualMachine";
}
/*! \brief The runtime module/library that contains generated code. */
runtime::Module lib;
/*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs;
/*! \brief The virtual machine's function table. */
......@@ -448,16 +453,30 @@ class VirtualMachine : public runtime::ModuleNode {
void Init(const std::vector<TVMContext>& contexts);
void Run();
/*!
* \brief Load parameters from the parameter bytearray.
* \param params The binary file that contains parameters.
*/
void LoadParams(const std::string& params);
/*! \brief A map from globals (as strings) to their index in the function map.
*/
std::unordered_map<std::string, Index> global_map;
/*! \brief A mapping from the packed function (as string) to the index that
* corresponds to the position of the `packed_funcs` list.
*/
std::unordered_map<std::string, Index> primitive_map;
private:
/*! \brief Invoke a global setting up the VM state to execute.
*
* This does not begin execution of the VM.
*/
void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args);
/*! \brief The parameter name to data mapping. */
std::unordered_map<std::string, Object> params_;
};
} // namespace vm
......
......@@ -34,6 +34,8 @@ from . import debug
from . import param_dict
from . import feature
from .backend import vm
from .backend import serializer
from .backend import deserializer
from .backend import vmobj
# Root operators
......
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
The Relay Virtual Machine deserializer.
Python interface for deserializing a Relay VM.
"""
from tvm import module
from tvm._ffi.runtime_ctypes import TVMByteArray
from . import _vm
from . import vm as rly_vm
def _create_deserializer(code, lib):
"""Create a deserializer object.
Parameters
----------
code : bytearray
The serialized virtual machine code.
lib : :py:class:`~tvm.module.Module`
The serialized runtime module/library that contains the hardware
dependent binary code.
Returns
-------
ret : Deserializer
The created virtual machine deserializer.
"""
if isinstance(code, (bytes, str)):
code = bytearray(code)
elif not isinstance(code, (bytearray, TVMByteArray)):
raise TypeError("vm is expected to be the type of bytearray or " +
"TVMByteArray, but received {}".format(type(code)))
if not isinstance(lib, module.Module):
raise TypeError("lib is expected to be the type of tvm.module.Module" +
", but received {}".format(type(lib)))
return _vm._Deserializer(code, lib)
class Deserializer:
"""Relay VM deserializer.
Parameters
----------
code : bytearray
The serialized virtual machine code.
lib : :py:class:`~tvm.module.Module`
The serialized runtime module/library that contains the hardware
dependent binary code.
"""
def __init__(self, code, lib):
self.mod = _create_deserializer(code, lib)
self._deserialize = self.mod["deserialize"]
def deserialize(self):
"""Deserialize the serialized bytecode into a Relay VM.
Returns
-------
ret : VirtualMachine
The deserialized Relay VM.
"""
return rly_vm.VirtualMachine(self._deserialize())
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
The Relay Virtual Machine serializer.
Python interface for serializing a Relay VM.
"""
import tvm
from . import _vm
from . import vm as rly_vm
def _create_serializer(vm):
"""Create a VM serializer.
Parameters
----------
vm : Union[VirtualMachine, :py:class:`~tvm.module.Module`]
The virtual machine to be serialized.
Returns
-------
ret : Serializer
The created virtual machine serializer.
"""
if isinstance(vm, rly_vm.VirtualMachine):
vm = vm.module
elif not isinstance(vm, tvm.module.Module):
raise TypeError("vm is expected to be the type of VirtualMachine or " +
"tvm.Module, but received {}".format(type(vm)))
return _vm._Serializer(vm)
class Serializer:
"""Relay VM serializer."""
def __init__(self, vm):
self.mod = _create_serializer(vm)
self._get_lib = self.mod["get_lib"]
self._get_bytecode = self.mod["get_bytecode"]
self._get_globals = self.mod["get_globals"]
self._get_stats = self.mod["get_stats"]
self._get_primitive_ops = self.mod["get_primitive_ops"]
self._serialize = self.mod["serialize"]
@property
def stats(self):
"""Get the statistics of the Relay VM.
Returns
-------
ret : String
The serialized statistic information.
"""
return self._get_stats()
@property
def primitive_ops(self):
"""Get the name of the primitive ops that are executed in the VM.
Returns
-------
ret : List[:py:class:`~tvm.expr.StringImm`]
The list of primitive ops.
"""
return [prim_op.value for prim_op in self._get_primitive_ops()]
@property
def bytecode(self):
"""Get the bytecode of the Relay VM.
Returns
-------
ret : String
The serialized bytecode.
Notes
-----
The bytecode is in the following format:
func_name reg_file_size num_instructions
param1 param2 ... paramM
instruction1
instruction2
...
instructionN
Each instruction is printed in the following format:
hash opcode field1 ... fieldX # The text format.
The part starting from # is only used for visualization and debugging.
The real serialized code doesn't contain it, therefore the deserializer
doesn't need to deal with it as well.
"""
return self._get_bytecode()
@property
def globals(self):
"""Get the globals used by the Relay VM.
Returns
-------
ret : List[:py:class:`~tvm.expr.StringImm`]
The serialized globals.
"""
return [glb.value for glb in self._get_globals()]
def serialize(self):
"""Serialize the Relay VM.
Returns
-------
code : bytearray
The binary blob representing a serialized Relay VM. It can then be
saved to disk and later deserialized into a new VM.
lib : :py:class:`~tvm.module.Module`
The runtime module that contains the generated code. It is
basically a library that is composed of hardware dependent code.
Notes
-----
The returned code is organized with the following sections in order.
- Global section. This section contains the globals used by the
virtual machine.
- Constant section. This section is used to store the constant pool of
a virtual machine.
- Primitive name section. This section is introduced to accommodate
the list of primitive operator names that will be invoked by the
virtual machine.
- Code section. The VM functions, including bytecode, are sitting in
this section.
Examples
--------
.. code-block:: python
import numpy as np
import tvm
from tvm import relay
# define a simple network.
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x + x)
mod = relay.Module({"main": f})
# create a Relay VM.
ctx = tvm.cpu()
target = "llvm"
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(ctx)
# serialize.
ser = relay.serializer.Serializer(vm)
code, lib = ser.serialize()
# save and load the code and lib file.
tmp = tvm.contrib.util.tempdir()
path_lib = tmp.relpath("lib.so")
lib.export_library(path_lib)
with open(tmp.relpath("code.bc"), "wb") as fo:
fo.write(code)
loaded_lib = tvm.module.load(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read())
# deserialize.
deser = relay.deserializer.Deserializer(loaded_code, loaded_lib)
des_vm = deser.deserialize()
# execute the deserialized vm.
des_vm.init(ctx)
x_data = np.random.rand(10, 10).astype('float32')
res = des_vm.run(x_data)
print(res.asnumpy())
"""
return self._serialize(), self._get_lib()
......@@ -16,13 +16,14 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
"""
The Relay Virtual Vachine.
The Relay Virtual Machine.
Implements a Python interface to compiling and executing on the Relay VM.
"""
import numpy as np
import tvm
from tvm._ffi.runtime_ctypes import TVMByteArray
from . import _vm
from . import vmobj as _obj
from .interpreter import Executor
......@@ -71,6 +72,7 @@ class VirtualMachine(object):
def __init__(self, mod):
self.mod = mod
self._init = self.mod["init"]
self._load_params = self.mod["load_params"]
self._invoke = self.mod["invoke"]
def init(self, ctx):
......@@ -84,6 +86,23 @@ class VirtualMachine(object):
args = [ctx.device_type, ctx.device_id]
self._init(*args)
def load_params(self, params):
"""Load parameters for the VM.
Parameters
----------
params : Union[bytearray, Dict]
The dictionary that contains serialized parameters.
"""
if isinstance(params, dict):
params = tvm.relay.save_param_dict(params)
elif isinstance(params, (bytes, str)):
params = bytearray(params)
if not isinstance(params, (bytearray, TVMByteArray)):
raise TypeError("params must be a bytearray")
self._load_params(bytearray(params))
def invoke(self, func_name, *args):
"""Invoke a function.
......@@ -118,6 +137,11 @@ class VirtualMachine(object):
"""
return self.invoke("main", *args)
@property
def module(self):
"""Return the runtime module contained in a virtual machine."""
return self.mod
class VMCompiler(object):
"""Build Relay module to run on VM runtime."""
......
......@@ -745,7 +745,7 @@ class VMCompiler : public runtime::ModuleNode {
}
#endif // USE_RELAY_DEBUG
PopulatePackedFuncMap();
LibraryCodegen();
for (auto gv : context_.global_map) {
vm_->global_map.insert({gv.first->name_hint, gv.second});
......@@ -775,26 +775,28 @@ class VMCompiler : public runtime::ModuleNode {
}
}
void PopulatePackedFuncMap() {
void LibraryCodegen() {
auto const& lowered_funcs = context_.lowered_funcs;
if (lowered_funcs.size() == 0) {
return;
}
runtime::Module mod;
// TODO(@icemelon9): support heterogeneous targets
Target target;
for (auto kv : targets_) {
target = kv.second;
}
if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()),
target, target_host_);
runtime::Module mod =
(*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target,
target_host_);
CHECK(mod.operator->());
vm_->lib = mod;
} else {
LOG(FATAL) << "relay.backend.build is not registered";
}
CHECK(mod.operator->());
size_t primitive_index = 0;
for (auto lfunc : lowered_funcs) {
vm_->packed_funcs.push_back(mod.GetFunction(lfunc->name));
vm_->primitive_map.insert({lfunc->name, primitive_index++});
}
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/deserializer.cc
* \brief Implementation of APIs to deserialize the serialized VM bytecode.
*/
#include "deserializer.h"
#include <tvm/runtime/registry.h>
#include <memory>
#include <sstream>
#include "serialize_util.h"
namespace tvm {
namespace relay {
namespace vm {
#define STREAM_CHECK(val, section) \
CHECK(val) << "Invalid VM file format in the " << section << " section." \
<< "\n";
void Deserializer::Init(const std::string& code, const runtime::Module& lib) {
code_ = code;
vm_ = std::make_shared<VirtualMachine>();
vm_->lib = lib;
strm_ = new dmlc::MemoryStringStream(&code_);
}
runtime::PackedFunc Deserializer::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "deserialize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->Deserialize();
*rv = runtime::Module(vm_);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}
void Deserializer::Deserialize() {
// Check header.
uint64_t header;
STREAM_CHECK(strm_->Read(&header), "header");
STREAM_CHECK(header == kTVMVMBytecodeMagic, "header");
// Check version.
std::string version;
STREAM_CHECK(strm_->Read(&version), "version");
STREAM_CHECK(version == TVM_VERSION, "version");
// Global section.
DeserializeGlobalSection();
// Constant section.
DeserializeConstantSection();
// Primitive names that will be invoked by `InvokePacked` instructions.
DeserializePrimitiveOpNames();
// Code section.
DeserializeCodeSection();
}
void Deserializer::DeserializeGlobalSection() {
std::vector<std::string> globals;
STREAM_CHECK(strm_->Read(&globals), "global");
for (size_t i = 0; i < globals.size(); i++) {
vm_->global_map.insert({globals[i], i});
}
}
void Deserializer::DeserializeConstantSection() {
uint64_t sz;
// Load the number of constants.
STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "constant");
size_t size = static_cast<size_t>(sz);
// Load each of the constants.
for (size_t i = 0; i < size; i++) {
runtime::NDArray constant;
STREAM_CHECK(constant.Load(strm_), "constant");
runtime::Object obj = runtime::Object::Tensor(constant);
vm_->constants.push_back(obj);
}
}
void Deserializer::DeserializePrimitiveOpNames() {
std::vector<std::string> primitive_names;
STREAM_CHECK(strm_->Read(&primitive_names), "primitive name");
for (size_t i = 0; i < primitive_names.size(); i++) {
vm_->primitive_map.insert({primitive_names[i], i});
}
}
// Extract the `cnt` number of fields started at `start` from the list
// `instr_fields`.
inline std::vector<Index> ExtractFields(const std::vector<Index>& instr_fields,
Index start,
Index cnt) {
CHECK_LE(static_cast<size_t>(start + cnt), instr_fields.size());
std::vector<Index> ret;
for (auto i = start; i < start + cnt; i++) {
ret.push_back(instr_fields[i]);
}
return ret;
}
Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
Opcode opcode = static_cast<Opcode>(instr.opcode);
switch (opcode) {
case Opcode::Move: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::Move(instr.fields[0], instr.fields[1]);
}
case Opcode::Ret: {
// Number of fields = 1
DCHECK_EQ(instr.fields.size(), 1U);
return Instruction::Ret(instr.fields[0]);
}
case Opcode::Fatal: {
// Number of fields = 0
DCHECK(instr.fields.empty());
return Instruction::Fatal();
}
case Opcode::InvokePacked: {
// Number of fields = 3 + instr.arity
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index packed_index = instr.fields[0];
Index arity = instr.fields[1];
Index output_size = instr.fields[2];
std::vector<RegName> args = ExtractFields(instr.fields, 3, arity);
return Instruction::InvokePacked(packed_index, arity, output_size, args);
}
case Opcode::AllocTensor: {
// Number of fields = 5 + instr.alloc_tensor.ndim
DCHECK_GE(instr.fields.size(), 5U);
DCHECK_EQ(instr.fields.size(), 5U + static_cast<size_t>(instr.fields[3]));
DLDataType dtype;
dtype.code = instr.fields[0];
dtype.bits = instr.fields[1];
dtype.lanes = instr.fields[2];
Index ndim = instr.fields[3];
RegName dst = instr.fields[4];
std::vector<Index> shape = ExtractFields(instr.fields, 5, ndim);
return Instruction::AllocTensor(shape, dtype, dst);
}
case Opcode::AllocTensorReg: {
// Number of fields = 5
DCHECK_EQ(instr.fields.size(), 5U);
Index shape_register = instr.fields[0];
DLDataType dtype;
dtype.code = instr.fields[1];
dtype.bits = instr.fields[2];
dtype.lanes = instr.fields[3];
RegName dst = instr.fields[4];
return Instruction::AllocTensorReg(shape_register, dtype, dst);
}
case Opcode::AllocDatatype: {
// Number of fields = 3 + instr.num_fields
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index constructor_tag = instr.fields[0];
Index num_fields = instr.fields[1];
RegName dst = instr.fields[2];
std::vector<Index> fields = ExtractFields(instr.fields, 3, num_fields);
return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst);
}
case Opcode::AllocClosure: {
// Number of fields = 3 + instr.num_freevar
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index clo_index = instr.fields[0];
Index num_freevar = instr.fields[1];
RegName dst = instr.fields[2];
std::vector<Index> free_vars = ExtractFields(instr.fields, 3, num_freevar);
return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst);
}
case Opcode::If: {
// Number of fields = 4
DCHECK_EQ(instr.fields.size(), 4U);
Index test = instr.fields[0];
Index target = instr.fields[1];
Index true_offset = instr.fields[2];
Index false_offset = instr.fields[3];
return Instruction::If(test, target, true_offset, false_offset);
}
case Opcode::Invoke: {
// Number of fields = 3 + instr.num_args
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index func_index = instr.fields[0];
Index num_args = instr.fields[1];
RegName dst = instr.fields[2];
std::vector<Index> args = ExtractFields(instr.fields, 3, num_args);
return Instruction::Invoke(func_index, args, dst);
}
case Opcode::InvokeClosure: {
// Number of fields = 3 + instr.num_closure_args
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Index closure = instr.fields[0];
Index num_closure_args = instr.fields[1];
RegName dst = instr.fields[2];
std::vector<Index> args = ExtractFields(instr.fields, 3, num_closure_args);
return Instruction::InvokeClosure(closure, args, dst);
}
case Opcode::LoadConst: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::LoadConst(instr.fields[0], instr.fields[1]);
}
case Opcode::LoadConsti: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::LoadConsti(instr.fields[0], instr.fields[1]);
}
case Opcode::GetField: {
// Number of fields = 3
DCHECK_EQ(instr.fields.size(), 3U);
return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]);
}
case Opcode::GetTag: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::GetTag(instr.fields[0], instr.fields[1]);
}
case Opcode::Goto: {
// Number of fields = 1
DCHECK_EQ(instr.fields.size(), 1U);
return Instruction::Goto(instr.fields[0]);
}
default:
LOG(FATAL) << "Invalid opcode" << instr.opcode;
return Instruction();
}
}
void Deserializer::DeserializeCodeSection() {
// Load the number of functions.
uint64_t sz;
STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "code");
size_t num_funcs = static_cast<size_t>(sz);
vm_->functions.resize(num_funcs);
for (size_t i = 0; i < num_funcs; i++) {
// Load the function info.
VMFunctionSerializer loaded_func;
STREAM_CHECK(loaded_func.Load(strm_), "code/function");
// Load the instructions.
std::vector<Instruction> instructions;
for (size_t j = 0; j < loaded_func.num_instructions; j++) {
VMInstructionSerializer instr;
std::vector<Index> instr_fields;
STREAM_CHECK(instr.Load(strm_), "code/instruction");
instructions.push_back(DeserializeInstruction(instr));
}
// Create the VM function.
VMFunction vm_func = VMFunction(loaded_func.name,
loaded_func.params,
instructions,
loaded_func.register_file_size);
auto it = vm_->global_map.find(loaded_func.name);
CHECK(it != vm_->global_map.end());
CHECK_LE(it->second, vm_->global_map.size());
vm_->functions[it->second] = vm_func;
}
}
runtime::Module CreateDeserializer(const std::string& code, const runtime::Module lib) {
std::shared_ptr<Deserializer> exec = std::make_shared<Deserializer>();
exec->Init(code, lib);
return runtime::Module(exec);
}
TVM_REGISTER_GLOBAL("relay._vm._Deserializer")
.set_body_typed(CreateDeserializer);
} // namespace vm
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/deserializer.h
* \brief Define a deserializer for the serialized Relay VM.
*/
#ifndef TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
#define TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
#include <dmlc/memory_io.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/vm.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace relay {
namespace vm {
using namespace tvm::runtime::vm;
namespace runtime = tvm::runtime;
class Deserializer : public runtime::ModuleNode {
public:
/*!
* \brief Initialize the deserializer for creating a virtual machine object.
*
* \param code The serialized code.
* \param lib The serialized runtime module/library that contains the
* hardware dependent code.
*/
inline void Init(const std::string& code, const runtime::Module& lib);
/*!
* \brief Return the member function to the frontend.
*
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
*
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final { return "Deserializer"; }
/*! \brief Deserialize the serialized VM. */
void Deserialize();
virtual ~Deserializer() { delete strm_; }
private:
/*! \brief Deserialize the globals in `vm_`. */
void DeserializeGlobalSection();
/*! \brief Deserialize the constant pool in `vm_`. */
void DeserializeConstantSection();
/*! \brief Deserialize primitive op names in `vm_`. */
void DeserializePrimitiveOpNames();
/*! \brief Deserialize the vm functions in `vm_`. */
void DeserializeCodeSection();
/*! \brief The code to be serialized. */
std::string code_;
/*! \brief The stream used for serialization. */
dmlc::Stream* strm_;
/*! \brief The VM to be created. */
std::shared_ptr<VirtualMachine> vm_;
};
} // namespace vm
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_DESERIALIZER_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/serialize_util.h
* \brief Definitions of helpers for serializing and deserializing a Relay VM.
*/
#ifndef TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_
#define TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_
#include <dmlc/common.h>
#include <dmlc/memory_io.h>
#include <tvm/runtime/vm.h>
#include <functional>
#include <string>
#include <vector>
namespace tvm {
namespace relay {
namespace vm {
/*! \brief The magic number for the serialized VM bytecode file */
constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D;
template <typename T>
static inline size_t VectorHash(size_t key, const std::vector<T>& values) {
for (const auto& it : values) {
key = dmlc::HashCombine(key, it);
}
return key;
}
// A struct to hold the funciton info in the code section.
struct VMFunctionSerializer {
/*! \brief The name of the VMFunction. */
std::string name;
/*! \brief The number of registers used by the VMFunction. */
Index register_file_size;
/*! \brief The number of instructions in the VMFunction. */
size_t num_instructions;
/*! \brief The parameters of the VMFunction. */
std::vector<std::string> params;
VMFunctionSerializer() = default;
VMFunctionSerializer(const std::string& name,
Index register_file_size,
size_t num_instructions,
const std::vector<std::string>& params)
: name(name),
register_file_size(register_file_size),
num_instructions(num_instructions),
params(params) {}
/*!
* \brief Load the serialized function header.
* \param strm The stream used to load data.
* \return True if successful. Otherwise, false.
*/
bool Load(dmlc::Stream* strm) {
std::vector<std::string> func_info;
if (!strm->Read(&func_info)) return false;
CHECK_EQ(func_info.size(), 3U) << "Failed to decode the vm function."
<< "\n";
name = func_info[0];
register_file_size = std::stoll(func_info[1]);
// Get the number of instructions.
num_instructions = static_cast<size_t>(std::stoll(func_info[2]));
return strm->Read(&params);
}
/*!
* \brief Save the VM function header into the serialized form.
* \param strm The stream used to save data.
*/
void Save(dmlc::Stream* strm) const {
std::vector<std::string> func_info;
func_info.push_back(name);
func_info.push_back(std::to_string(register_file_size));
func_info.push_back(std::to_string(num_instructions));
strm->Write(func_info);
strm->Write(params);
}
};
struct VMInstructionSerializer {
/*! \brief The opcode of the instruction. */
Index opcode;
/*! \brief The fields of the instruction. */
std::vector<Index> fields;
VMInstructionSerializer() = default;
VMInstructionSerializer(Index opcode, const std::vector<Index>& fields) :
opcode(opcode), fields(fields) {}
/*!
* \brief Compute the hash of the serialized instruction.
* \return The hash that combines the opcode and all fields of the VM
* instruction.
*/
Index Hash() const {
size_t key = static_cast<size_t>(opcode);
key = VectorHash(key, fields);
return key;
}
/*!
* \brief Load the serialized instruction.
* \param strm The stream used to load data.
* \return True if successful. Otherwise, false.
*/
bool Load(dmlc::Stream* strm) {
std::vector<Index> instr;
if (!strm->Read(&instr)) return false;
CHECK_GE(instr.size(), 2U);
Index loaded_hash = instr[0];
opcode = instr[1];
for (size_t i = 2; i < instr.size(); i++) {
fields.push_back(instr[i]);
}
Index hash = Hash();
CHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: "
<< opcode << "\n";
return true;
}
/*!
* \brief Save the instruction into the serialized form.
* \param strm The stream used to save data.
*/
void Save(dmlc::Stream* strm) const {
Index hash = Hash();
std::vector<Index> serialized({hash, opcode});
serialized.insert(serialized.end(), fields.begin(), fields.end());
strm->Write(serialized);
}
};
} // namespace vm
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/serializer.cc
* \brief Implementation of serializing APIs for the Relay VM.
*/
#include "serializer.h"
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include <algorithm>
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
#include "serialize_util.h"
namespace tvm {
namespace relay {
namespace vm {
void Serializer::Init(const VirtualMachine* vm) {
vm_ = vm;
// Initialize the stream object.
strm_ = new dmlc::MemoryStringStream(&code_);
}
runtime::PackedFunc Serializer::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "get_lib") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetLib();
});
} else if (name == "get_primitive_ops") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetPrimitiveOps();
});
} else if (name == "get_bytecode") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetBytecode();
});
} else if (name == "get_globals") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetGlobals();
});
} else if (name == "get_stats") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->Stats();
});
} else if (name == "serialize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->Serialize();
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}
tvm::Array<tvm::Expr> Serializer::GetPrimitiveOps() const {
std::vector<tvm::Expr> ret;
for (const auto& it : vm_->primitive_map) {
auto packed_name = tvm::ir::StringImm::make(it.first);
auto packed_index = static_cast<size_t>(it.second);
if (ret.size() <= packed_index) {
ret.resize(packed_index + 1);
}
ret[packed_index] = packed_name;
}
return ret;
}
std::string Serializer::Stats() const {
std::ostringstream oss;
oss << "Relay VM statistics:" << std::endl;
// Get the number of constants and the shape of each of them.
oss << " Constant shapes (# " << vm_->constants.size() << "): [";
for (const auto& it : vm_->constants) {
auto cell = it.AsTensor();
CHECK(cell.operator->());
runtime::NDArray data = cell->data;
const auto& shape = data.Shape();
// Scalar
if (shape.empty()) {
oss << "scalar, ";
continue;
}
oss << "[";
for (auto s : shape) {
oss << s << ", ";
}
oss.seekp(-2, oss.cur);
oss << "], " << std::endl;
}
if (!vm_->constants.empty()) oss.seekp(-2, oss.cur);
oss << "]" << std::endl;
// Get the number of globals and the name of each of them.
oss << " Globals (#" << vm_->global_map.size() << "): [";
for (const auto& it : vm_->global_map) {
oss << "(\"" << it.first << "\", " << it.second << ")" << ", ";
}
if (!vm_->global_map.empty()) oss.seekp(-2, oss.cur);
oss << "]" << std::endl;
// Get the number of primitive ops and the name of each of them.
oss << " Primitive ops (#" << vm_->primitive_map.size() << "): [";
const auto& prim_ops = GetPrimitiveOps();
for (const auto& it : prim_ops) {
oss << it << ", ";
}
if (!prim_ops.empty()) oss.seekp(-2, oss.cur);
oss << "]" << std::endl;
return oss.str();
}
TVMByteArray Serializer::Serialize() {
uint64_t header = kTVMVMBytecodeMagic;
strm_->Write(header);
std::string version = TVM_VERSION;
strm_->Write(version);
// Global section.
SerializeGlobalSection();
// Constant section.
SerializeConstantSection();
// Primitive names.
SerializePrimitiveOpNames();
// Code section.
SerializeCodeSection();
TVMByteArray arr;
arr.data = code_.c_str();
arr.size = code_.length();
return arr;
}
void Serializer::SerializeGlobalSection() {
auto globals = GetGlobals();
std::vector<std::string> glbs;
for (const auto& it : globals) {
glbs.push_back(it.as<tvm::ir::StringImm>()->value);
}
strm_->Write(glbs);
}
void Serializer::SerializeConstantSection() {
std::vector<DLTensor*> arrays;
for (const auto& obj : vm_->constants) {
auto cell = obj.AsTensor();
runtime::NDArray data = cell->data;
arrays.push_back(const_cast<DLTensor*>(data.operator->()));
}
strm_->Write(static_cast<uint64_t>(vm_->constants.size()));
for (const auto& it : arrays) {
runtime::SaveDLTensor(strm_, it);
}
}
void Serializer::SerializePrimitiveOpNames() {
auto names = GetPrimitiveOps();
std::vector<std::string> primitive_names;
for (const auto& it : names) {
primitive_names.push_back(it.as<tvm::ir::StringImm>()->value);
}
strm_->Write(primitive_names);
}
// Serialize a virtual machine instruction. It creates a list that contains the
// hash, opcode, and all fields of an instruction.
//
// For example, the function signature used to create an `AllocTensor`
// instruction is:
// Instruction AllocTensor(std::vector<Index> shape, DLDataType dtype, RegName dst)
//
// The serialized form will be:
// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn`
//
// where hash is the hash of serialized instruction that is computed internally
// by the `VMInstructionSerializer`. It is used for sanity check before decoding.
// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)`
// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register`
// is the destination register, and the rest of it together indicates the shape
// of the tensor to be allocated.
VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
std::vector<Index> fields;
// Save the opcode.
DLOG(INFO) << "Serializing: " << instr << std::endl;
switch (instr.op) {
case Opcode::Move: {
// Number of fields = 2
fields.assign({instr.from, instr.dst});
break;
}
case Opcode::Ret: {
// Number of fields = 1
fields.push_back(instr.result);
break;
}
case Opcode::Fatal: {
// Number of fields = 0
break;
}
case Opcode::InvokePacked: {
// Number of fields = 3 + instr.arity
// Note that arity includes both input arguments and outputs. We will
// put all the `arity` number of fields in the end for serialization.
fields.assign({instr.packed_index, instr.arity, instr.output_size});
// Save the args.
fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity);
break;
}
case Opcode::AllocTensor: {
// Number of fields = 5 + instr.alloc_tensor.ndim
// Save `DLDataType` and the dst register.
const auto& dtype = instr.alloc_tensor.dtype;
fields.assign({dtype.code, dtype.bits, dtype.lanes});
// The number of dimensions is not needed for constructing an
// `AllocTensor` instruction as it equals to the length of the `shape`
// vector. However, we save it to conveniently deserialize the instruction
// because we will know how many fields are needed by the `shape` argument.
fields.push_back(instr.alloc_tensor.ndim);
fields.push_back(instr.dst);
// Save the shape of the tensor.
// Note that this field is rotated to the end of the list.
fields.insert(fields.end(), instr.alloc_tensor.shape,
instr.alloc_tensor.shape + instr.alloc_tensor.ndim);
break;
}
case Opcode::AllocTensorReg: {
// Number of fields = 5
fields.push_back(instr.alloc_tensor_reg.shape_register);
// Save `DLDataType` and the dst register.
const auto& dtype = instr.alloc_tensor.dtype;
fields.assign({dtype.code, dtype.bits, dtype.lanes});
fields.push_back(instr.dst);
break;
}
case Opcode::AllocDatatype: {
// Number of fields = 3 + instr.num_fields
fields.assign({instr.constructor_tag, instr.num_fields, instr.dst});
// Save the fields.
fields.insert(fields.end(), instr.datatype_fields,
instr.datatype_fields + instr.num_fields);
break;
}
case Opcode::AllocClosure: {
// Number of fields = 3 + instr.num_freevar
fields.assign({instr.clo_index, instr.num_freevar, instr.dst});
// Save the free vars.
fields.insert(fields.end(), instr.free_vars,
instr.free_vars + instr.num_freevar);
break;
}
case Opcode::If: {
// Number of fields = 4
fields.assign({instr.if_op.test,
instr.if_op.target,
instr.if_op.true_offset,
instr.if_op.false_offset});
break;
}
case Opcode::Invoke: {
// Number of fields = 3 + instr.num_args
fields.assign({instr.func_index, instr.num_args, instr.dst});
// Save the args.
fields.insert(fields.end(), instr.invoke_args_registers,
instr.invoke_args_registers + instr.num_args);
break;
}
case Opcode::InvokeClosure: {
// Number of fields = 3 + instr.num_closure_args
fields.assign({instr.closure, instr.num_closure_args, instr.dst});
// Save the args.
fields.insert(fields.end(), instr.closure_args,
instr.closure_args + instr.num_closure_args);
break;
}
case Opcode::LoadConst: {
// Number of fields = 2
fields.assign({instr.const_index, instr.dst});
break;
}
case Opcode::LoadConsti: {
// Number of fields = 2
fields.assign({instr.load_consti.val, instr.dst});
break;
}
case Opcode::GetField: {
// Number of fields = 3
fields.assign({instr.object, instr.field_index, instr.dst});
break;
}
case Opcode::GetTag: {
// Number of fields = 2
fields.assign({instr.get_tag.object, instr.dst});
break;
}
case Opcode::Goto: {
// Number of fields = 1
fields.push_back(instr.pc_offset);
break;
}
default:
LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op);
break;
}
return VMInstructionSerializer(static_cast<Index>(instr.op), fields);
}
void Serializer::SerializeCodeSection() {
// Save the number of functions.
strm_->Write(static_cast<uint64_t>(vm_->functions.size()));
for (const auto& func : vm_->functions) {
// Serialize the function info.
VMFunctionSerializer func_format(func.name,
func.register_file_size,
func.instructions.size(),
func.params);
func_format.Save(strm_);
// Serialize each instruction.
for (const auto& instr : func.instructions) {
const auto& serialized_instr = SerializeInstruction(instr);
serialized_instr.Save(strm_);
}
}
}
tvm::Array<tvm::Expr> Serializer::GetGlobals() const {
tvm::Array<tvm::Expr> ret;
std::vector<std::pair<std::string, Index> > globals(vm_->global_map.begin(),
vm_->global_map.end());
auto comp = [](const std::pair<std::string, Index>& a,
const std::pair<std::string, Index>& b) {
return a.second < b.second;
};
std::sort(globals.begin(), globals.end(), comp);
for (const auto& it : globals) {
ret.push_back(tvm::ir::StringImm::make(it.first));
}
return ret;
}
std::string Serializer::GetBytecode() const {
std::ostringstream oss;
for (const auto& func : vm_->functions) {
// Print the header of the function format.
oss << "# func name, reg file size, param count, inst count:"
<< std::endl;
oss << func.name << " "
<< func.register_file_size << " "
<< func.params.size() << " "
<< func.instructions.size() << std::endl;
// Print pramams of a `VMFunction`.
oss << "# Parameters:"<< std::endl;
for (const auto& param : func.params) {
oss << param << " ";
}
oss << std::endl;
// Print the instructions of a `VMFunction`.
// The part after ";" is the instruction in text format.
oss << "hash, opcode, fields # inst(text):"<< std::endl;
for (const auto& instr : func.instructions) {
const auto& serialized_instr = SerializeInstruction(instr);
oss << std::hex << "0x" << serialized_instr.Hash() << " "
<< std::dec << serialized_instr.opcode << " ";
for (auto it : serialized_instr.fields) {
oss << it << " ";
}
oss << " # " << instr;
if (oss.str().back() != '\n') oss << std::endl;
}
}
return oss.str();
}
runtime::Module Serializer::GetLib() const {
return vm_->lib;
}
runtime::Module CreateSerializer(const VirtualMachine* vm) {
std::shared_ptr<Serializer> exec = std::make_shared<Serializer>();
exec->Init(vm);
return runtime::Module(exec);
}
TVM_REGISTER_GLOBAL("relay._vm._Serializer")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* vm = dynamic_cast<VirtualMachine*>(mod.operator->());
CHECK(vm) << "Virtual machine has not been defined yet."
<< "\n";
*rv = CreateSerializer(vm);
});
} // namespace vm
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/serializer.h
* \brief Define a serializer for the Relay VM.
*
* The following components of a Relay VM will be serialized:
* - The `constants`, e.g., the constant pool, that contains the
* constants used in a Relay program.
* - The `packed_funcs` that essentially contains the generated code for
* a specific target. We return it as a runtime module that can be exported as
* a library file (e.g., .so, .o, or .tar).
* - The `global_map` that contains the globals.
* - The `primitive_map` that contains the name of individual primitive operators.
* - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of
* a list of instructions/bytecode.
*
* Note that only the library is returned as a separate module. All othere parts
* are stored in a single serialized code that is organized with the following
* sections in order.
* - Global section, containing all globals.
* - Constant section, storing the constant pool.
* - Primitive name section, containing the function name of the primitive ops
* used by the virtual machine.
* - Code section, handling the VM functions and bytecode.
*
* The code section is again organized as follows for each VM function:
* func_name, register_file_size, num_instructions (N)
* param1, param2, ..., paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Serializing an `Instruction` requires us to deal with the bytecode. Each line
* of the instructions could be serialized as the following format:
* hash, opcode, f1, f2, ..., fX, field with variable length
* 1. hash: the hash of the instruction. This number will be used to help us
* validate if an instruction is well-formed during deserialization.
* 2. opcode: the opcode code of the instruction.
* 3. f1, f2, ..., fX. These fields together represent the fixed fields in
* an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For
* example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
* 4. The rest of the line indicates the field with variable length, e.g.,
* the shape of a tensor, the args used by an `InvokPacked` instruction, etc.
*/
#ifndef TVM_RELAY_BACKEND_VM_SERIALIZER_H_
#define TVM_RELAY_BACKEND_VM_SERIALIZER_H_
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <tvm/ir.h>
#include <tvm/node/container.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/vm.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace relay {
namespace vm {
using namespace tvm::runtime;
using namespace tvm::runtime::vm;
/*!
* \brief The Relay VM serializer.
*/
class Serializer : public runtime::ModuleNode {
public:
/*!
* \brief Initialize the serializer for a virtual machine.
*
* \param vm The Relay virtual machine.
*/
inline void Init(const VirtualMachine* vm);
/*!
* \brief Return the member function to the frontend.
*
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
*
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const char* type_key() const final { return "Serializer"; }
/*!
* \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc.
*/
std::string Stats() const;
/*!
* \brief Serialize the `vm_` into global section, constant section, and code
* section.
*
* \return The binary representation of the VM.
*/
TVMByteArray Serialize();
/*!
* \brief Get a list of the globals used by the `_vm`.
*
* \return The global map in the form a list.
*/
tvm::Array<tvm::Expr> GetGlobals() const;
/*!
* \brief Get the primitive operators that are contained in the Relay VM.
*
* \return The list of primitve operators.
*/
tvm::Array<tvm::Expr> GetPrimitiveOps() const;
/*!
* \brief Get the serialized form of the `functions` in `vm_`. This is
* essentially bytecode serialization.
*
* \return The serialized vm bytecode.
*
* \note The bytecode is in the following format:
* func_name reg_file_size num_instructions
* param1 param2 ... paramM
* instruction1
* instruction2
* ...
* instructionN
*
* Each instruction is printed in the following format:
* opcode num_fields field1 ... fieldX # The text format.
*
* The field starting from # is only used for debugging. The serialized code
* doesn't contain it, therefore the deserializer doens't need to handle it.
*/
std::string GetBytecode() const;
/*! \brief Get the `lib` module in vm_. Serialization of `runtime::module`
* has already been supported by TVM. Therefore, we only return the runtime
* module and let users have the flexibility to call `export_library` from
* the frontend to save the library to disk.
*
* \return The runtime module that contains the hardwre dependent code.
*/
inline runtime::Module GetLib() const;
virtual ~Serializer() { delete strm_; }
private:
/*! \brief Serialize the globals in vm_. */
void SerializeGlobalSection();
/*! \brief Serialize the constant pool in vm_. */
void SerializeConstantSection();
/*! \brief Serialize primitive op names in vm_. */
void SerializePrimitiveOpNames();
/*! \brief Serialize the vm functions in vm_. */
void SerializeCodeSection();
/*! \brief The Relay virtual machine for to be serialized. */
const VirtualMachine* vm_;
/*! \brief The stream used for serialization. */
dmlc::Stream* strm_;
/*! \brief The serialized code. */
std::string code_;
};
} // namespace vm
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_VM_SERIALIZER_H_
......@@ -23,6 +23,7 @@
* \brief The Relay virtual machine.
*/
#include <dmlc/memory_io.h>
#include <tvm/logging.h>
#include <tvm/runtime/vm.h>
......@@ -91,8 +92,8 @@ Instruction::Instruction(const Instruction& instr) {
return;
case Opcode::InvokeClosure:
this->closure = instr.closure;
this->closure_args_num = instr.closure_args_num;
this->closure_args = Duplicate<RegName>(instr.closure_args, instr.closure_args_num);
this->num_closure_args = instr.num_closure_args;
this->closure_args = Duplicate<RegName>(instr.closure_args, instr.num_closure_args);
return;
case Opcode::Invoke:
this->func_index = instr.func_index;
......@@ -179,9 +180,9 @@ Instruction& Instruction::operator=(const Instruction& instr) {
return *this;
case Opcode::InvokeClosure:
this->closure = instr.closure;
this->closure_args_num = instr.closure_args_num;
this->num_closure_args = instr.num_closure_args;
FreeIf(this->closure_args);
this->closure_args = Duplicate<RegName>(instr.closure_args, instr.closure_args_num);
this->closure_args = Duplicate<RegName>(instr.closure_args, instr.num_closure_args);
return *this;
case Opcode::Invoke:
this->func_index = instr.func_index;
......@@ -262,7 +263,9 @@ Instruction Instruction::Fatal() {
return instr;
}
Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index output_size,
Instruction Instruction::InvokePacked(Index packed_index,
Index arity,
Index output_size,
const std::vector<RegName>& args) {
Instruction instr;
instr.op = Opcode::InvokePacked;
......@@ -380,7 +383,7 @@ Instruction Instruction::InvokeClosure(RegName closure, const std::vector<RegNam
instr.op = Opcode::InvokeClosure;
instr.dst = dst;
instr.closure = closure;
instr.closure_args_num = args.size();
instr.num_closure_args = args.size();
instr.closure_args = new RegName[args.size()];
for (size_t i = 0; i < args.size(); ++i) {
instr.closure_args[i] = args[i];
......@@ -396,7 +399,7 @@ Instruction Instruction::LoadConst(Index const_index, RegName dst) {
return instr;
}
Instruction Instruction::LoadConsti(size_t val, RegName dst) {
Instruction Instruction::LoadConsti(Index val, RegName dst) {
Instruction instr;
instr.op = Opcode::LoadConsti;
instr.dst = dst;
......@@ -432,7 +435,7 @@ void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) {
}
template<typename T>
std::string StrJoin(T* items, int offset, int cnt, std::string delim = ",") {
std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") {
if (cnt == 0) {
return "";
}
......@@ -447,11 +450,11 @@ std::string StrJoin(T* items, int offset, int cnt, std::string delim = ",") {
void InstructionPrint(std::ostream& os, const Instruction& instr) {
switch (instr.op) {
case Opcode::Move: {
os << "move $" << instr.dst << " $" << instr.from;
os << "move $" << instr.dst << " $" << instr.from << std::endl;
break;
}
case Opcode::Ret: {
os << "ret $" << instr.result;
os << "ret $" << instr.result << std::endl;
break;
}
case Opcode::Fatal: {
......@@ -459,74 +462,86 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
break;
}
case Opcode::InvokePacked: {
os << "invoke_packed PackedFunc[" << instr.packed_index << "](in: $"
<< StrJoin<RegName>(instr.packed_args, 0, instr.arity - instr.output_size, ",$")
os << "invoke_packed PackedFunc[" << instr.packed_index << "] (in: $"
<< StrJoin<RegName>(instr.packed_args, 0,
instr.arity - instr.output_size, ", $")
<< ", out: $"
<< StrJoin<RegName>(instr.packed_args, instr.arity - instr.output_size,
instr.output_size, ",$")
<< ")";
instr.output_size, ", $")
<< ")" << std::endl;
break;
}
case Opcode::AllocTensor: {
os << "alloc_tensor $" << instr.dst << " ["
<< StrJoin<int64_t>(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim)
<< StrJoin<int64_t>(instr.alloc_tensor.shape, 0,
instr.alloc_tensor.ndim)
<< "] ";
DLDatatypePrint(os, instr.alloc_tensor.dtype);
os << std::endl;
break;
}
case Opcode::AllocTensorReg: {
os << "alloc_tensor_reg $" << instr.dst << " $"
<< instr.alloc_tensor_reg.shape_register << " ";
DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
os << std::endl;
break;
}
case Opcode::AllocDatatype: {
os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$"
<< StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]";
<< StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]"
<< std::endl;
break;
}
case Opcode::AllocClosure: {
os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index
<< "]($" << StrJoin<RegName>(instr.free_vars, 0, instr.num_freevar, ",$")
<< ")";
<< ")"
<< std::endl;
break;
}
case Opcode::If: {
os << "if " << "$" << instr.if_op.test << " " << instr.if_op.target << " "
<< instr.if_op.true_offset << " " << instr.if_op.false_offset;
<< instr.if_op.true_offset << " " << instr.if_op.false_offset
<< std::endl;
break;
}
case Opcode::Invoke: {
os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($"
<< StrJoin<RegName>(instr.invoke_args_registers, 0, instr.num_args, ",$")
<< ")";
<< ")"
<< std::endl;
break;
}
case Opcode::InvokeClosure: {
os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($"
<< StrJoin<RegName>(instr.closure_args, 0, instr.closure_args_num, ",$")
<< ")";
<< StrJoin<RegName>(instr.closure_args, 0, instr.num_closure_args, ",$")
<< ")"
<< std::endl;
break;
}
case Opcode::LoadConst: {
os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]";
os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]"
<< std::endl;
break;
}
case Opcode::LoadConsti: {
os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]";
os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]"
<< std::endl;
break;
}
case Opcode::GetField: {
os << "get_field $" << instr.dst << " $" << instr.object << "["
<< instr.field_index << "]";
<< instr.field_index << "]"
<< std::endl;
break;
}
case Opcode::GetTag: {
os << "get_tag $" << instr.dst << " $" << instr.get_tag.object;
os << "get_tag $" << instr.dst << " $" << instr.get_tag.object << std::endl;
break;
}
case Opcode::Goto: {
os << "goto " << instr.pc_offset;
os << "goto " << instr.pc_offset << std::endl;
break;
}
default:
......@@ -564,6 +579,23 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
Object obj = args[i];
func_args.push_back(obj);
}
auto it = std::find_if(functions.begin(), functions.end(),
[func_name](const VMFunction& func) {
return func.name == func_name;
});
CHECK(it != functions.end()) << "Cannot find function " << func_name << "\n";
CHECK_EQ(func_args.size() + params_.size(), it->params.size())
<< "The number of provided parameters doesn't match the number of arguments"
<< "\n";
if (!params_.empty()) {
for (const auto& p : it->params) {
const auto& pit = params_.find(p);
if (pit != params_.end()) {
func_args.push_back(pit->second);
}
}
CHECK_EQ(func_args.size(), it->params.size());
}
*rv = this->Invoke(func_name, func_args);
});
} else if (name == "init") {
......@@ -579,12 +611,40 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
}
this->Init(contexts);
});
} else if (name == "load_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->LoadParams(args[0].operator std::string());
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
}
void VirtualMachine::LoadParams(const std::string& params) {
dmlc::MemoryStringStream mss(const_cast<std::string*>(&params));
dmlc::Stream* strm = &mss;
uint64_t header, reserved;
CHECK(strm->Read(&header)) << "Invalid parameter file";
CHECK(header == kTVMNDArrayListMagic) << "Invalid parameter file";
CHECK(strm->Read(&reserved)) << "Invalid parameter file";
std::vector<std::string> names;
CHECK(strm->Read(&names)) << "Invalid parameter file";
uint64_t sz;
strm->Read(&sz);
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) << "Invalid parameter file";
for (size_t i = 0; i < size; i++) {
NDArray arr;
CHECK(arr.Load(strm)) << "Invalid parameter file";
runtime::Object obj = runtime::Object::Tensor(arr);
params_.emplace(std::make_pair(names[i], obj));
}
}
void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) {
auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size);
frames.push_back(frame);
......@@ -662,7 +722,22 @@ void InvokePacked(const PackedFunc& func, Index arg_count, Index output_size,
func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
}
void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) { this->ctxs = ctxs; }
void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
this->ctxs = ctxs;
// Get the list of packed functions.
CHECK(primitive_map.empty() || lib.operator->())
<< "runtime module should have been built for primitive functions"
<< "\n";
for (const auto& it : primitive_map) {
const auto& packed_name = it.first;
auto packed_index = static_cast<size_t>(it.second);
if (packed_funcs.size() <= packed_index) {
packed_funcs.resize(packed_index + 1);
}
packed_funcs[packed_index] = lib.GetFunction(packed_name);
}
}
inline void VirtualMachine::WriteRegister(Index r, const Object& val) {
frames.back().register_file[r] = val;
......@@ -716,8 +791,8 @@ void VirtualMachine::Run() {
goto main_loop;
}
case Opcode::LoadConsti: {
auto tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0});
reinterpret_cast<int32_t*>(tensor->data)[0] = instr.load_consti.val;
auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0});
reinterpret_cast<int64_t*>(tensor->data)[0] = instr.load_consti.val;
WriteRegister(instr.dst, Object::Tensor(tensor));
pc++;
goto main_loop;
......@@ -753,7 +828,7 @@ void VirtualMachine::Run() {
for (auto free_var : closure->free_vars) {
args.push_back(free_var);
}
for (Index i = 0; i < instr.closure_args_num; ++i) {
for (Index i = 0; i < instr.num_closure_args; ++i) {
args.push_back(ReadRegister(instr.closure_args[i]));
}
InvokeGlobal(this->functions[closure->func_index], args);
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, missing-docstring, no-else-return
"""Unit tests for the Relay VM serialization and deserialization."""
import numpy as np
import tvm
from tvm import relay
from tvm.relay.module import Module as rly_module
from tvm.relay import vm as _vm
from tvm.relay import serializer, deserializer
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude
from tvm.contrib import util
from tvm.relay import testing
def create_vm(f, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr):
mod = relay.Module()
mod["main"] = f
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(ctx)
return vm
else:
assert isinstance(f, relay.Module), "expected mod as relay.Module"
compiler = relay.vm.VMCompiler()
vm = compiler.compile(f, target)
vm.init(ctx)
return vm
def veval(vm, *args, ctx=tvm.cpu()):
assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine"
vm.init(ctx)
ret = vm.run(*args)
return ret
def run_network(mod,
params,
data_shape=(1, 3, 224, 224),
dtype='float32'):
def get_vm_output(mod, data, params, target, ctx, dtype='float32'):
ex = relay.create_executor('vm', mod=mod, ctx=ctx)
result = ex.evaluate()(data, **params)
return result.asnumpy().astype(dtype)
def get_serialized_output(mod, data, params, target, ctx, dtype='float32'):
vm = create_vm(mod, ctx, target)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
des_vm.init(ctx)
des_vm.load_params(params)
result = des_vm.run(data)
return result.asnumpy().astype(dtype)
data = np.random.uniform(size=data_shape).astype(dtype)
target = "llvm"
ctx = tvm.cpu(0)
tvm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params,
target, ctx, dtype)
vm_out = get_serialized_output(mod, tvm.nd.array(data.astype(dtype)), params,
target, ctx, dtype)
tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_serializer():
mod = rly_module({})
a = relay.const(1.0, "float32")
x = relay.var('x', shape=(10, 10), dtype='float32')
f1 = relay.Function([x], x + a)
glb_f1 = relay.GlobalVar("f1")
mod[glb_f1] = f1
b = relay.const(2.0, "float32")
y = relay.var('y', shape=(10, 10), dtype='float32')
f2 = relay.Function([y], y - b)
glb_f2 = relay.GlobalVar("f2")
mod[glb_f2] = f2
x1 = relay.var('x1', shape=(10, 10), dtype='float32')
y1 = relay.var('y1', shape=(10, 10), dtype='float32')
main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1))
mod["main"] = main
vm = create_vm(mod)
ser = serializer.Serializer(vm)
stats = ser.stats
assert "scalar" in stats
glbs = ser.globals
assert len(glbs) == 3
assert "f1" in glbs
assert "f2" in glbs
assert "main" in glbs
prim_ops = ser.primitive_ops
assert any(item.startswith('fused_add') for item in prim_ops)
assert any(item.startswith('fused_subtract') for item in prim_ops)
assert any(item.startswith('fused_multiply') for item in prim_ops)
code = ser.bytecode
assert "main 5 2 5" in code
assert "f1 3 1 4" in code
assert "f2 3 1 4" in code
code, lib = ser.serialize()
assert isinstance(code, bytearray)
assert isinstance(lib, tvm.module.Module)
def test_save_load():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x + x)
x_data = np.random.rand(10, 10).astype('float32')
# serialize.
vm = create_vm(f)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
assert isinstance(code, bytearray)
# save and load the code and lib file.
tmp = util.tempdir()
path_lib = tmp.relpath("lib.so")
lib.export_library(path_lib)
with open(tmp.relpath("code.bc"), "wb") as fo:
fo.write(code)
loaded_lib = tvm.module.load(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read())
# deserialize.
deser = deserializer.Deserializer(loaded_code, loaded_lib)
des_vm = deser.deserialize()
res = veval(des_vm, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data)
def test_const():
c = relay.const(1.0, "float32")
x = relay.var('x', shape=(10, 10), dtype='float32')
f = relay.Function([x], x + c)
vm = create_vm(f)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
assert isinstance(code, bytearray)
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
x_data = np.random.rand(10, 10).astype('float32')
res = veval(des_vm, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + 1)
def test_if():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(10, 10))
equal = relay.op.equal(x, y)
equal = relay.op.nn.batch_flatten(equal)
f = relay.Function([x, y], relay.If(relay.op.min(equal, axis=[0, 1]), x,
y))
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(10, 10).astype('float32')
vm = create_vm(f)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
# same
res = veval(des_vm, x_data, x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data)
# diff
res = veval(des_vm, x_data, y_data)
tvm.testing.assert_allclose(res.asnumpy(), y_data)
def test_loop():
mod = relay.module.Module({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
accum = relay.var('accum', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, 'int32'))):
sb.ret(accum)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, 'int32'))
new_accum = relay.add(accum, i)
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
func = relay.Function([i, accum], sb.get())
mod[sum_up] = func
loop_bound = 0
i_data = np.array(loop_bound, dtype='int32')
accum_data = np.array(0, dtype='int32')
iarg = relay.var('i', shape=[], dtype='int32')
aarg = relay.var('accum', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
vm = create_vm(mod)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
result = veval(des_vm, i_data, accum_data)
tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
def test_tuple():
ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
tup = relay.var('tup', type_annotation=ttype)
f = relay.Function([tup], relay.TupleGetItem(tup, 1))
i_data = np.random.rand(41).astype('float32')
j_data = np.random.rand(10).astype('float32')
vm = create_vm(f)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
result = veval(des_vm, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), j_data)
def test_adt_list():
mod = relay.Module()
p = Prelude(mod)
l1 = p.cons(relay.const(1), p.nil())
l21 = p.cons(relay.const(2), l1)
l321 = p.cons(relay.const(3), l21)
f = relay.Function([], l321)
mod["main"] = f
vm = create_vm(mod)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
result = veval(des_vm)
assert len(result) == 2
assert len(result[1]) == 2
assert len(result[1][1]) == 2
res = []
res.append(result[0].asnumpy().tolist())
res.append(result[1][0].asnumpy().tolist())
res.append(result[1][1][0].asnumpy().tolist())
tvm.testing.assert_allclose(res, np.array([3, 2, 1]))
def test_adt_compose():
mod = relay.Module()
p = Prelude(mod)
compose = p.compose
# add_one = fun x -> x + 1
sb = relay.ScopeBuilder()
x = relay.var('x', 'float32')
x1 = sb.let('x1', x)
xplusone = x1 + relay.const(1.0, 'float32')
sb.ret(xplusone)
body = sb.get()
add_one = relay.GlobalVar("add_one")
add_one_func = relay.Function([x], body)
# add_two = compose(add_one, add_one)
sb = relay.ScopeBuilder()
y = relay.var('y', 'float32')
add_two_func = sb.let('add_two', compose(add_one_func, add_one_func))
add_two_res = add_two_func(y)
sb.ret(add_two_res)
add_two_body = sb.get()
mod[add_one] = add_one_func
f = relay.Function([y], add_two_body)
mod["main"] = f
vm = create_vm(mod)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
x_data = np.array(np.random.rand()).astype('float32')
result = veval(des_vm, x_data)
tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
def test_closure():
x = relay.var('x', shape=())
y = relay.var('y', shape=())
f = relay.Function([x], x + y)
ff = relay.Function([y], f)
clo = ff(relay.const(1.0))
main = clo(relay.const(2.0))
vm = create_vm(main)
ser = serializer.Serializer(vm)
code, lib = ser.serialize()
deser = deserializer.Deserializer(code, lib)
des_vm = deser.deserialize()
res = veval(des_vm)
tvm.testing.assert_allclose(res.asnumpy(), 3.0)
def test_resnet():
mod, params = testing.resnet.get_workload(batch_size=1, num_layers=18)
run_network(mod, params)
def test_mobilenet():
mod, params = testing.mobilenet.get_workload(batch_size=1)
run_network(mod, params)
if __name__ == "__main__":
test_serializer()
test_save_load()
test_const()
test_if()
test_loop()
test_tuple()
test_adt_list()
test_adt_compose()
test_closure()
test_resnet()
test_mobilenet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment