Commit 32be34a0 by Andrew Tulloch Committed by Tianqi Chen

[Runtime] Allow for parameter sharing in GraphRuntime (#3384)

Summary:

In multi-threaded applications where we have multiple inferences on the
same model in parallel (consider e.g. a TTS system handling multiple
requests), it can be useful to share the parameters of a model amongst
these multiple instances. This improves the cache utilization behaviour
of the system, as multiple cores can use the same set of weights instead
of evicting the identical copies of weights in a shared cache.

As the underlying `NDArray` instances in `data_entry_` implement a
ref-counted based sharing system, this is a simple modification of the
`GraphRuntime::LoadParams` logic to instead copy parameters from an
existing GraphRuntime instance. This is a little ugly in that we need
both the pre-existing GraphRuntime instance, as well as the 'serialized'
params (since we need to know the set of names we should copy), but
without imposing additional assumptions (i.e. storing the set of param
names in GraphRuntime, and enforcing that shared param names are
identical to the parameters set in the preceding `LoadParams` call),
this seems unavoidable.

Test Plan:

Unit test added.
parent e97c0101
...@@ -129,6 +129,7 @@ class GraphModule(object): ...@@ -129,6 +129,7 @@ class GraphModule(object):
self._get_input = module["get_input"] self._get_input = module["get_input"]
self._get_num_outputs = module["get_num_outputs"] self._get_num_outputs = module["get_num_outputs"]
self._load_params = module["load_params"] self._load_params = module["load_params"]
self._share_params = module["share_params"]
def set_input(self, key=None, value=None, **params): def set_input(self, key=None, value=None, **params):
"""Set inputs to the module via kwargs """Set inputs to the module via kwargs
...@@ -234,6 +235,19 @@ class GraphModule(object): ...@@ -234,6 +235,19 @@ class GraphModule(object):
""" """
self._load_params(bytearray(params_bytes)) self._load_params(bytearray(params_bytes))
def share_params(self, other, params_bytes):
"""Share parameters from pre-existing GraphRuntime instance.
Parameters
----------
other: GraphRuntime
The parent GraphRuntime from which this instance should share
it's parameters.
params_bytes : bytearray
The serialized parameter dict (used only for the parameter names).
"""
self._share_params(other.module, bytearray(params_bytes))
def __getitem__(self, key): def __getitem__(self, key):
"""Get internal module function """Get internal module function
......
...@@ -184,6 +184,32 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { ...@@ -184,6 +184,32 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
} }
} }
void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid parameters file format";
CHECK(header == kTVMNDArrayListMagic)
<< "Invalid parameters file format";
CHECK(strm->Read(&reserved))
<< "Invalid parameters file format";
std::vector<std::string> names;
CHECK(strm->Read(&names)) << "Invalid parameters file format";
uint64_t sz;
strm->Read(&sz);
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) << "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
int in_idx = GetInputIndex(names[i]);
CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i];
uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
CHECK_LT(eid, data_entry_.size());
CHECK_EQ(data_entry_[eid].use_count(), 1);
data_entry_[eid] = other.GetInput(GetInputIndex(names[i]));
CHECK_GT(data_entry_[eid].use_count(), 1);
}
this->SetupOpExecs();
}
void GraphRuntime::SetupStorage() { void GraphRuntime::SetupStorage() {
// Grab saved optimization plan from graph. // Grab saved optimization plan from graph.
std::vector<TVMType> vtype; std::vector<TVMType> vtype;
...@@ -372,6 +398,14 @@ PackedFunc GraphRuntime::GetFunction( ...@@ -372,6 +398,14 @@ PackedFunc GraphRuntime::GetFunction(
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->LoadParams(args[0].operator std::string()); this->LoadParams(args[0].operator std::string());
}); });
} else if (name == "share_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
const auto& module = args[0].operator Module();
CHECK_EQ(module.operator->()->type_key(), "GraphRuntime");
const auto& param_blob = args[1].operator std::string();
dmlc::MemoryStringStream strm(const_cast<std::string*>(&param_blob));
this->ShareParams(dynamic_cast<const GraphRuntime&>(*module.operator->()), &strm);
});
} else { } else {
return PackedFunc(); return PackedFunc();
} }
......
...@@ -147,10 +147,19 @@ class GraphRuntime : public ModuleNode { ...@@ -147,10 +147,19 @@ class GraphRuntime : public ModuleNode {
* \param param_blob A binary blob of parameter. * \param param_blob A binary blob of parameter.
*/ */
void LoadParams(const std::string& param_blob); void LoadParams(const std::string& param_blob);
/*!
* \brief Get total number of nodes. /*!
* \return Total number of nodes. * \brief Share parameters from pre-existing GraphRuntime instance.
*/ * \param other A GraphRuntime instance, previously with |LoadParams| called with the
* identical input |param_blob|.
* \param strm The input stream.
*/
void ShareParams(const GraphRuntime& other, dmlc::Stream* strm);
/*!
* \brief Get total number of nodes.
* \return Total number of nodes.
*/
uint32_t GetNumOfNodes() const { uint32_t GetNumOfNodes() const {
return static_cast<uint32_t>(nodes_.size()); return static_cast<uint32_t>(nodes_.size());
} }
......
...@@ -81,8 +81,46 @@ def test_graph_simple(): ...@@ -81,8 +81,46 @@ def test_graph_simple():
out = mod.get_output(0, out) out = mod.get_output(0, out)
np.testing.assert_equal(out.asnumpy(), a + 1) np.testing.assert_equal(out.asnumpy(), a + 1)
def check_sharing():
from tvm import relay
x = relay.var('x', shape=(1, 10))
y = relay.var('y', shape=(1, 10))
z = relay.add(x, y)
func = relay.Function([x, y], z)
x_in = np.ones((1, 10)).astype("float32")
params = {'x': x_in}
graph, lib, params = relay.build(func, target="llvm", params=params)
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled")
return
mod_shared = graph_runtime.create(graph, lib, tvm.cpu(0))
mod_shared.load_params(relay.save_param_dict(params))
num_mods = 10
mods = [graph_runtime.create(graph, lib, tvm.cpu(0))
for _ in range(num_mods)]
for mod in mods:
mod.share_params(mod_shared, relay.save_param_dict(params))
a = np.random.uniform(size=(1, 10)).astype("float32")
for mod in mods:
mod.run(y=a)
out = mod.get_output(0, tvm.nd.empty((1, 10)))
np.testing.assert_equal(out.asnumpy(), x_in + a)
# Explicitly delete the shared module and verify correctness.
del mod_shared
for mod in mods:
mod.run(y=a)
out = mod.get_output(0, tvm.nd.empty((1, 10)))
np.testing.assert_equal(out.asnumpy(), x_in + a)
del mod
check_verify() check_verify()
check_remote() check_remote()
check_sharing()
if __name__ == "__main__": if __name__ == "__main__":
test_graph_simple() test_graph_simple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment