diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 4d0698a..0c9ce40 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -129,6 +129,7 @@ class GraphModule(object): self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] self._load_params = module["load_params"] + self._share_params = module["share_params"] def set_input(self, key=None, value=None, **params): """Set inputs to the module via kwargs @@ -234,6 +235,19 @@ class GraphModule(object): """ self._load_params(bytearray(params_bytes)) + def share_params(self, other, params_bytes): + """Share parameters from pre-existing GraphRuntime instance. + + Parameters + ---------- + other: GraphRuntime + The parent GraphRuntime from which this instance should share + it's parameters. + params_bytes : bytearray + The serialized parameter dict (used only for the parameter names). + """ + self._share_params(other.module, bytearray(params_bytes)) + def __getitem__(self, key): """Get internal module function diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 960d509..cc37a85 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -184,6 +184,32 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { } } + void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { + uint64_t header, reserved; + CHECK(strm->Read(&header)) + << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) + << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) + << "Invalid parameters file format"; + std::vector<std::string> names; + CHECK(strm->Read(&names)) << "Invalid parameters file format"; + uint64_t sz; + strm->Read(&sz); + size_t size = static_cast<size_t>(sz); + CHECK(size == names.size()) << "Invalid parameters file format"; + for (size_t i = 0; i < size; ++i) { + int in_idx = GetInputIndex(names[i]); + CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i]; + uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); + CHECK_LT(eid, data_entry_.size()); + CHECK_EQ(data_entry_[eid].use_count(), 1); + data_entry_[eid] = other.GetInput(GetInputIndex(names[i])); + CHECK_GT(data_entry_[eid].use_count(), 1); + } + this->SetupOpExecs(); +} + void GraphRuntime::SetupStorage() { // Grab saved optimization plan from graph. std::vector<TVMType> vtype; @@ -372,6 +398,14 @@ PackedFunc GraphRuntime::GetFunction( return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->LoadParams(args[0].operator std::string()); }); + } else if (name == "share_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + const auto& module = args[0].operator Module(); + CHECK_EQ(module.operator->()->type_key(), "GraphRuntime"); + const auto& param_blob = args[1].operator std::string(); + dmlc::MemoryStringStream strm(const_cast<std::string*>(¶m_blob)); + this->ShareParams(dynamic_cast<const GraphRuntime&>(*module.operator->()), &strm); + }); } else { return PackedFunc(); } diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index 5298f22..e3f5815 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -147,10 +147,19 @@ class GraphRuntime : public ModuleNode { * \param param_blob A binary blob of parameter. */ void LoadParams(const std::string& param_blob); - /*! - * \brief Get total number of nodes. - * \return Total number of nodes. - */ + + /*! + * \brief Share parameters from pre-existing GraphRuntime instance. + * \param other A GraphRuntime instance, previously with |LoadParams| called with the + * identical input |param_blob|. + * \param strm The input stream. + */ + void ShareParams(const GraphRuntime& other, dmlc::Stream* strm); + + /*! + * \brief Get total number of nodes. + * \return Total number of nodes. + */ uint32_t GetNumOfNodes() const { return static_cast<uint32_t>(nodes_.size()); } diff --git a/tests/python/unittest/test_runtime_graph.py b/tests/python/unittest/test_runtime_graph.py index 20af8a0..f331f5b 100644 --- a/tests/python/unittest/test_runtime_graph.py +++ b/tests/python/unittest/test_runtime_graph.py @@ -81,8 +81,46 @@ def test_graph_simple(): out = mod.get_output(0, out) np.testing.assert_equal(out.asnumpy(), a + 1) + def check_sharing(): + from tvm import relay + x = relay.var('x', shape=(1, 10)) + y = relay.var('y', shape=(1, 10)) + z = relay.add(x, y) + func = relay.Function([x, y], z) + + x_in = np.ones((1, 10)).astype("float32") + params = {'x': x_in} + graph, lib, params = relay.build(func, target="llvm", params=params) + + if not tvm.module.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod_shared = graph_runtime.create(graph, lib, tvm.cpu(0)) + mod_shared.load_params(relay.save_param_dict(params)) + num_mods = 10 + mods = [graph_runtime.create(graph, lib, tvm.cpu(0)) + for _ in range(num_mods)] + + for mod in mods: + mod.share_params(mod_shared, relay.save_param_dict(params)) + + a = np.random.uniform(size=(1, 10)).astype("float32") + for mod in mods: + mod.run(y=a) + out = mod.get_output(0, tvm.nd.empty((1, 10))) + np.testing.assert_equal(out.asnumpy(), x_in + a) + + # Explicitly delete the shared module and verify correctness. + del mod_shared + for mod in mods: + mod.run(y=a) + out = mod.get_output(0, tvm.nd.empty((1, 10))) + np.testing.assert_equal(out.asnumpy(), x_in + a) + del mod + check_verify() check_remote() + check_sharing() if __name__ == "__main__": test_graph_simple()