Commit 3e765edc by Logan Weber Committed by Tianqi Chen

[Relay] Port param dict save/load from NNVM (#2620)

parent c7f65ce2
......@@ -136,7 +136,7 @@ def load_json(json_str):
def save_json(node):
"""Load tvm object as json string.
"""Save tvm object as json string.
Parameters
----------
......
......@@ -13,6 +13,7 @@ from .build_module import build, build_config, create_executor, optimize
from . import prelude
from . import parser
from . import debug
from . import param_dict
# Root operators
from .op import Op
......@@ -85,3 +86,7 @@ ExprMutator = expr_functor.ExprMutator
# Parser
fromtext = parser.fromtext
# Param Serialization
save_param_dict = param_dict.save_param_dict
load_param_dict = param_dict.load_param_dict
# pylint: disable=invalid-name
"""Helper utility to save parameter dicts."""
import tvm
_save_param_dict = tvm.get_global_func("tvm.relay._save_param_dict")
_load_param_dict = tvm.get_global_func("tvm.relay._load_param_dict")
def save_param_dict(params):
"""Save parameter dictionary to binary bytes.
The result binary bytes can be loaded by the
GraphModule with API "load_params".
Parameters
----------
params : dict of str to NDArray
The parameter dictionary.
Returns
-------
param_bytes: bytearray
Serialized parameters.
Examples
--------
.. code-block:: python
# compile and save the modules to file.
graph, lib, params = tvm.relay.build(func, target=target, params=params)
module = graph_runtime.create(graph, lib, tvm.gpu(0))
# save the parameters as byte array
param_bytes = tvm.relay.save_param_dict(params)
# We can serialize the param_bytes and load it back later.
# Pass in byte array to module to directly set parameters
module.load_params(param_bytes)
"""
args = []
for k, v in params.items():
args.append(k)
args.append(tvm.nd.array(v))
return _save_param_dict(*args)
def load_param_dict(param_bytes):
"""Load parameter dictionary to binary bytes.
Parameters
----------
param_bytes: bytearray
Serialized parameters.
Returns
-------
params : dict of str to NDArray
The parameter dictionary.
"""
if isinstance(param_bytes, (bytes, str)):
param_bytes = bytearray(param_bytes)
load_arr = _load_param_dict(param_bytes)
return {v.name : v.array for v in load_arr}
......@@ -578,5 +578,10 @@ TVM_REGISTER_API("relay.backend.CreateInterpreter")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CreateInterpreter(args[0], args[1], args[2]);
});
TVM_REGISTER_NODE_TYPE(ClosureNode);
TVM_REGISTER_NODE_TYPE(TupleValueNode);
TVM_REGISTER_NODE_TYPE(TensorValueNode);
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2019 by Contributors
* \file param_dict.cc
* \brief Implementation and registration of parameter dictionary
* serializing/deserializing functions.
*/
#include "param_dict.h"
#include <dmlc/memory_io.h>
#include <string>
#include <vector>
namespace tvm {
namespace relay {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_EQ(args.size() % 2, 0u);
// `args` is in the form "key, value, key, value, ..."
size_t num_params = args.size() / 2;
std::vector<std::string> names;
names.reserve(num_params);
std::vector<DLTensor*> arrays;
arrays.reserve(num_params);
for (size_t i = 0; i < num_params * 2; i += 2) {
names.emplace_back(args[i].operator std::string());
arrays.emplace_back(args[i + 1].operator DLTensor*());
}
std::string bytes;
dmlc::MemoryStringStream strm(&bytes);
dmlc::Stream* fo = &strm;
uint64_t header = kTVMNDArrayListMagic, reserved = 0;
fo->Write(header);
fo->Write(reserved);
fo->Write(names);
{
uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(sz);
for (size_t i = 0; i < sz; ++i) {
tvm::runtime::SaveDLTensor(fo, arrays[i]);
}
}
TVMByteArray arr;
arr.data = bytes.c_str();
arr.size = bytes.length();
*rv = arr;
});
TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string bytes = args[0];
std::vector<std::string> names;
dmlc::MemoryStringStream memstrm(&bytes);
dmlc::Stream* strm = &memstrm;
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid parameters file format";
CHECK(header == kTVMNDArrayListMagic)
<< "Invalid parameters file format";
CHECK(strm->Read(&reserved))
<< "Invalid parameters file format";
CHECK(strm->Read(&names))
<< "Invalid parameters file format";
uint64_t sz;
strm->Read(&sz, sizeof(sz));
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size())
<< "Invalid parameters file format";
tvm::Array<NamedNDArray> ret;
for (size_t i = 0; i < size; ++i) {
tvm::runtime::NDArray temp;
temp.Load(strm);
auto n = tvm::make_node<NamedNDArrayNode>();
n->name = std::move(names[i]);
n->array = temp;
ret.push_back(NamedNDArray(n));
}
*rv = ret;
});
TVM_REGISTER_NODE_TYPE(NamedNDArrayNode);
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2019 by Contributors
* \file param_dict.h
* \brief Definitions for serializing and deserializing parameter dictionaries.
*/
#ifndef TVM_RELAY_BACKEND_PARAM_DICT_H_
#define TVM_RELAY_BACKEND_PARAM_DICT_H_
#include <tvm/node/node.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
/*!
* \brief Wrapper node for naming `NDArray`s.
*/
struct NamedNDArrayNode : public ::tvm::Node {
std::string name;
tvm::runtime::NDArray array;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("array", &array);
}
static constexpr const char* _type_key = "NamedNDArray";
TVM_DECLARE_NODE_TYPE_INFO(NamedNDArrayNode, Node);
};
TVM_DEFINE_NODE_REF(NamedNDArray, NamedNDArrayNode);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BACKEND_PARAM_DICT_H_
import os
import numpy as np
import tvm
import json
import base64
from tvm._ffi.base import py_str
from tvm.relay.op import add
from tvm import relay
from tvm import rpc
from tvm.contrib import util, graph_runtime
def test_save_load():
x = np.ones((10, 2)).astype("float32")
y = np.ones((1, 2, 3)).astype("float32")
params = {"x": x, "y": y}
param_bytes = relay.save_param_dict(params)
assert isinstance(param_bytes, bytearray)
param2 = relay.load_param_dict(param_bytes)
assert len(param2) == 2
np.testing.assert_equal(param2["x"].asnumpy(), x)
np.testing.assert_equal(param2["y"].asnumpy(), y)
def test_ndarray_reflection():
# Make two `NDArrayWrapper`s that point to the same underlying array.
np_array = np.random.uniform(size=(10, 2)).astype("float32")
tvm_array = tvm.nd.array(np_array)
param_dict = {'x': tvm_array, 'y': tvm_array}
assert param_dict['x'].same_as(param_dict['y'])
# Serialize then deserialize `param_dict`.
deser_param_dict = relay.load_param_dict(relay.save_param_dict(param_dict))
# Make sure the data matches the original data and `x` and `y` contain the same data.
np.testing.assert_equal(deser_param_dict['x'].asnumpy(), tvm_array.asnumpy())
# Make sure `x` and `y` contain the same data.
np.testing.assert_equal(deser_param_dict['x'].asnumpy(), deser_param_dict['y'].asnumpy())
def test_bigendian_rpc_param():
"""Test big endian rpc when there is a PowerPC RPC server available"""
host = os.environ.get("TVM_POWERPC_TEST_HOST", None)
port = os.environ.get("TVM_POWERPC_TEST_PORT", 9090)
if host is None:
return
def verify_graph_runtime(remote, target, shape, dtype):
x = relay.var('x')
y = relay.const(1)
z = relay.add(x, y)
func = relay.Function([x], z)
x_in = np.ones(shape).astype(dtype)
params = {'x': x_in}
graph, lib, params = relay.build(func, target=target, params=params)
temp = util.tempdir()
path_dso = temp.relpath("dev_lib.o")
lib.save(path_dso)
remote.upload(path_dso)
lib = remote.load_module("dev_lib.o")
ctx = remote.cpu(0)
mod = graph_runtime.create(graph, lib, ctx)
mod.load_params(relay.save_param_dict(params))
mod.run()
out = mod.get_output(0, tvm.nd.empty(shape, dtype=dtype, ctx=ctx))
tvm.testing.assert_allclose(x_in + 1, out.asnumpy())
print("Test RPC connection to PowerPC...")
remote = rpc.connect(host, port)
target = "llvm -mtriple=powerpc-linux-gnu"
for dtype in ["float32", "float64", "int32", "int8"]:
verify_graph_runtime(remote, target, (10,), dtype)
if __name__ == "__main__":
test_save_load()
test_ndarray_reflection()
test_bigendian_rpc_param()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment