/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2018 by Contributors * \file vulkan_module.cc */ #include <dmlc/memory_io.h> #include <tvm/runtime/registry.h> #include <tvm/runtime/module.h> #include <array> #include <string> #include <mutex> #include "vulkan_common.h" #include "vulkan_module.h" #include "../pack_args.h" #include "../thread_storage_scope.h" #include "../meta_data.h" #include "../file_util.h" namespace tvm { namespace runtime { void VulkanShader::Save(dmlc::Stream* writer) const { writer->Write(flag); writer->Write(data); } bool VulkanShader::Load(dmlc::Stream* reader) { if (!reader->Read(&flag)) return false; if (!reader->Read(&data)) return false; return true; } // Multi-device enabled module. class VulkanModuleNode final :public runtime::ModuleNode { public: // Pipeline cache states struct PipelineEntry { VkShaderModule shader{VK_NULL_HANDLE}; VkPipelineLayout pipeline_layout{VK_NULL_HANDLE}; VkDescriptorSetLayout descriptor_layout{VK_NULL_HANDLE}; VkPipeline pipeline{VK_NULL_HANDLE}; }; // constructor explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap, std::unordered_map<std::string, FunctionInfo> fmap, std::string source) : smap_(smap), fmap_(fmap), source_(source) { } ~VulkanModuleNode() { // cleanup vulkan related caches. for (DeviceEntry& e : finfo_) { if (e.device == nullptr) continue; for (auto &kv : e.smap) { PipelineEntry& pe = kv.second; vkDestroyShaderModule(e.device, pe.shader, nullptr); vkDestroyDescriptorSetLayout(e.device, pe.descriptor_layout, nullptr); vkDestroyPipelineLayout(e.device, pe.pipeline_layout, nullptr); vkDestroyPipeline(e.device, pe.pipeline, nullptr); } } } const char* type_key() const final { return "vulkan"; } PackedFunc GetFunction( const std::string& name, const std::shared_ptr<ModuleNode>& sptr_to_self) final; void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); CHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan"; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); std::string data_bin; dmlc::MemoryStringStream fs(&data_bin); dmlc::Stream* stream = &fs; uint32_t magic = kVulkanModuleMagic; stream->Write(magic); stream->Write(smap_); SaveBinaryToFile(file_name, data_bin); } void SaveToBinary(dmlc::Stream* stream) final { stream->Write(fmt_); stream->Write(fmap_); stream->Write(smap_); } std::string GetSource(const std::string& format) final { // can only return source code. return source_; } // get a from primary context in device_id PipelineEntry GetPipeline(size_t device_id, const std::string& func_name, size_t num_pack_args) { vulkan::VulkanWorkspace* w = vulkan::VulkanWorkspace::Global().get(); CHECK_LT(device_id, w->context_.size()); // start lock scope. std::lock_guard<std::mutex> lock(mutex_); if (finfo_.size() <= device_id) { finfo_.resize(device_id + 1, DeviceEntry()); } DeviceEntry& e = finfo_[device_id]; auto it = e.smap.find(func_name); if (it != e.smap.end()) return it->second; PipelineEntry pe; if (e.device == nullptr) { e.device = w->context_[device_id].device; } { // create shader auto sit = smap_.find(func_name); CHECK(sit != smap_.end()); const std::vector<uint32_t>& data = sit->second.data; VkShaderModuleCreateInfo shader_cinfo; shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; shader_cinfo.pNext = nullptr; shader_cinfo.flags = 0; shader_cinfo.codeSize = data.size() * sizeof(uint32_t); shader_cinfo.pCode = data.data(); VULKAN_CALL(vkCreateShaderModule( e.device, &shader_cinfo, nullptr, &(pe.shader))); } std::vector<VkDescriptorSetLayoutBinding> arg_binding; uint32_t num_pod = 0, num_buffer = 0; { auto fit = fmap_.find(func_name); CHECK(fit != fmap_.end()); for (TVMType arg_type : fit->second.arg_types) { if (arg_type.code == kHandle) { VkDescriptorSetLayoutBinding bd; bd.binding = num_buffer; bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; bd.descriptorCount = 1; bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; bd.pImmutableSamplers = nullptr; arg_binding.push_back(bd); ++num_buffer; } else { ++num_pod; } } } VkDescriptorSetLayoutCreateInfo descrip_cinfo; descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; descrip_cinfo.pNext = nullptr; descrip_cinfo.flags = 0; descrip_cinfo.bindingCount = arg_binding.size(); descrip_cinfo.pBindings = arg_binding.data(); VULKAN_CALL(vkCreateDescriptorSetLayout( e.device, &descrip_cinfo, nullptr, &(pe.descriptor_layout))); VkPushConstantRange crange; crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; crange.offset = 0; crange.size = sizeof(ArgUnion) * num_pack_args; VkPipelineLayoutCreateInfo playout_cinfo; playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; playout_cinfo.pNext = nullptr; playout_cinfo.flags = 0; playout_cinfo.setLayoutCount = 1; playout_cinfo.pSetLayouts = &(pe.descriptor_layout); if (num_pack_args != 0) { playout_cinfo.pushConstantRangeCount = 1; playout_cinfo.pPushConstantRanges = &crange; CHECK_LE(crange.size, w->context_[device_id].phy_device_prop.limits.maxPushConstantsSize); } else { playout_cinfo.pushConstantRangeCount = 0; playout_cinfo.pPushConstantRanges = nullptr; } VULKAN_CALL(vkCreatePipelineLayout( e.device, &playout_cinfo, nullptr, &(pe.pipeline_layout))); VkComputePipelineCreateInfo pipeline_cinfo; pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; pipeline_cinfo.pNext = nullptr; pipeline_cinfo.flags = 0; pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; pipeline_cinfo.stage.pNext = nullptr; pipeline_cinfo.stage.flags = 0; pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; pipeline_cinfo.stage.module = pe.shader; pipeline_cinfo.stage.pName = func_name.c_str(); pipeline_cinfo.stage.pSpecializationInfo = nullptr; pipeline_cinfo.layout = pe.pipeline_layout; pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE; pipeline_cinfo.basePipelineIndex = 0; VULKAN_CALL(vkCreateComputePipelines( e.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, &(pe.pipeline))); e.smap[func_name] = pe; return pe; } private: // device specific entry struct DeviceEntry { VkDevice device{nullptr}; std::unordered_map<std::string, PipelineEntry> smap; }; // the binary data std::vector<uint32_t> data_; // function information table. std::unordered_map<std::string, VulkanShader> smap_; // function information table. std::unordered_map<std::string, FunctionInfo> fmap_; // The format std::string fmt_{"vulkan"}; // The source std::string source_; // device local pipeline information. std::vector<DeviceEntry> finfo_; // internal mutex when updating the module std::mutex mutex_; }; // a wrapped function class to get packed func. class VulkanWrappedFunc { public: // initialize the VULKAN function. void Init(VulkanModuleNode* m, std::shared_ptr<ModuleNode> sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, const std::vector<std::string>& thread_axis_tags) { w_ = vulkan::VulkanWorkspace::Global().get(); m_ = m; sptr_ = sptr; func_name_ = func_name; num_buffer_args_ = num_buffer_args; num_pack_args_ = num_pack_args; thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); } // invoke the function with void arguments void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { vulkan::VulkanThreadEntry* tls = vulkan::VulkanThreadEntry::ThreadLocal(); int device_id = tls->context.device_id; CHECK_LT(device_id, kVulkanMaxNumDevice); const vulkan::VulkanContext& vctx = w_->context_[device_id]; VulkanModuleNode::PipelineEntry& pe = scache_[device_id]; if (pe.pipeline == VK_NULL_HANDLE) { pe = m_->GetPipeline(device_id, func_name_, num_pack_args_); } ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); vulkan::VulkanCommandBuffer* cmd = tls->CommandPool(device_id)->Alloc( &(pe.descriptor_layout)); cmd->write_descriptor_set.dstSet = cmd->descriptor_set; // setup descriptors for (uint32_t i = 0; i < num_buffer_args_; ++i) { void* buf = args[static_cast<int>(i)]; VkDescriptorBufferInfo binfo; binfo.buffer = static_cast<vulkan::VulkanBuffer*>(buf)->buffer; binfo.offset = 0; binfo.range = VK_WHOLE_SIZE; cmd->write_descriptor_set.dstBinding = i; cmd->write_descriptor_set.pBufferInfo = &binfo; vkUpdateDescriptorSets( vctx.device, 1, &(cmd->write_descriptor_set), 0, nullptr); } // dispatch VkCommandBufferBeginInfo cb_begin; cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; cb_begin.pNext = nullptr; cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; cb_begin.pInheritanceInfo = 0; VkSubmitInfo cb_submit; cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; cb_submit.pNext = nullptr; cb_submit.waitSemaphoreCount = 0; cb_submit.pWaitSemaphores = nullptr; cb_submit.pWaitDstStageMask = 0; cb_submit.commandBufferCount = 1; cb_submit.pCommandBuffers = &(cmd->cmd_buffer); cb_submit.signalSemaphoreCount = 0; cb_submit.pSignalSemaphores = nullptr; // 0: begin VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin)); // 1: dispatch vkCmdBindPipeline( cmd->cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, pe.pipeline); vkCmdBindDescriptorSets( cmd->cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, pe.pipeline_layout, 0, 1, &(cmd->descriptor_set), 0, nullptr); // bind push constant if necessary if (num_pack_args_ != 0) { vkCmdPushConstants( cmd->cmd_buffer, pe.pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion), pack_args); } vkCmdDispatch( cmd->cmd_buffer, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); // 2: barrier(compute->compute|transfer) VkMemoryBarrier barrier_info; barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; barrier_info.pNext = nullptr; barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT; barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); vkCmdPipelineBarrier( cmd->cmd_buffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1, &barrier_info, 0, nullptr, 0, nullptr); // 3: end VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer)); // 4: submit with cmd->fence VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence)); } private: // Reference to global workspace. vulkan::VulkanWorkspace* w_; // internal module VulkanModuleNode* m_; // the resource holder std::shared_ptr<ModuleNode> sptr_; // The name of the function. std::string func_name_; // Number of buffer arguments size_t num_buffer_args_; // number of packed arguments. size_t num_pack_args_; // Device state cache per device. // mark as mutable, to enable lazy initialization mutable std::array<VulkanModuleNode::PipelineEntry, kVulkanMaxNumDevice> scache_; // thread axis configuration ThreadAxisConfig thread_axis_cfg_; }; PackedFunc VulkanModuleNode::GetFunction( const std::string& name, const std::shared_ptr<ModuleNode>& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; VulkanWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, info.thread_axis_tags); return PackFuncNonBufferArg(f, info.arg_types); } Module VulkanModuleCreate( std::unordered_map<std::string, VulkanShader> smap, std::unordered_map<std::string, FunctionInfo> fmap, std::string source) { vulkan::VulkanWorkspace::Global()->Init(); std::shared_ptr<VulkanModuleNode> n = std::make_shared<VulkanModuleNode>(smap, fmap, source); return Module(n); } // Load module from module. Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map<std::string, VulkanShader> smap; std::unordered_map<std::string, FunctionInfo> fmap; std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); LoadBinaryFromFile(file_name, &data); LoadMetaDataFromFile(meta_file, &fmap); dmlc::MemoryStringStream fs(&data); dmlc::Stream* stream = &fs; uint32_t magic; stream->Read(&magic); CHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch"; stream->Read(&smap); return VulkanModuleCreate(smap, fmap, ""); } Module VulkanModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm); std::unordered_map<std::string, VulkanShader> smap; std::unordered_map<std::string, FunctionInfo> fmap; std::string fmt; stream->Read(&fmt); stream->Read(&fmap); stream->Read(&smap); return VulkanModuleCreate(smap, fmap, ""); } TVM_REGISTER_GLOBAL("module.loadfile_vulkan") .set_body_typed(VulkanModuleLoadFile); TVM_REGISTER_GLOBAL("module.loadbinary_vulkan") .set_body_typed(VulkanModuleLoadBinary); } // namespace runtime } // namespace tvm