/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2019 by Contributors * \file src/relay/backend/vm/serializer.cc * \brief Implementation of serializing APIs for the Relay VM. */ #include "serializer.h" #include <tvm/runtime/registry.h> #include <tvm/runtime/c_runtime_api.h> #include <algorithm> #include <memory> #include <sstream> #include <utility> #include <vector> #include "serialize_util.h" namespace tvm { namespace relay { namespace vm { void Serializer::Init(const VirtualMachine* vm) { vm_ = vm; // Initialize the stream object. strm_ = new dmlc::MemoryStringStream(&code_); } runtime::PackedFunc Serializer::GetFunction( const std::string& name, const std::shared_ptr<ModuleNode>& sptr_to_self) { if (name == "get_lib") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); }); } else if (name == "get_primitive_ops") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetPrimitiveOps(); }); } else if (name == "get_bytecode") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetBytecode(); }); } else if (name == "get_globals") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGlobals(); }); } else if (name == "get_stats") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); } else if (name == "serialize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Serialize(); }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); } } tvm::Array<tvm::Expr> Serializer::GetPrimitiveOps() const { std::vector<tvm::Expr> ret; for (const auto& it : vm_->primitive_map) { auto packed_name = tvm::ir::StringImm::make(it.first); auto packed_index = static_cast<size_t>(it.second); if (ret.size() <= packed_index) { ret.resize(packed_index + 1); } ret[packed_index] = packed_name; } return ret; } std::string Serializer::Stats() const { std::ostringstream oss; oss << "Relay VM statistics:" << std::endl; // Get the number of constants and the shape of each of them. oss << " Constant shapes (# " << vm_->constants.size() << "): ["; for (const auto& it : vm_->constants) { auto cell = it.AsTensor(); CHECK(cell.operator->()); runtime::NDArray data = cell->data; const auto& shape = data.Shape(); // Scalar if (shape.empty()) { oss << "scalar, "; continue; } oss << "["; for (auto s : shape) { oss << s << ", "; } oss.seekp(-2, oss.cur); oss << "], " << std::endl; } if (!vm_->constants.empty()) oss.seekp(-2, oss.cur); oss << "]" << std::endl; // Get the number of globals and the name of each of them. oss << " Globals (#" << vm_->global_map.size() << "): ["; for (const auto& it : vm_->global_map) { oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; } if (!vm_->global_map.empty()) oss.seekp(-2, oss.cur); oss << "]" << std::endl; // Get the number of primitive ops and the name of each of them. oss << " Primitive ops (#" << vm_->primitive_map.size() << "): ["; const auto& prim_ops = GetPrimitiveOps(); for (const auto& it : prim_ops) { oss << it << ", "; } if (!prim_ops.empty()) oss.seekp(-2, oss.cur); oss << "]" << std::endl; return oss.str(); } TVMByteArray Serializer::Serialize() { uint64_t header = kTVMVMBytecodeMagic; strm_->Write(header); std::string version = TVM_VERSION; strm_->Write(version); // Global section. SerializeGlobalSection(); // Constant section. SerializeConstantSection(); // Primitive names. SerializePrimitiveOpNames(); // Code section. SerializeCodeSection(); TVMByteArray arr; arr.data = code_.c_str(); arr.size = code_.length(); return arr; } void Serializer::SerializeGlobalSection() { auto globals = GetGlobals(); std::vector<std::string> glbs; for (const auto& it : globals) { glbs.push_back(it.as<tvm::ir::StringImm>()->value); } strm_->Write(glbs); } void Serializer::SerializeConstantSection() { std::vector<DLTensor*> arrays; for (const auto& obj : vm_->constants) { auto cell = obj.AsTensor(); runtime::NDArray data = cell->data; arrays.push_back(const_cast<DLTensor*>(data.operator->())); } strm_->Write(static_cast<uint64_t>(vm_->constants.size())); for (const auto& it : arrays) { runtime::SaveDLTensor(strm_, it); } } void Serializer::SerializePrimitiveOpNames() { auto names = GetPrimitiveOps(); std::vector<std::string> primitive_names; for (const auto& it : names) { primitive_names.push_back(it.as<tvm::ir::StringImm>()->value); } strm_->Write(primitive_names); } // Serialize a virtual machine instruction. It creates a list that contains the // hash, opcode, and all fields of an instruction. // // For example, the function signature used to create an `AllocTensor` // instruction is: // Instruction AllocTensor(std::vector<Index> shape, DLDataType dtype, RegName dst) // // The serialized form will be: // `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` // // where hash is the hash of serialized instruction that is computed internally // by the `VMInstructionSerializer`. It is used for sanity check before decoding. // 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` // represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` // is the destination register, and the rest of it together indicates the shape // of the tensor to be allocated. VMInstructionSerializer SerializeInstruction(const Instruction& instr) { std::vector<Index> fields; // Save the opcode. DLOG(INFO) << "Serializing: " << instr << std::endl; switch (instr.op) { case Opcode::Move: { // Number of fields = 2 fields.assign({instr.from, instr.dst}); break; } case Opcode::Ret: { // Number of fields = 1 fields.push_back(instr.result); break; } case Opcode::Fatal: { // Number of fields = 0 break; } case Opcode::InvokePacked: { // Number of fields = 3 + instr.arity // Note that arity includes both input arguments and outputs. We will // put all the `arity` number of fields in the end for serialization. fields.assign({instr.packed_index, instr.arity, instr.output_size}); // Save the args. fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); break; } case Opcode::AllocTensor: { // Number of fields = 5 + instr.alloc_tensor.ndim // Save `DLDataType` and the dst register. const auto& dtype = instr.alloc_tensor.dtype; fields.assign({dtype.code, dtype.bits, dtype.lanes}); // The number of dimensions is not needed for constructing an // `AllocTensor` instruction as it equals to the length of the `shape` // vector. However, we save it to conveniently deserialize the instruction // because we will know how many fields are needed by the `shape` argument. fields.push_back(instr.alloc_tensor.ndim); fields.push_back(instr.dst); // Save the shape of the tensor. // Note that this field is rotated to the end of the list. fields.insert(fields.end(), instr.alloc_tensor.shape, instr.alloc_tensor.shape + instr.alloc_tensor.ndim); break; } case Opcode::AllocTensorReg: { // Number of fields = 5 fields.push_back(instr.alloc_tensor_reg.shape_register); // Save `DLDataType` and the dst register. const auto& dtype = instr.alloc_tensor.dtype; fields.assign({dtype.code, dtype.bits, dtype.lanes}); fields.push_back(instr.dst); break; } case Opcode::AllocDatatype: { // Number of fields = 3 + instr.num_fields fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); // Save the fields. fields.insert(fields.end(), instr.datatype_fields, instr.datatype_fields + instr.num_fields); break; } case Opcode::AllocClosure: { // Number of fields = 3 + instr.num_freevar fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); // Save the free vars. fields.insert(fields.end(), instr.free_vars, instr.free_vars + instr.num_freevar); break; } case Opcode::If: { // Number of fields = 4 fields.assign({instr.if_op.test, instr.if_op.target, instr.if_op.true_offset, instr.if_op.false_offset}); break; } case Opcode::Invoke: { // Number of fields = 3 + instr.num_args fields.assign({instr.func_index, instr.num_args, instr.dst}); // Save the args. fields.insert(fields.end(), instr.invoke_args_registers, instr.invoke_args_registers + instr.num_args); break; } case Opcode::InvokeClosure: { // Number of fields = 3 + instr.num_closure_args fields.assign({instr.closure, instr.num_closure_args, instr.dst}); // Save the args. fields.insert(fields.end(), instr.closure_args, instr.closure_args + instr.num_closure_args); break; } case Opcode::LoadConst: { // Number of fields = 2 fields.assign({instr.const_index, instr.dst}); break; } case Opcode::LoadConsti: { // Number of fields = 2 fields.assign({instr.load_consti.val, instr.dst}); break; } case Opcode::GetField: { // Number of fields = 3 fields.assign({instr.object, instr.field_index, instr.dst}); break; } case Opcode::GetTag: { // Number of fields = 2 fields.assign({instr.get_tag.object, instr.dst}); break; } case Opcode::Goto: { // Number of fields = 1 fields.push_back(instr.pc_offset); break; } default: LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op); break; } return VMInstructionSerializer(static_cast<Index>(instr.op), fields); } void Serializer::SerializeCodeSection() { // Save the number of functions. strm_->Write(static_cast<uint64_t>(vm_->functions.size())); for (const auto& func : vm_->functions) { // Serialize the function info. VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(), func.params); func_format.Save(strm_); // Serialize each instruction. for (const auto& instr : func.instructions) { const auto& serialized_instr = SerializeInstruction(instr); serialized_instr.Save(strm_); } } } tvm::Array<tvm::Expr> Serializer::GetGlobals() const { tvm::Array<tvm::Expr> ret; std::vector<std::pair<std::string, Index> > globals(vm_->global_map.begin(), vm_->global_map.end()); auto comp = [](const std::pair<std::string, Index>& a, const std::pair<std::string, Index>& b) { return a.second < b.second; }; std::sort(globals.begin(), globals.end(), comp); for (const auto& it : globals) { ret.push_back(tvm::ir::StringImm::make(it.first)); } return ret; } std::string Serializer::GetBytecode() const { std::ostringstream oss; for (const auto& func : vm_->functions) { // Print the header of the function format. oss << "# func name, reg file size, param count, inst count:" << std::endl; oss << func.name << " " << func.register_file_size << " " << func.params.size() << " " << func.instructions.size() << std::endl; // Print pramams of a `VMFunction`. oss << "# Parameters:"<< std::endl; for (const auto& param : func.params) { oss << param << " "; } oss << std::endl; // Print the instructions of a `VMFunction`. // The part after ";" is the instruction in text format. oss << "hash, opcode, fields # inst(text):"<< std::endl; for (const auto& instr : func.instructions) { const auto& serialized_instr = SerializeInstruction(instr); oss << std::hex << "0x" << serialized_instr.Hash() << " " << std::dec << serialized_instr.opcode << " "; for (auto it : serialized_instr.fields) { oss << it << " "; } oss << " # " << instr; if (oss.str().back() != '\n') oss << std::endl; } } return oss.str(); } runtime::Module Serializer::GetLib() const { return vm_->lib; } runtime::Module CreateSerializer(const VirtualMachine* vm) { std::shared_ptr<Serializer> exec = std::make_shared<Serializer>(); exec->Init(vm); return runtime::Module(exec); } TVM_REGISTER_GLOBAL("relay._vm._Serializer") .set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* vm = dynamic_cast<VirtualMachine*>(mod.operator->()); CHECK(vm) << "Virtual machine has not been defined yet." << "\n"; *rv = CreateSerializer(vm); }); } // namespace vm } // namespace relay } // namespace tvm