/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file stackvm_module.cc */ #include <tvm/runtime/registry.h> #include <tvm/runtime/module.h> #include <dmlc/memory_io.h> #include <memory> #include <utility> #include <unordered_map> #include "stackvm_module.h" #include "../file_util.h" namespace tvm { namespace runtime { class StackVMModuleNode : public runtime::ModuleNode { public: const char* type_key() const { return "stackvm"; } PackedFunc GetFunction( const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { if (name == runtime::symbol::tvm_module_main) { return GetFunction(entry_func_, sptr_to_self); } auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const StackVM& vm = it->second; // capture sptr_to_self to keep module node alive. return PackedFunc([vm, sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { vm.Run(args, this); }); } std::string GetSource(const std::string& format) final { std::ostringstream os; for (const auto& kv : fmap_) { os << "Function: " << kv.first << '\n'; os << kv.second; } return os.str(); } void SaveToFile(const std::string& file_name, const std::string& format) final { std::string data, mblob; dmlc::MemoryStringStream writer(&data); dmlc::Stream* strm = &writer; strm->Write(fmap_); strm->Write(entry_func_); // also save imports uint64_t num_imports = static_cast<uint64_t>(imports_.size()); strm->Write(num_imports); for (runtime::Module im : imports_) { CHECK_EQ(im->imports().size(), 0U) << "Only support simply one-level hierarchy"; std::string tkey = im->type_key(); strm->Write(tkey); LOG(INFO) << "save " << tkey; im->SaveToBinary(strm); LOG(INFO) << "FInish save " << tkey; } SaveBinaryToFile(file_name, data); } static Module Create(std::unordered_map<std::string, StackVM> fmap, std::string entry_func) { auto n = make_object<StackVMModuleNode>(); n->fmap_ = std::move(fmap); n->entry_func_ = std::move(entry_func); return Module(n); } static Module Load(dmlc::Stream* strm) { std::unordered_map<std::string, StackVM> fmap; std::string entry_func, data; strm->Read(&fmap); strm->Read(&entry_func); auto n = make_object<StackVMModuleNode>(); n->fmap_ = std::move(fmap); n->entry_func_ = std::move(entry_func); uint64_t num_imports; strm->Read(&num_imports); for (uint64_t i = 0; i < num_imports; ++i) { std::string tkey; CHECK(strm->Read(&tkey)); std::string fkey = "runtime.module.loadbinary_" + tkey; const PackedFunc* f = Registry::Get(fkey); CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented."; Module m = (*f)(static_cast<void*>(strm)); n->imports_.emplace_back(std::move(m)); } return Module(n); } static Module LoadFromFile(std::string file_name, std::string format) { std::string data; LoadBinaryFromFile(file_name, &data); dmlc::MemoryStringStream reader(&data); return Load(&reader); } private: // internal function map std::unordered_map<std::string, StackVM> fmap_; // entry function. std::string entry_func_; }; Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap, std::string entry_func) { return StackVMModuleNode::Create(fmap, entry_func); } TVM_REGISTER_GLOBAL("runtime.module.loadfile_stackvm") .set_body_typed(StackVMModuleNode::LoadFromFile); } // namespace runtime } // namespace tvm