Commit 2536465c by Andrew Tulloch Committed by Tianqi Chen

Vulkan2 Runtime API (#3849)

parent 06aecc60
......@@ -62,8 +62,7 @@
#endif
#ifdef TVM_VULKAN_RUNTIME
#include "../src/runtime/vulkan/vulkan_device_api.cc"
#include "../src/runtime/vulkan/vulkan_module.cc"
#include "../src/runtime/vulkan/vulkan.cc"
#endif
#ifdef USE_SORT
......
......@@ -18,6 +18,14 @@
# Be compatible with older version of CMake
find_vulkan(${USE_VULKAN})
# Extra Vulkan runtime options, exposed for advanced users.
tvm_option(USE_VULKAN_IMMEDIATE_MODE "Use Vulkan Immediate mode
(KHR_push_descriptor extension)" ON IF USE_VULKAN)
tvm_option(USE_VULKAN_DEDICATED_ALLOCATION "Use Vulkan dedicated allocations" ON
IF USE_VULKAN)
tvm_option(USE_VULKAN_VALIDATION "Enable Vulkan API validation layers" OFF
IF USE_VULKAN)
if(Vulkan_FOUND)
# always set the includedir
# avoid global retrigger of cmake
......@@ -28,12 +36,24 @@ if(USE_VULKAN)
if(NOT Vulkan_FOUND)
message(FATAL_ERROR "Cannot find Vulkan, USE_VULKAN=" ${USE_VULKAN})
endif()
message(STATUS "Build with VULKAN support")
file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc)
message(STATUS "Build with Vulkan support")
file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/vulkan.cc)
file(GLOB COMPILER_VULKAN_SRCS src/codegen/spirv/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS})
list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS})
list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARY})
if(USE_VULKAN_IMMEDIATE_MODE)
message(STATUS "Build with Vulkan immediate mode")
add_definitions(-DUSE_VULKAN_IMMEDIATE_MODE=1)
endif()
if(USE_VULKAN_DEDICATED_ALLOCATION)
message(STATUS "Build with Vulkan dedicated allocation")
add_definitions(-DUSE_VULKAN_DEDICATED_ALLOCATION=1)
endif()
if(USE_VULKAN_VALIDATION)
message(STATUS "Build with Vulkan API validation")
add_definitions(-DUSE_VULKAN_VALIDATION=1)
endif()
endif(USE_VULKAN)
......@@ -29,6 +29,8 @@
#include "codegen_spirv.h"
#include "../build_common.h"
#include "../../runtime/vulkan/vulkan_shader.h"
#include "../../runtime/vulkan/vulkan_module.h"
namespace tvm {
......
......@@ -33,7 +33,10 @@ namespace spirv {
void IRBuilder::InitHeader() {
CHECK_EQ(header_.size(), 0U);
header_.push_back(spv::MagicNumber);
header_.push_back(spv::Version);
// Use SPIR-V v1.0. This needs to be kept in sync (or at least behind)
// `VkApplicationInfo.apiVersion` in `vulkan.cc` to ensure Vulkan API
// validation passes.
header_.push_back(0x10000);
// generator: set to 0, unknown
header_.push_back(0U);
// Bound: set during Finalize
......
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
## Components
### VulkanDeviceAPI
Implements the TVM DeviceAPI interface. Owns the core Vulkan datastructures. Is
responsible for initializing the Vulkan instance and devices, querying for
possible extensions.
### VulkanThreadEntry
Thread-local state for the Vulkan runtime. Maintains a staging buffer (for
copies), and a VulkanStream per device.
### VulkanWrappedFunc
Responsible for launching computation kernels. Responsible for obtaining a
VulkanPipeline instance (from the VulkanModuleNode), and launches the kernel
(via immediate or deferred mode) on the active VulkanStream instance.
## Stream execution in the Vulkan programming model.
The natural model for TVM DeviceAPI implementation and runtime follows the CUDA
API model. That is, we launch "kernels" onto a (implicit or explicit) "stream"
(which execute asynchronously with respect to the host, but ordered with respect
to the stream), and explicitly synchronize the stream with respect to the host.
We simulate this behaviour in the Vulkan model by maintaining a thread-local
`vkCommandBuffer` instance, and queueing up (or eagerly executing, depending on
the availability of the `VK_KHR_push_descriptor` extension). When we synchronize
the stream, we end the command buffer recording, submit it to the device queue,
and wait on the corresponding fence.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <vulkan/vulkan.h>
#include <dmlc/memory_io.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <array>
#include <cstring>
#include "../file_util.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "../workspace_pool.h"
#include "vulkan_common.h"
#include "vulkan_module.h"
#include "vulkan_shader.h"
#include "vulkan_stream.h"
namespace tvm {
namespace runtime {
namespace vulkan {
/*! \brief Maximum number of GPU supported in VulkanModule. */
static constexpr const int kVulkanMaxNumDevice = 8;
/*! \brief TVM Vulkan binary pack magic number */
static constexpr const int kVulkanModuleMagic = 0x02700027;
class VulkanThreadEntry {
public:
VulkanThreadEntry();
static VulkanThreadEntry* ThreadLocal();
~VulkanThreadEntry() {
// Because the thread entry refers to Device API
// The command buffer always will be destroyed before
// the instance and device get destroyed.
// The destruction need to be manually called
// to ensure the destruction order.
streams_.clear();
for (const auto& kv : staging_buffers_) {
if (!kv.second) {
continue;
}
auto& buf = *(kv.second);
if (buf.host_addr != nullptr) {
vkUnmapMemory(buf.device, buf.memory);
}
if (buf.memory != VK_NULL_HANDLE) {
vkFreeMemory(buf.device, buf.memory, nullptr);
}
if (buf.buffer != VK_NULL_HANDLE) {
vkDestroyBuffer(buf.device, buf.buffer, nullptr);
}
}
}
TVMContext ctx;
WorkspacePool pool;
VulkanStream* Stream(size_t device_id);
VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
private:
std::unordered_map<size_t, std::unique_ptr<VulkanStream>> streams_;
std::unordered_map<size_t, std::unique_ptr<VulkanStagingBuffer>> staging_buffers_;
};
struct VulkanBuffer {
VkBuffer buffer{VK_NULL_HANDLE};
VkDeviceMemory memory{VK_NULL_HANDLE};
};
struct VulkanPipeline {
VulkanContext* vctx_{nullptr};
VkShaderModule shader{VK_NULL_HANDLE};
VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE};
VkDescriptorPool descriptor_pool{VK_NULL_HANDLE};
VkDescriptorSet descriptor_set{VK_NULL_HANDLE};
VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
VkPipeline pipeline{VK_NULL_HANDLE};
VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE};
};
typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
class VulkanDeviceAPI final : public DeviceAPI {
public:
VulkanDeviceAPI();
~VulkanDeviceAPI() {
for (auto& vctx : context_) {
vkDestroyDevice(vctx.device, nullptr);
}
if (instance_) {
vkDestroyInstance(instance_, nullptr);
}
}
void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; }
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) final {
const auto& vctx = context(ctx.device_id);
VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.size = nbytes;
info.queueFamilyIndexCount = 1;
info.pQueueFamilyIndices = &(vctx.queue_family_index);
info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
// create buffer
VkBuffer buffer;
VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
// bind to memory
VkBufferMemoryRequirementsInfo2KHR req_info2;
req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
req_info2.pNext = 0;
req_info2.buffer = buffer;
VkMemoryRequirements2KHR req2;
req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
req2.pNext = 0;
VkMemoryDedicatedRequirementsKHR dedicated_req;
dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
dedicated_req.pNext = 0;
req2.pNext = &dedicated_req;
bool dedicated_allocation = false;
if (vctx.get_buffer_memory_requirements_2_functions) {
vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
vctx.device, &req_info2, &req2);
dedicated_allocation =
dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
}
VkDeviceMemory memory;
if (!dedicated_allocation) {
VkMemoryAllocateInfo minfo;
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
minfo.pNext = nullptr;
minfo.allocationSize = nbytes;
minfo.memoryTypeIndex = vctx.compute_mtype_index;
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
} else {
VkMemoryAllocateInfo minfo;
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
minfo.pNext = nullptr;
minfo.allocationSize = req2.memoryRequirements.size;
minfo.memoryTypeIndex = vctx.compute_mtype_index;
VkMemoryDedicatedAllocateInfoKHR mdinfo;
mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
mdinfo.pNext = 0;
mdinfo.image = 0;
mdinfo.buffer = buffer;
minfo.pNext = &mdinfo;
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
}
VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
VulkanBuffer* pbuf = new VulkanBuffer();
pbuf->memory = memory;
pbuf->buffer = buffer;
return pbuf;
}
void FreeDataSpace(TVMContext ctx, void* ptr) final {
const auto& vctx = context(ctx.device_id);
auto* pbuf = static_cast<VulkanBuffer*>(ptr);
vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr);
vkFreeMemory(vctx.device, pbuf->memory, nullptr);
delete pbuf;
}
void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
TVMContext ctx_from, TVMContext ctx_to, TVMType type_hint,
TVMStreamHandle stream) final {
CHECK(stream == nullptr);
TVMContext ctx = ctx_from;
if (ctx_from.device_type == kDLCPU) {
ctx = ctx_to;
}
int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type);
if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) {
VulkanThreadEntry::ThreadLocal()
->Stream(ctx_from.device_id)
->Launch([=](VulkanStreamState* state) {
// 1: copy
const auto* from_buf = static_cast<const VulkanBuffer*>(from);
auto* to_buf = static_cast<VulkanBuffer*>(to);
VkBufferCopy copy_info;
copy_info.srcOffset = from_offset;
copy_info.dstOffset = to_offset;
copy_info.size = size;
vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, &copy_info);
// 2: barrier(transfer-> compute|transfer)
CHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Vulkan disallow cross device copy.";
VkMemoryBarrier barrier_info;
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
barrier_info.pNext = nullptr;
barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
barrier_info.dstAccessMask =
(VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
vkCmdPipelineBarrier(
state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1,
&barrier_info, 0, nullptr, 0, nullptr);
});
} else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) {
const auto* from_buf = static_cast<const VulkanBuffer*>(from);
const auto& vctx = context(ctx_from.device_id);
auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_from.device_id, size);
VulkanThreadEntry::ThreadLocal()
->Stream(ctx_from.device_id)
->Launch([&](VulkanStreamState* state) {
VkBufferCopy copy_info;
copy_info.srcOffset = from_offset;
copy_info.dstOffset = 0;
copy_info.size = size;
vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->buffer, 1, &copy_info);
});
VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
if (!vctx.coherent_staging) {
VkMappedMemoryRange mrange;
mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
mrange.pNext = nullptr;
mrange.memory = temp->memory;
mrange.offset = 0;
mrange.size = VK_WHOLE_SIZE; // size;
VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange));
}
memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>(temp->host_addr), size);
} else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) {
const auto& vctx = context(ctx_to.device_id);
const auto* to_buf = static_cast<const VulkanBuffer*>(to);
VulkanStagingBuffer* temp =
VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_to.device_id, size);
memcpy(temp->host_addr, static_cast<const char*>(from) + from_offset, size);
// host side flush if access is not coherent.
// so writes from CPU is visible to GPU
if (!vctx.coherent_staging) {
VkMappedMemoryRange mrange;
mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
mrange.pNext = nullptr;
mrange.memory = temp->memory;
mrange.offset = 0;
mrange.size = VK_WHOLE_SIZE; // size;
VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange));
}
VulkanThreadEntry::ThreadLocal()
->Stream(ctx_from.device_id)
->Launch([&](VulkanStreamState* state) {
// 0: barrier(host->transfer)
VkMemoryBarrier barrier_info;
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
barrier_info.pNext = nullptr;
barrier_info.srcAccessMask = 0;
barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0,
nullptr);
// 1: copy
VkBufferCopy copy_info;
copy_info.srcOffset = 0;
copy_info.dstOffset = to_offset;
copy_info.size = size;
vkCmdCopyBuffer(state->cmd_buffer_, temp->buffer, to_buf->buffer, 1, &copy_info);
});
// TODO(tulloch): should we instead make the staging buffer a property of the
// Stream? This would allow us to elide synchronizations here.
VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
} else {
LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan"
<< ", from=" << from_dev_type << ", to=" << to_dev_type;
}
}
// Always use the default stream
TVMStreamHandle CreateStream(TVMContext ctx) {
LOG(FATAL) << "Not implemented";
return nullptr;
}
void FreeStream(TVMContext ctx, TVMStreamHandle stream) {
LOG(FATAL) << "Not implemented";
return;
}
void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
LOG(FATAL) << "Not implemented";
return;
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
CHECK(stream == nullptr);
VulkanThreadEntry::ThreadLocal()->Stream(ctx.device_id)->Synchronize();
}
void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
LOG(FATAL) << "Not implemented";
return;
}
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
void FreeWorkspace(TVMContext ctx, void* data) final {
VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}
static const std::shared_ptr<VulkanDeviceAPI>& Global() {
static std::shared_ptr<VulkanDeviceAPI> inst = std::make_shared<VulkanDeviceAPI>();
return inst;
}
const VulkanContext& context(size_t device_id) const {
CHECK_LT(device_id, context_.size());
return context_[device_id];
}
private:
VkInstance instance_{nullptr};
// The physical devices, have 1 to 1 mapping to devices
std::vector<VulkanContext> context_;
};
void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) {
*rv = static_cast<int>(index < context_.size());
return;
}
CHECK_LT(index, context_.size()) << "Invalid device id " << index;
const auto& vctx = context(index);
switch (kind) {
case kMaxThreadsPerBlock: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
int64_t value = phy_prop.limits.maxComputeWorkGroupSize[0];
*rv = value;
break;
}
case kMaxSharedMemoryPerBlock: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
*rv = value;
break;
}
case kWarpSize: {
*rv = 1;
break;
}
case kComputeVersion: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
int64_t value = phy_prop.apiVersion;
std::ostringstream os;
os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "."
<< VK_VERSION_PATCH(value);
*rv = os.str();
break;
}
case kDeviceName:
return;
case kMaxClockRate:
return;
case kMultiProcessorCount:
return;
case kExist:
break;
case kMaxThreadDimensions:
break;
}
}
VulkanDeviceAPI::VulkanDeviceAPI() {
VkApplicationInfo app_info;
app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
app_info.pNext = nullptr;
app_info.pApplicationName = "TVM";
app_info.applicationVersion = 0;
app_info.pEngineName = "";
app_info.engineVersion = 0;
app_info.apiVersion = VK_MAKE_VERSION(1, 0, 0);
VkInstanceCreateInfo inst_info;
inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
inst_info.pNext = nullptr;
inst_info.flags = 0;
const auto layers = []() -> std::vector<const char*> {
uint32_t inst_layer_prop_count;
VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr));
std::vector<VkLayerProperties> inst_layer_prop(inst_layer_prop_count);
VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data()));
std::vector<const char*> l;
for (const auto& lp : inst_layer_prop) {
// TODO(tulloch): add CMAKE options.
(void)lp; // suppress unused variable warning.
#ifdef USE_VULKAN_VALIDATION
if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) {
l.push_back("VK_LAYER_LUNARG_standard_validation");
}
if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) {
l.push_back("VK_LAYER_LUNARG_parameter_validation");
}
if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) {
l.push_back("VK_LAYER_KHRONOS_validation");
}
#endif
}
return l;
}();
const auto instance_extensions = []() -> std::vector<const char*> {
uint32_t inst_extension_prop_count;
VULKAN_CALL(
vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr));
std::vector<VkExtensionProperties> inst_extension_prop(inst_extension_prop_count);
VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count,
inst_extension_prop.data()));
std::vector<const char*> extensions;
for (const auto& ip : inst_extension_prop) {
if (std::strcmp(ip.extensionName, "VK_KHR_get_physical_device_properties2") == 0) {
extensions.push_back("VK_KHR_get_physical_device_properties2");
}
}
return extensions;
}();
inst_info.pApplicationInfo = &app_info;
inst_info.enabledLayerCount = layers.size();
inst_info.ppEnabledLayerNames = layers.data();
inst_info.enabledExtensionCount = instance_extensions.size();
inst_info.ppEnabledExtensionNames = instance_extensions.data();
VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_));
uint32_t phy_dev_count = 0;
VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr));
std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
for (VkPhysicalDevice phy_dev : all_phy_devs) {
uint32_t queue_prop_count = 0;
vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr);
std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count,
dmlc::BeginPtr(queue_props));
uint32_t queue_family_index = 0;
std::vector<VkDeviceQueueCreateInfo> queue_create_info;
float priority = 1.0f;
for (uint32_t i = 0; i < queue_props.size(); i++) {
// find queues that support compute
if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
VkDeviceQueueCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.queueFamilyIndex = i;
info.queueCount = 1;
info.pQueuePriorities = &priority;
queue_create_info.push_back(info);
// only use the first available queue for now
if (queue_create_info.size() == 0) {
queue_family_index = i;
}
}
}
if (queue_create_info.size() == 0) continue;
VulkanContext ctx;
// setup context
ctx.phy_device = phy_dev;
vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop));
const auto extensions = [&]() {
uint32_t device_extension_prop_count;
VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr,
&device_extension_prop_count, nullptr));
std::vector<VkExtensionProperties> device_extension_prop(device_extension_prop_count);
VULKAN_CALL(vkEnumerateDeviceExtensionProperties(
ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data()));
std::vector<const char*> extensions;
for (const auto& dp : device_extension_prop) {
if ((std::strcmp(dp.extensionName, "VK_KHR_push_descriptor") == 0) && dp.specVersion > 0) {
extensions.push_back("VK_KHR_push_descriptor");
}
if ((std::strcmp(dp.extensionName, "VK_KHR_descriptor_update_template") == 0) &&
dp.specVersion > 0) {
extensions.push_back("VK_KHR_descriptor_update_template");
}
if ((std::strcmp(dp.extensionName, "VK_KHR_get_memory_requirements2") == 0) &&
dp.specVersion > 0) {
extensions.push_back("VK_KHR_get_memory_requirements2");
}
if ((std::strcmp(dp.extensionName, "VK_KHR_dedicated_allocation") == 0) &&
dp.specVersion > 0) {
extensions.push_back("VK_KHR_dedicated_allocation");
}
}
return extensions;
}();
VkDeviceCreateInfo device_create_info;
device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
device_create_info.pNext = nullptr;
device_create_info.flags = 0;
device_create_info.queueCreateInfoCount = static_cast<uint32_t>(queue_create_info.size());
device_create_info.pQueueCreateInfos = queue_create_info.data();
device_create_info.enabledLayerCount = 0;
device_create_info.ppEnabledLayerNames = nullptr;
device_create_info.enabledExtensionCount = extensions.size();
device_create_info.ppEnabledExtensionNames = extensions.data();
device_create_info.pEnabledFeatures = nullptr;
VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device)));
ctx.queue_mutex.reset(new std::mutex());
vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
ctx.queue_family_index = queue_family_index;
// Find suitable memory type for staging and compute
// Find suitable compute index.
VkBuffer buffer;
VkMemoryRequirements req_staging, req_compute;
VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.size = 1024;
info.queueFamilyIndexCount = 1;
info.pQueueFamilyIndices = &(ctx.queue_family_index);
info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
// get staging requirement
info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
vkGetBufferMemoryRequirements(ctx.device, buffer, &req_staging);
vkDestroyBuffer(ctx.device, buffer, nullptr);
// get compute requirement
info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
vkGetBufferMemoryRequirements(ctx.device, buffer, &req_compute);
vkDestroyBuffer(ctx.device, buffer, nullptr);
// Query phyiscal device property
// find a memory that is host visible, no need to be consistent
int win_rank = -1;
VkPhysicalDeviceMemoryProperties prop;
vkGetPhysicalDeviceMemoryProperties(ctx.phy_device, &prop);
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
// match copy requirment
if (!(req_staging.memoryTypeBits & (1 << k))) continue;
if (heap_size < 1024) continue;
int rank = 0;
rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
if (rank > win_rank) {
win_rank = rank;
ctx.staging_mtype_index = k;
ctx.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
}
}
CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
win_rank = -1;
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
// match copy requirment
if (!(req_staging.memoryTypeBits & (1 << k))) continue;
if (heap_size < 1024) continue;
int rank = 0;
// prefer not host visible
rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
if (rank > win_rank) {
win_rank = rank;
ctx.compute_mtype_index = k;
}
}
CHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device.";
auto has_extension = [&extensions](const char* query) {
return std::any_of(extensions.begin(), extensions.end(),
[&](const char* extension) { return std::strcmp(query, extension) == 0; });
};
#ifdef USE_VULKAN_IMMEDIATE_MODE
if (has_extension("VK_KHR_push_descriptor") &&
has_extension("VK_KHR_descriptor_update_template")) {
ctx.descriptor_template_khr_functions =
std::unique_ptr<VulkanDescriptorTemplateKHRFunctions>(
new VulkanDescriptorTemplateKHRFunctions());
ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR =
CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
ctx.device, "vkCreateDescriptorUpdateTemplateKHR"));
ctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR =
CHECK_NOTNULL((PFN_vkDestroyDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
ctx.device, "vkDestroyDescriptorUpdateTemplateKHR"));
ctx.descriptor_template_khr_functions->vkUpdateDescriptorSetWithTemplateKHR =
CHECK_NOTNULL((PFN_vkUpdateDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
ctx.device, "vkUpdateDescriptorSetWithTemplateKHR"));
ctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR =
CHECK_NOTNULL((PFN_vkCmdPushDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
ctx.device, "vkCmdPushDescriptorSetWithTemplateKHR"));
}
#endif
#ifdef USE_VULKAN_DEDICATED_ALLOCATION
if (has_extension("VK_KHR_get_memory_requirements2") &&
has_extension("VK_KHR_dedicated_allocation")) {
ctx.get_buffer_memory_requirements_2_functions =
std::unique_ptr<VulkanGetBufferMemoryRequirements2Functions>(
new VulkanGetBufferMemoryRequirements2Functions());
ctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR =
CHECK_NOTNULL((PFN_vkGetBufferMemoryRequirements2KHR)vkGetDeviceProcAddr(
ctx.device, "vkGetBufferMemoryRequirements2KHR"));
}
#endif
context_.push_back(std::move(ctx));
}
LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices..";
for (size_t i = 0; i < context_.size(); ++i) {
LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName
<< "\' phy_dev_id=" << context_[i].phy_device
<< " use_immediate=" << context_[i].UseImmediate();
}
} // namespace vulkan
class VulkanModuleNode;
// a wrapped function class to get packed func.
class VulkanWrappedFunc {
public:
void Init(VulkanModuleNode* m, std::shared_ptr<ModuleNode> sptr, const std::string& func_name,
size_t num_buffer_args, size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
num_buffer_args_ = num_buffer_args;
num_pack_args_ = num_pack_args;
thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
}
void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const;
private:
// internal module
VulkanModuleNode* m_;
// the resource holder
std::shared_ptr<ModuleNode> sptr_;
// v The name of the function.
std::string func_name_;
// Number of buffer arguments
size_t num_buffer_args_;
// number of packed arguments.
size_t num_pack_args_;
// Device state cache per device.
// mark as mutable, to enable lazy initialization
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
mutable std::array<std::shared_ptr<VulkanPipeline>, kVulkanMaxNumDevice> scache_;
};
// Multi-device enabled module.
class VulkanModuleNode final : public runtime::ModuleNode {
public:
explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: smap_(smap), fmap_(fmap), source_(source) {}
const char* type_key() const final { return "vulkan"; }
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
VulkanWrappedFunc f;
size_t num_buffer_args = NumBufferArgs(info.arg_types);
f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args,
info.thread_axis_tags);
return PackFuncNonBufferArg(std::move(f), info.arg_types);
}
~VulkanModuleNode() {
// cleanup vulkan related caches.
for (int device_id = 0; device_id < ecache_.size(); ++device_id) {
for (auto& kv : ecache_[device_id]) {
auto& pe = kv.second;
CHECK(pe);
const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
if (pe->descriptor_update_template != VK_NULL_HANDLE) {
vctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR(
vctx.device, pe->descriptor_update_template, nullptr);
}
vkDestroyPipeline(vctx.device, pe->pipeline, nullptr);
vkDestroyPipelineLayout(vctx.device, pe->pipeline_layout, nullptr);
vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr);
vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr);
vkDestroyShaderModule(vctx.device, pe->shader, nullptr);
}
}
}
std::shared_ptr<VulkanPipeline> GetPipeline(size_t device_id, const std::string& func_name,
size_t num_pack_args) {
const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
std::lock_guard<std::mutex> lock(mutex_);
const auto& cp = ecache_[device_id][func_name];
if (cp) {
return cp;
}
// Create new pipeline
auto pe = std::shared_ptr<VulkanPipeline>(new VulkanPipeline());
{
// create shader
auto sit = smap_.find(func_name);
CHECK(sit != smap_.end());
const std::vector<uint32_t>& data = sit->second.data;
VkShaderModuleCreateInfo shader_cinfo;
shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
shader_cinfo.pNext = nullptr;
shader_cinfo.flags = 0;
shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
shader_cinfo.pCode = data.data();
VULKAN_CALL(vkCreateShaderModule(vctx.device, &shader_cinfo, nullptr, &(pe->shader)));
}
std::vector<VkDescriptorSetLayoutBinding> arg_binding;
std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
uint32_t num_pod = 0, num_buffer = 0;
{
auto fit = fmap_.find(func_name);
CHECK(fit != fmap_.end());
for (TVMType arg_type : fit->second.arg_types) {
if (arg_type.code == kHandle) {
{
VkDescriptorSetLayoutBinding bd;
bd.binding = num_buffer;
bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
bd.descriptorCount = 1;
bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
bd.pImmutableSamplers = nullptr;
arg_binding.push_back(bd);
}
{
VkDescriptorUpdateTemplateEntryKHR tpl;
tpl.dstBinding = num_buffer;
tpl.dstArrayElement = 0;
tpl.descriptorCount = 1;
tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo);
tpl.stride = sizeof(VkDescriptorBufferInfo);
arg_template.push_back(tpl);
}
++num_buffer;
} else {
++num_pod;
}
}
}
{
VkDescriptorSetLayoutCreateInfo descrip_cinfo;
descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
descrip_cinfo.pNext = nullptr;
descrip_cinfo.flags = 0;
if (vctx.UseImmediate()) {
descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR;
}
descrip_cinfo.bindingCount = arg_binding.size();
descrip_cinfo.pBindings = arg_binding.data();
VULKAN_CALL(vkCreateDescriptorSetLayout(vctx.device, &descrip_cinfo, nullptr,
&(pe->descriptor_set_layout)));
}
{
VkDescriptorPoolSize pool_size;
pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
pool_size.descriptorCount = arg_binding.size();
VkDescriptorPoolCreateInfo descrip_pool_cinfo;
descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
descrip_pool_cinfo.pNext = nullptr;
descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
descrip_pool_cinfo.maxSets = 1;
descrip_pool_cinfo.poolSizeCount = 1;
descrip_pool_cinfo.pPoolSizes = &pool_size;
VULKAN_CALL(vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr,
&(pe->descriptor_pool)));
}
if (!vctx.UseImmediate()) {
VkDescriptorSetAllocateInfo alloc_info;
alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
alloc_info.pNext = nullptr;
alloc_info.descriptorPool = pe->descriptor_pool;
alloc_info.descriptorSetCount = 1;
alloc_info.pSetLayouts = &(pe->descriptor_set_layout);
VULKAN_CALL(vkAllocateDescriptorSets(vctx.device, &alloc_info, &(pe->descriptor_set)));
}
VkPushConstantRange crange;
crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
crange.offset = 0;
crange.size = sizeof(ArgUnion) * num_pack_args;
VkPipelineLayoutCreateInfo playout_cinfo;
playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
playout_cinfo.pNext = nullptr;
playout_cinfo.flags = 0;
playout_cinfo.setLayoutCount = 1;
playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);
if (num_pack_args != 0) {
playout_cinfo.pushConstantRangeCount = 1;
playout_cinfo.pPushConstantRanges = &crange;
CHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize);
} else {
playout_cinfo.pushConstantRangeCount = 0;
playout_cinfo.pPushConstantRanges = nullptr;
}
VULKAN_CALL(
vkCreatePipelineLayout(vctx.device, &playout_cinfo, nullptr, &(pe->pipeline_layout)));
VkComputePipelineCreateInfo pipeline_cinfo;
pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
pipeline_cinfo.pNext = nullptr;
pipeline_cinfo.flags = 0;
pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
pipeline_cinfo.stage.pNext = nullptr;
pipeline_cinfo.stage.flags = 0;
pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
pipeline_cinfo.stage.module = pe->shader;
pipeline_cinfo.stage.pName = func_name.c_str();
pipeline_cinfo.stage.pSpecializationInfo = nullptr;
pipeline_cinfo.layout = pe->pipeline_layout;
pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE;
pipeline_cinfo.basePipelineIndex = 0;
VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
&(pe->pipeline)));
if (vctx.UseImmediate()) {
VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
descrip_template_cinfo.pNext = 0;
descrip_template_cinfo.flags = 0;
descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size();
descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data();
descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR;
descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout;
descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
descrip_template_cinfo.pipelineLayout = pe->pipeline_layout;
descrip_template_cinfo.set = 0;
VULKAN_CALL(vctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR(
vctx.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template)));
}
ecache_[device_id][func_name] = pe;
return pe;
}
void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan";
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
std::string data_bin;
dmlc::MemoryStringStream fs(&data_bin);
dmlc::Stream* stream = &fs;
uint32_t magic = kVulkanModuleMagic;
stream->Write(magic);
stream->Write(smap_);
SaveBinaryToFile(file_name, data_bin);
}
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(smap_);
}
std::string GetSource(const std::string& format) final {
// can only return source code.
return source_;
}
private:
// the binary data
std::vector<uint32_t> data_;
// function information table.
std::unordered_map<std::string, VulkanShader> smap_;
// function information table.
std::unordered_map<std::string, FunctionInfo> fmap_;
// The format
std::string fmt_{"vulkan"};
// The source
std::string source_;
// Guards accesses to `ecache_`
std::mutex mutex_;
std::array<std::unordered_map<std::string, std::shared_ptr<VulkanPipeline>>, kVulkanMaxNumDevice>
ecache_;
};
Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
std::shared_ptr<VulkanModuleNode> n = std::make_shared<VulkanModuleNode>(smap, fmap, source);
return Module(n);
}
VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); }
VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) {
if (!staging_buffers_[device_id]) {
staging_buffers_[device_id] = std::unique_ptr<VulkanStagingBuffer>(new VulkanStagingBuffer());
}
auto& buf = *(staging_buffers_[device_id]);
if (buf.device != nullptr && buf.size < size) {
// free previous buffer
if (buf.host_addr != nullptr) {
vkUnmapMemory(buf.device, buf.memory);
}
if (buf.memory != VK_NULL_HANDLE) {
vkFreeMemory(buf.device, buf.memory, nullptr);
}
if (buf.buffer != VK_NULL_HANDLE) {
vkDestroyBuffer(buf.device, buf.buffer, nullptr);
}
buf.host_addr = nullptr;
buf.memory = VK_NULL_HANDLE;
buf.buffer = VK_NULL_HANDLE;
}
const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
if (buf.device == nullptr) {
buf.device = vctx.device;
}
if (buf.memory == VK_NULL_HANDLE) {
// allocate the stagging buffer memory if necessary
VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.size = size;
info.queueFamilyIndexCount = 1;
info.pQueueFamilyIndices = &(vctx.queue_family_index);
info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer)));
VkMemoryAllocateInfo minfo;
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
minfo.pNext = nullptr;
minfo.allocationSize = size;
minfo.memoryTypeIndex = vctx.staging_mtype_index;
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory)));
VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0));
VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr)));
buf.size = size;
}
memset(buf.host_addr, 0, size);
return &buf;
}
VulkanThreadEntry::VulkanThreadEntry()
: pool(static_cast<DLDeviceType>(kDLVulkan), VulkanDeviceAPI::Global()) {
ctx.device_id = 0;
ctx.device_type = static_cast<DLDeviceType>(kDLVulkan);
}
VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
if (!streams_[device_id]) {
streams_[device_id] = std::unique_ptr<VulkanStream>(
new VulkanStream(&VulkanDeviceAPI::Global()->context(device_id)));
}
return streams_[device_id].get();
}
void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
const ArgUnion* pack_args) const {
int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
CHECK_LT(device_id, kVulkanMaxNumDevice);
const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
if (!scache_[device_id]) {
scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_);
}
const auto& pipeline = scache_[device_id];
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
std::vector<VkDescriptorBufferInfo> descriptor_buffers;
descriptor_buffers.resize(num_buffer_args_);
for (int i = 0; i < num_buffer_args_; ++i) {
void* buf = args[static_cast<int>(i)];
VkDescriptorBufferInfo binfo;
binfo.buffer = static_cast<VulkanBuffer*>(buf)->buffer;
binfo.offset = 0;
binfo.range = VK_WHOLE_SIZE;
descriptor_buffers[i] = binfo;
}
if (vctx.UseImmediate()) {
// Can safely capture by reference as this lambda is immediately executed on the calling thread.
VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) {
vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
CHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE);
vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
descriptor_buffers.data());
if (num_pack_args_ != 0) {
vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion),
pack_args);
}
vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
VkMemoryBarrier barrier_info;
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
barrier_info.pNext = nullptr;
barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1, &barrier_info, 0, nullptr, 0, nullptr);
});
return;
}
// Otherwise, the more expensive deferred path.
std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_);
const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
std::vector<VkWriteDescriptorSet> write_descriptor_sets;
write_descriptor_sets.resize(descriptor_buffers.size());
for (int i = 0; i < write_descriptor_sets.size(); i++) {
write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
write_descriptor_sets[i].pNext = 0;
write_descriptor_sets[i].dstSet = pipeline->descriptor_set;
write_descriptor_sets[i].dstBinding = i;
write_descriptor_sets[i].dstArrayElement = 0;
write_descriptor_sets[i].descriptorCount = 1;
write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
write_descriptor_sets[i].pImageInfo = 0;
write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]);
write_descriptor_sets[i].pTexelBufferView = 0;
}
vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(),
0, 0);
};
const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) {
vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
nullptr);
if (pack_args_storage.size() != 0) {
vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data());
}
vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
VkMemoryBarrier barrier_info;
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
barrier_info.pNext = nullptr;
barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1, &barrier_info, 0, nullptr, 0, nullptr);
};
VulkanStreamToken deferred_token;
deferred_token.descriptor_set_ = pipeline->descriptor_set;
deferred_token.buffers_.resize(descriptor_buffers.size());
for (int i = 0; i < descriptor_buffers.size(); ++i) {
deferred_token.buffers_[i] = descriptor_buffers[i].buffer;
}
VulkanThreadEntry::ThreadLocal()->Stream(device_id)->LaunchDeferred(
deferred_initializer, deferred_kernel, deferred_token);
}
Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
std::unordered_map<std::string, VulkanShader> smap;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
dmlc::MemoryStringStream fs(&data);
dmlc::Stream* stream = &fs;
uint32_t magic;
stream->Read(&magic);
CHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch";
stream->Read(&smap);
return VulkanModuleCreate(smap, fmap, "");
}
Module VulkanModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::unordered_map<std::string, VulkanShader> smap;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt;
stream->Read(&fmt);
stream->Read(&fmap);
stream->Read(&smap);
return VulkanModuleCreate(smap, fmap, "");
}
TVM_REGISTER_GLOBAL("module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = VulkanDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
} // namespace vulkan
} // namespace runtime
} // namespace tvm
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -16,26 +16,18 @@
* specific language governing permissions and limitations
* under the License.
*/
#pragma once
/*!
* Copyright (c) 2017 by Contributors
* \file vulkan_common.h
* \brief Vulkan common header
*/
#ifndef TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
#define TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <dmlc/logging.h>
#include <tvm/runtime/packed_func.h>
#include <vulkan/vulkan.h>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include <memory>
#include "../workspace_pool.h"
namespace tvm {
namespace runtime {
......@@ -43,25 +35,44 @@ namespace vulkan {
inline const char* VKGetErrorString(VkResult error) {
switch (error) {
case VK_SUCCESS: return "VK_SUCCESS";
case VK_NOT_READY: return "VK_NOT_READY";
case VK_TIMEOUT: return "VK_TIMEOUT";
case VK_EVENT_SET: return "VK_EVENT_SET";
case VK_EVENT_RESET: return "VK_EVENT_RESET";
case VK_INCOMPLETE: return "VK_INCOMPLETE";
case VK_ERROR_OUT_OF_HOST_MEMORY: return "VK_ERROR_OUT_OF_HOST_MEMORY";
case VK_ERROR_OUT_OF_DEVICE_MEMORY: return "VK_ERROR_OUT_OF_DEVICE_MEMORY";
case VK_ERROR_INITIALIZATION_FAILED: return "VK_ERROR_INITIALIZATION_FAILED";
case VK_ERROR_DEVICE_LOST: return "VK_ERROR_DEVICE_LOST";
case VK_ERROR_MEMORY_MAP_FAILED: return "VK_ERROR_MEMORY_MAP_FAILED";
case VK_ERROR_LAYER_NOT_PRESENT: return "VK_ERROR_LAYER_NOT_PRESENT";
case VK_ERROR_EXTENSION_NOT_PRESENT: return "VK_ERROR_EXTENSION_NOT_PRESENT";
case VK_ERROR_FEATURE_NOT_PRESENT: return "VK_ERROR_FEATURE_NOT_PRESENT";
case VK_ERROR_INCOMPATIBLE_DRIVER: return "VK_ERROR_INCOMPATIBLE_DRIVER";
case VK_ERROR_TOO_MANY_OBJECTS: return "VK_ERROR_TOO_MANY_OBJECTS";
case VK_ERROR_FORMAT_NOT_SUPPORTED: return "VK_ERROR_FORMAT_NOT_SUPPORTED";
case VK_ERROR_FRAGMENTED_POOL: return "VK_ERROR_FRAGMENTED_POOL";
default: return "Unknown Vulkan error code";
case VK_SUCCESS:
return "VK_SUCCESS";
case VK_NOT_READY:
return "VK_NOT_READY";
case VK_TIMEOUT:
return "VK_TIMEOUT";
case VK_EVENT_SET:
return "VK_EVENT_SET";
case VK_EVENT_RESET:
return "VK_EVENT_RESET";
case VK_INCOMPLETE:
return "VK_INCOMPLETE";
case VK_ERROR_OUT_OF_HOST_MEMORY:
return "VK_ERROR_OUT_OF_HOST_MEMORY";
case VK_ERROR_OUT_OF_DEVICE_MEMORY:
return "VK_ERROR_OUT_OF_DEVICE_MEMORY";
case VK_ERROR_INITIALIZATION_FAILED:
return "VK_ERROR_INITIALIZATION_FAILED";
case VK_ERROR_DEVICE_LOST:
return "VK_ERROR_DEVICE_LOST";
case VK_ERROR_MEMORY_MAP_FAILED:
return "VK_ERROR_MEMORY_MAP_FAILED";
case VK_ERROR_LAYER_NOT_PRESENT:
return "VK_ERROR_LAYER_NOT_PRESENT";
case VK_ERROR_EXTENSION_NOT_PRESENT:
return "VK_ERROR_EXTENSION_NOT_PRESENT";
case VK_ERROR_FEATURE_NOT_PRESENT:
return "VK_ERROR_FEATURE_NOT_PRESENT";
case VK_ERROR_INCOMPATIBLE_DRIVER:
return "VK_ERROR_INCOMPATIBLE_DRIVER";
case VK_ERROR_TOO_MANY_OBJECTS:
return "VK_ERROR_TOO_MANY_OBJECTS";
case VK_ERROR_FORMAT_NOT_SUPPORTED:
return "VK_ERROR_FORMAT_NOT_SUPPORTED";
case VK_ERROR_FRAGMENTED_POOL:
return "VK_ERROR_FRAGMENTED_POOL";
default:
return "Unknown Vulkan error code";
}
}
......@@ -69,19 +80,37 @@ inline const char* VKGetErrorString(VkResult error) {
* \brief Protected Vulkan call
* \param func Expression to call.
*/
#define VULKAN_CHECK_ERROR(__e) \
{ \
CHECK(__e == VK_SUCCESS) \
<< "Vulan Error, code=" << __e << ": " << vulkan::VKGetErrorString(__e); \
#define VULKAN_CHECK_ERROR(__e) \
{ \
CHECK(__e == VK_SUCCESS) << "Vulan Error, code=" << __e << ": " \
<< vulkan::VKGetErrorString(__e); \
}
#define VULKAN_CALL(func) \
{ \
VkResult __e = (func); \
VULKAN_CHECK_ERROR(__e); \
#define VULKAN_CALL(func) \
{ \
VkResult __e = (func); \
VULKAN_CHECK_ERROR(__e); \
}
/*! \brief Auxiliary context structure for vulkan */
struct VulkanDescriptorTemplateKHRFunctions {
PFN_vkCreateDescriptorUpdateTemplateKHR vkCreateDescriptorUpdateTemplateKHR{nullptr};
PFN_vkDestroyDescriptorUpdateTemplateKHR vkDestroyDescriptorUpdateTemplateKHR{nullptr};
PFN_vkUpdateDescriptorSetWithTemplateKHR vkUpdateDescriptorSetWithTemplateKHR{nullptr};
PFN_vkCmdPushDescriptorSetWithTemplateKHR vkCmdPushDescriptorSetWithTemplateKHR{nullptr};
};
struct VulkanGetBufferMemoryRequirements2Functions {
PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr};
};
struct VulkanStagingBuffer {
VkDevice device{nullptr};
VkBuffer buffer{VK_NULL_HANDLE};
VkDeviceMemory memory{VK_NULL_HANDLE};
void* host_addr{nullptr};
size_t size{0};
};
struct VulkanContext {
// phyiscal device
VkPhysicalDevice phy_device{nullptr};
......@@ -91,211 +120,27 @@ struct VulkanContext {
uint32_t staging_mtype_index{0};
// whether staging is coherent
bool coherent_staging{false};
std::unique_ptr<VulkanDescriptorTemplateKHRFunctions> descriptor_template_khr_functions{nullptr};
std::unique_ptr<VulkanGetBufferMemoryRequirements2Functions>
get_buffer_memory_requirements_2_functions{nullptr};
// Memory type index for compute
uint32_t compute_mtype_index{0};
// The logical device
VkDevice device{nullptr};
// command queue
std::unique_ptr<std::mutex> queue_mutex;
VkQueue queue{nullptr};
// queue family_index;
uint32_t queue_family_index{0};
// Queue family index.
VkQueueFamilyProperties queue_prop;
};
/*! \brief The buffer object */
struct VulkanBuffer {
/*! \brief underlying buffer */
VkBuffer buffer{VK_NULL_HANDLE};
/*! \brief underlying buffer */
VkDeviceMemory memory{VK_NULL_HANDLE};
bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; }
};
/*! \brief Buffer only used for stagging */
struct VulkanStagingBuffer {
/*! \brief the corresponding device */
VkDevice device{nullptr};
/*! \brief underlying buffer */
VkBuffer buffer{VK_NULL_HANDLE};
/*! \brief underlying buffer */
VkDeviceMemory memory{VK_NULL_HANDLE};
/*! \brief host address */
void* host_addr{nullptr};
/*! \brief size of the memory */
size_t size{0};
};
/*!
* \brief Process global Vulkan workspace.
*/
class VulkanWorkspace final : public DeviceAPI {
public:
// global mutex
std::mutex mu;
// whether the workspace it initialized.
bool initialized_{false};
// vulkan instance
VkInstance instance_{nullptr};
// The physical devices, have 1 to 1 mapping to devices
std::vector<VulkanContext> context_;
// Destructor
~VulkanWorkspace();
// Initialize workspace
// Return false if already initialized, otherwise return true.
void Init();
// override device API
void SetDevice(TVMContext ctx) final;
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
TVMType type_hint) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from,
size_t from_size,
void* to,
size_t to_size,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace
static const std::shared_ptr<VulkanWorkspace>& Global();
};
/*! \brief Helper command buffer resource */
struct VulkanCommandBuffer {
/*! \brief fence to signal the resource is ready to use */
VkFence fence{VK_NULL_HANDLE};
/*! \brief The internal command buffer */
VkCommandBuffer cmd_buffer{nullptr};
/*! \brief Descriptor set used to bind arguments */
VkDescriptorSet descriptor_set{VK_NULL_HANDLE};
/*! \brief Internal utilities for write command */
VkWriteDescriptorSet write_descriptor_set;
VulkanCommandBuffer() {
write_descriptor_set.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
write_descriptor_set.pNext = nullptr;
write_descriptor_set.dstSet = VK_NULL_HANDLE;
write_descriptor_set.dstBinding = 0;
write_descriptor_set.dstArrayElement = 0;
write_descriptor_set.descriptorCount = 1;
write_descriptor_set.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
write_descriptor_set.pImageInfo = nullptr;
write_descriptor_set.pBufferInfo = nullptr;
write_descriptor_set.pTexelBufferView = nullptr;
}
};
/*!
* \brief Command pool backed by a fixed size ring buffer.
*
* Vulkan requires us not to reuse command buffer until
* All its corresponding jobs have finished.
*
* This class to faciliate automatic management
* of the command buffers. A fence is created
* for each launch of command buffer jobs
* and when we try to reuse the same entry
* in the ring, we need to make sure that
* the previous pending job already finishes.
*
*/
class VulkanCommandPool {
public:
/*! \brief Maximum number of pending jobs in the pool */
static constexpr const int kMaxPending = 4;
/*! \brief Maximum number of pending jobs in the pool */
static constexpr const int kMaxNumArgs = 16;
/*!
* \brief constructor
* \param vctx The corresponding vulkan context.
*/
explicit VulkanCommandPool(const VulkanContext& vctx);
/*! \brief destructor */
~VulkanCommandPool();
/*!
* \brief Allocate a new command buffer entry
*
* The caller must only submit the entry once
* with the given fence in the entry,
* before calling next Alloc.
*
* This function may block to wait for a
* previously unfinished command when
* there is more than kMaxPending jobs.
*
* \returns The allocated entry.
*/
VulkanCommandBuffer* Alloc();
/*!
* \brief Allocate a new command buffer entry
* \param dlayout the descriptor layout.
*
* \returns The allocated entry.
*/
VulkanCommandBuffer* Alloc(const VkDescriptorSetLayout* dlayout);
private:
/*! \brief Local ring buffer */
std::vector<VulkanCommandBuffer> ring_;
/*! \brief clock pointer */
size_t clock_ptr_{0};
/*! \brief the corresponding device*/
VkDevice device_{nullptr};
/*! \brief internal command buffer pool */
VkCommandPool cmd_pool_{VK_NULL_HANDLE};
/*! \brief Descriptor pool */
VkDescriptorPool descriptor_pool_{VK_NULL_HANDLE};
};
/*! \brief Thread local workspace */
class VulkanThreadEntry {
public:
/*! \brief The current context */
TVMContext context;
/*! \brief workspace pool */
WorkspacePool pool;
/*! \brief The staging buffers */
std::vector<VulkanStagingBuffer> staging_buffer_;
/*!
* \brief Get the command pool of corresponding device;
* \param device_id The device id
* \return The corresponding command buffer.
*/
VulkanCommandPool* CommandPool(int device_id);
/*!
* \brief Get the stagging buffer.
* \param device_id The device id
* \return The corresponding stagging buffer.
*/
VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
// constructor
VulkanThreadEntry()
: pool(static_cast<DLDeviceType>(kDLVulkan), VulkanWorkspace::Global()) {
context.device_id = 0;
context.device_type = static_cast<DLDeviceType>(kDLVulkan);
}
~VulkanThreadEntry();
// get the global workspace
static VulkanThreadEntry* ThreadLocal();
private:
/*! \brief the command pools */
std::vector<std::unique_ptr<VulkanCommandPool> > pool_;
};
// inline implementation
} // namespace vulkan
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2017 by Contributors
* \file vulkan_device_api.cc
*/
#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
#include <cstring>
#include "vulkan_common.h"
namespace tvm {
namespace runtime {
namespace vulkan {
VulkanWorkspace::~VulkanWorkspace() {
for (VulkanContext& ctx : context_) {
vkDestroyDevice(ctx.device, nullptr);
}
if (instance_ != nullptr) {
vkDestroyInstance(instance_, nullptr);
}
}
const std::shared_ptr<VulkanWorkspace>& VulkanWorkspace::Global() {
static std::shared_ptr<VulkanWorkspace> inst = std::make_shared<VulkanWorkspace>();
return inst;
}
void VulkanWorkspace::SetDevice(TVMContext ctx) {
VulkanThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
}
void VulkanWorkspace::GetAttr(
TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
this->Init();
size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) {
*rv = static_cast<int>(index< context_.size());
return;
}
CHECK_LT(index, context_.size())
<< "Invalid device id " << index;
switch (kind) {
case kMaxThreadsPerBlock: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
int64_t value = phy_prop.limits.maxComputeWorkGroupSize[0];
*rv = value;
break;
}
case kMaxSharedMemoryPerBlock: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
*rv = value;
break;
}
case kWarpSize: {
*rv = 1;
break;
}
case kComputeVersion: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
int64_t value = phy_prop.apiVersion;
std::ostringstream os;
os << VK_VERSION_MAJOR(value)
<< "." << VK_VERSION_MINOR(value)
<< "." << VK_VERSION_PATCH(value);
*rv = os.str();
break;
}
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kExist: break;
case kMaxThreadDimensions: break;
}
}
void* VulkanWorkspace::AllocDataSpace(
TVMContext ctx, size_t size, size_t alignment, TVMType type_hint) {
this->Init();
VulkanContext& vctx = context_[ctx.device_id];
VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.size = size;
info.queueFamilyIndexCount = 1;
info.pQueueFamilyIndices = &(vctx.queue_family_index);
info.usage =
VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
VK_BUFFER_USAGE_TRANSFER_DST_BIT |
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
// create buffer
VkBuffer buffer;
VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
// bind to memory
VkMemoryAllocateInfo minfo;
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
minfo.pNext = nullptr;
minfo.allocationSize = size;
minfo.memoryTypeIndex = vctx.compute_mtype_index;
VkDeviceMemory memory;
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
VulkanBuffer* pbuf = new VulkanBuffer();
pbuf->memory = memory;
pbuf->buffer = buffer;
return pbuf;
}
void VulkanWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
VulkanContext& vctx = context_[ctx.device_id];
VulkanBuffer* pbuf = static_cast<VulkanBuffer*>(ptr);
vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr);
vkFreeMemory(vctx.device, pbuf->memory, nullptr);
delete pbuf;
}
void VulkanWorkspace::CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) {
this->Init();
CHECK(stream == nullptr);
TVMContext ctx = ctx_from;
if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
VulkanThreadEntry* tls = VulkanThreadEntry::ThreadLocal();
VulkanCommandBuffer* cmd = tls->CommandPool(ctx.device_id)->Alloc();
VkCommandBufferBeginInfo cb_begin;
cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
cb_begin.pNext = nullptr;
cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
cb_begin.pInheritanceInfo = 0;
VkSubmitInfo cb_submit;
cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
cb_submit.pNext = nullptr;
cb_submit.waitSemaphoreCount = 0;
cb_submit.pWaitSemaphores = nullptr;
cb_submit.pWaitDstStageMask = 0;
cb_submit.commandBufferCount = 1;
cb_submit.pCommandBuffers = &(cmd->cmd_buffer);
cb_submit.signalSemaphoreCount = 0;
cb_submit.pSignalSemaphores = nullptr;
int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type);
if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) {
CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
<< "Vulkan disallow cross device copy.";
const VulkanContext& vctx = context_[ctx_from.device_id];
const VulkanBuffer* from_buf = static_cast<const VulkanBuffer*>(from);
VulkanBuffer* to_buf = static_cast<VulkanBuffer*>(to);
// The assumption is that subsequence ops only perform compute/transfer
// 0: begin
VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
// 1: copy
VkBufferCopy copy_info;
copy_info.srcOffset = from_offset;
copy_info.dstOffset = to_offset;
copy_info.size = size;
vkCmdCopyBuffer(cmd->cmd_buffer, from_buf->buffer, to_buf->buffer, 1, &copy_info);
// 2: barrier(transfer-> compute|transfer)
VkMemoryBarrier barrier_info;
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
barrier_info.pNext = nullptr;
barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
barrier_info.dstAccessMask =
(VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
vkCmdPipelineBarrier(
cmd->cmd_buffer,
VK_PIPELINE_STAGE_TRANSFER_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
0, 1, &barrier_info, 0, nullptr, 0, nullptr);
// 3: end
VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
// 4: submit with cmd->fence
VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
} else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) {
const VulkanContext& vctx = context_[ctx_from.device_id];
const VulkanBuffer* from_buf = static_cast<const VulkanBuffer*>(from);
VulkanStagingBuffer* temp = tls->StagingBuffer(ctx_from.device_id, size);
// 0: begin
VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
// 1: copy
VkBufferCopy copy_info;
copy_info.srcOffset = from_offset;
copy_info.dstOffset = 0;
copy_info.size = size;
vkCmdCopyBuffer(cmd->cmd_buffer,
from_buf->buffer,
temp->buffer,
1, &copy_info);
// 2: end
VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
// 4: submit with cmd->fence
VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
// Block until done, to make sure temp can be reused later.
VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
// host side invalidation if access is not coherent.
// so writes from GPU is visible to CPU
if (!vctx.coherent_staging) {
VkMappedMemoryRange mrange;
mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
mrange.pNext = nullptr;
mrange.memory = temp->memory;
mrange.offset = 0;
mrange.size = size;
VULKAN_CALL(vkInvalidateMappedMemoryRanges(
vctx.device, 1, &mrange));
}
memcpy(static_cast<char*>(to) + to_offset,
static_cast<char*>(temp->host_addr),
size);
} else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) {
const VulkanContext& vctx = context_[ctx_to.device_id];
const VulkanBuffer* to_buf = static_cast<const VulkanBuffer*>(to);
VulkanStagingBuffer* temp = tls->StagingBuffer(ctx_to.device_id, size);
memcpy(temp->host_addr,
static_cast<const char*>(from) + from_offset,
size);
// host side flush if access is not coherent.
// so writes from CPU is visible to GPU
if (!vctx.coherent_staging) {
VkMappedMemoryRange mrange;
mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
mrange.pNext = nullptr;
mrange.memory = temp->memory;
mrange.offset = 0;
mrange.size = size;
VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange));
}
VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
// 0: barrier(host->transfer)
VkMemoryBarrier barrier_info;
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
barrier_info.pNext = nullptr;
barrier_info.srcAccessMask = 0;
barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
vkCmdPipelineBarrier(cmd->cmd_buffer,
VK_PIPELINE_STAGE_HOST_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT,
0, 1, &barrier_info,
0, nullptr, 0, nullptr);
// 1: copy
VkBufferCopy copy_info;
copy_info.srcOffset = 0;
copy_info.dstOffset = to_offset;
copy_info.size = size;
vkCmdCopyBuffer(cmd->cmd_buffer,
temp->buffer,
to_buf->buffer,
1, &copy_info);
// 2: end
VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
// 4: submit with cmd->fence
VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
// wait until copy finishes, so we can reuse temp next time.
VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
} else {
LOG(FATAL) << "Expect copy from/to Metal or between Metal"
<< ", from=" << from_dev_type
<< ", to=" << to_dev_type;
}
}
void VulkanWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
CHECK(stream == nullptr);
VulkanContext& vctx = context_[ctx.device_id];
VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
}
void* VulkanWorkspace::AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) {
return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
void VulkanWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}
// VulkanCommandPool
VulkanCommandPool::VulkanCommandPool(const VulkanContext& vctx) {
ring_.resize(kMaxPending, VulkanCommandBuffer());
device_ = vctx.device;
{
// create command pool
VkCommandPoolCreateInfo cmd_pool_cinfo;
cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
cmd_pool_cinfo.pNext = nullptr;
cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
cmd_pool_cinfo.queueFamilyIndex = vctx.queue_family_index;
VULKAN_CALL(vkCreateCommandPool(device_, &cmd_pool_cinfo, nullptr, &cmd_pool_));
}
{
// create descriptor pool
VkDescriptorPoolSize pool_size;
pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
pool_size.descriptorCount = kMaxPending * kMaxNumArgs;
VkDescriptorPoolCreateInfo descrip_pool_cinfo;
descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
descrip_pool_cinfo.pNext = nullptr;
descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
descrip_pool_cinfo.maxSets = kMaxPending + 2;
descrip_pool_cinfo.poolSizeCount = 1;
descrip_pool_cinfo.pPoolSizes = &pool_size;
VULKAN_CALL(vkCreateDescriptorPool(
device_, &descrip_pool_cinfo, nullptr, &descriptor_pool_));
}
VkCommandBufferAllocateInfo buffer_alloc_info;
buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
buffer_alloc_info.pNext = nullptr;
buffer_alloc_info.commandPool = cmd_pool_;
buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
buffer_alloc_info.commandBufferCount = 1;
VkFenceCreateInfo fence_cinfo;
fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
fence_cinfo.pNext = nullptr;
fence_cinfo.flags = VK_FENCE_CREATE_SIGNALED_BIT;
for (size_t i = 0; i < ring_.size(); ++i) {
VULKAN_CALL(vkAllocateCommandBuffers(
device_, &buffer_alloc_info, &(ring_[i].cmd_buffer)));
VULKAN_CALL(vkCreateFence(
device_, &fence_cinfo, nullptr, &(ring_[i].fence)));
}
}
VulkanCommandPool::~VulkanCommandPool() {
// wait device to be idle so we know we can recycle buffers
VULKAN_CALL(vkDeviceWaitIdle(device_));
// start recycling.
for (size_t i = 0; i < ring_.size(); ++i) {
if (ring_[i].cmd_buffer != nullptr) {
vkFreeCommandBuffers(device_, cmd_pool_, 1, &(ring_[i].cmd_buffer));
ring_[i].cmd_buffer = nullptr;
}
if (ring_[i].fence != VK_NULL_HANDLE) {
vkDestroyFence(device_, ring_[i].fence, nullptr);
}
}
// delete cmd_pool and descriptor pool
vkDestroyCommandPool(device_, cmd_pool_, nullptr);
vkDestroyDescriptorPool(device_, descriptor_pool_, nullptr);
}
VulkanCommandBuffer* VulkanCommandPool::Alloc() {
return Alloc(nullptr);
}
VulkanCommandBuffer* VulkanCommandPool::Alloc(
const VkDescriptorSetLayout* dlayout) {
// always allocate resource in round robin manner
VulkanCommandBuffer* e = &(ring_[clock_ptr_]);
clock_ptr_ = (clock_ptr_ + 1) % ring_.size();
// Wait until previous usage of commad buffer is finished.
uint64_t timeout = 1UL << 30UL;
VkResult res;
res = vkWaitForFences(device_, 1, &(e->fence), 0, timeout);
while (res == VK_TIMEOUT) {
res = vkWaitForFences(device_, 1, &(e->fence), 0, timeout);
}
VULKAN_CHECK_ERROR(res);
vkResetFences(device_, 1, (&e->fence));
if (e->descriptor_set != VK_NULL_HANDLE) {
VULKAN_CALL(vkFreeDescriptorSets(
device_, descriptor_pool_, 1, &(e->descriptor_set)));
e->descriptor_set = VK_NULL_HANDLE;
}
if (dlayout != nullptr) {
VkDescriptorSetAllocateInfo alloc_info;
alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
alloc_info.pNext = nullptr;
alloc_info.descriptorPool = descriptor_pool_;
alloc_info.descriptorSetCount = 1;
alloc_info.pSetLayouts = dlayout;
VULKAN_CALL(vkAllocateDescriptorSets(
device_, &alloc_info, &(e->descriptor_set)));
}
return e;
}
// VulkanThreadEntry
typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() {
return VulkanThreadStore::Get();
}
VulkanCommandPool* VulkanThreadEntry::CommandPool(int device_id) {
while (pool_.size() <= static_cast<size_t>(device_id)) {
pool_.emplace_back(std::unique_ptr<VulkanCommandPool>());
}
if (pool_[device_id] == nullptr) {
const VulkanContext& vctx =
VulkanWorkspace::Global()->context_[device_id];
pool_[device_id].reset(new VulkanCommandPool(vctx));
}
return pool_[device_id].get();
}
VulkanStagingBuffer*
VulkanThreadEntry::StagingBuffer(int device_id, size_t size) {
if (staging_buffer_.size() <= static_cast<size_t>(device_id)) {
staging_buffer_.resize(device_id + 1, VulkanStagingBuffer());
}
VulkanStagingBuffer& buf = staging_buffer_[device_id];
if (buf.device != nullptr && buf.size < size) {
// free previous buffer
if (buf.host_addr != nullptr) {
vkUnmapMemory(buf.device, buf.memory);
}
if (buf.memory != VK_NULL_HANDLE) {
vkFreeMemory(buf.device, buf.memory, nullptr);
}
if (buf.buffer != VK_NULL_HANDLE) {
vkDestroyBuffer(buf.device, buf.buffer, nullptr);
}
buf.host_addr = nullptr;
buf.memory = VK_NULL_HANDLE;
buf.buffer = VK_NULL_HANDLE;
}
const VulkanContext& vctx =
VulkanWorkspace::Global()->context_[device_id];
if (buf.device == nullptr) {
buf.device = vctx.device;
}
if (buf.memory == VK_NULL_HANDLE) {
// allocate the stagging buffer memory if necessary
VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.size = size;
info.queueFamilyIndexCount = 1;
info.pQueueFamilyIndices = &(vctx.queue_family_index);
info.usage =
VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
VK_BUFFER_USAGE_TRANSFER_DST_BIT;
VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer)));
VkMemoryAllocateInfo minfo;
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
minfo.pNext = nullptr;
minfo.allocationSize = size;
minfo.memoryTypeIndex = vctx.staging_mtype_index;
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory)));
VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0));
VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr)));
buf.size = size;
}
memset(buf.host_addr, 0, size);
return &buf;
}
VulkanThreadEntry::~VulkanThreadEntry() {
// Because the thread entry refers to Device API
// The command buffer always will be destroyed before
// the instance and device get destroyed.
// The destruction need to be manually called
// to ensure the destruction order.
pool_.clear();
for (VulkanStagingBuffer buf : staging_buffer_) {
if (buf.host_addr != nullptr) {
vkUnmapMemory(buf.device, buf.memory);
}
if (buf.memory != VK_NULL_HANDLE) {
vkFreeMemory(buf.device, buf.memory, nullptr);
}
if (buf.buffer != VK_NULL_HANDLE) {
vkDestroyBuffer(buf.device, buf.buffer, nullptr);
}
}
}
VkInstance CreateInstance() {
VkApplicationInfo app_info;
app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
app_info.pNext = nullptr;
app_info.pApplicationName = "TVM";
app_info.applicationVersion = 0;
app_info.pEngineName = "";
app_info.engineVersion = 0;
app_info.apiVersion = VK_MAKE_VERSION(1, 0, 65);
VkInstanceCreateInfo inst_info;
inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
inst_info.pNext = nullptr;
inst_info.flags = 0;
inst_info.pApplicationInfo = &app_info;
inst_info.enabledLayerCount = 0;
inst_info.ppEnabledLayerNames = nullptr;
inst_info.enabledExtensionCount = 0;
inst_info.ppEnabledExtensionNames = nullptr;
VkInstance inst;
VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &inst));
return inst;
}
// find suitable mem_type_index for staging and compute
void FindMemoryTypeIndex(VulkanContext* vctx) {
// Find suitable compute index.
VkBuffer buffer;
VkMemoryRequirements req_staging, req_compute;
VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.size = 1024;
info.queueFamilyIndexCount = 1;
info.pQueueFamilyIndices = &(vctx->queue_family_index);
// get staging requirement
info.usage =
VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
VK_BUFFER_USAGE_TRANSFER_DST_BIT;
VULKAN_CALL(vkCreateBuffer(vctx->device, &info, nullptr, &buffer));
vkGetBufferMemoryRequirements(vctx->device, buffer, &req_staging);
vkDestroyBuffer(vctx->device, buffer, nullptr);
// get compute requirement
info.usage =
VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
VK_BUFFER_USAGE_TRANSFER_DST_BIT |
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
VULKAN_CALL(vkCreateBuffer(vctx->device, &info, nullptr, &buffer));
vkGetBufferMemoryRequirements(vctx->device, buffer, &req_compute);
vkDestroyBuffer(vctx->device, buffer, nullptr);
// Query phyiscal device property
// find a memory that is host visible, no need to be consistent
int win_rank = -1;
VkPhysicalDeviceMemoryProperties prop;
vkGetPhysicalDeviceMemoryProperties(vctx->phy_device, &prop);
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
// match copy requirment
if (!(req_staging.memoryTypeBits & (1 << k))) continue;
if (heap_size < 1024) continue;
int rank = 0;
rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
if (rank > win_rank) {
win_rank = rank;
vctx->staging_mtype_index = k;
vctx->coherent_staging =
ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
}
}
CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
win_rank = -1;
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
// match copy requirment
if (!(req_staging.memoryTypeBits & (1 << k))) continue;
if (heap_size < 1024) continue;
int rank = 0;
// prefer not host visible
rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
if (rank > win_rank) {
win_rank = rank;
vctx->compute_mtype_index = k;
}
}
CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
}
// Get all logic devices that support compute
std::vector<VulkanContext> GetContext(VkInstance instance) {
std::vector<VulkanContext> result;
uint32_t phy_dev_count = 0;
VULKAN_CALL(vkEnumeratePhysicalDevices(
instance, &phy_dev_count, nullptr));
std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
VULKAN_CALL(vkEnumeratePhysicalDevices(
instance, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
for (VkPhysicalDevice phy_dev : all_phy_devs) {
uint32_t queue_prop_count = 0;
vkGetPhysicalDeviceQueueFamilyProperties(
phy_dev, &queue_prop_count, nullptr);
std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
vkGetPhysicalDeviceQueueFamilyProperties(
phy_dev, &queue_prop_count, dmlc::BeginPtr(queue_props));
uint32_t queue_family_index = 0;
std::vector<VkDeviceQueueCreateInfo> queue_create_info;
for (uint32_t i = 0; i < queue_props.size(); i++) {
// find queues that support compute
if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
float priority = 1.0f;
VkDeviceQueueCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.queueFamilyIndex = i;
info.queueCount = 1;
info.pQueuePriorities = &priority;
queue_create_info.push_back(info);
// only use the first available queue for now
if (queue_create_info.size() == 0) {
queue_family_index = i;
}
}
}
if (queue_create_info.size() == 0) continue;
VkDeviceCreateInfo device_create_info;
device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
device_create_info.pNext = nullptr;
device_create_info.flags = 0;
device_create_info.queueCreateInfoCount
= static_cast<uint32_t>(queue_create_info.size());
device_create_info.pQueueCreateInfos = queue_create_info.data();
device_create_info.enabledLayerCount = 0;
device_create_info.ppEnabledLayerNames = nullptr;
device_create_info.enabledExtensionCount = 0;
device_create_info.ppEnabledExtensionNames = nullptr;
device_create_info.pEnabledFeatures = nullptr;
VulkanContext ctx;
// setup context
ctx.phy_device = phy_dev;
vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop));
VULKAN_CALL(vkCreateDevice(
phy_dev, &device_create_info, nullptr, &(ctx.device)));
vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
ctx.queue_family_index = queue_family_index;
FindMemoryTypeIndex(&ctx);
// Find suitable memory type for staging and compute
result.push_back(ctx);
}
return result;
}
void VulkanWorkspace::Init() {
if (initialized_) return;
std::lock_guard<std::mutex> lock(this->mu);
if (initialized_) return;
initialized_ = true;
try {
instance_ = CreateInstance();
context_ = GetContext(instance_);
LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices..";
for (size_t i = 0; i < context_.size(); ++i) {
LOG(INFO) << "vulkan(" << i
<< ")=\'" << context_[i].phy_device_prop.deviceName
<< "\' phy_dev_id=" << context_[i].phy_device;
}
} catch (const dmlc::Error& err) {
LOG(INFO) << "Cannot initialize vulkan: " << err.what() << "\n"
<< "You can still compile vulkan module but cannot run locally";
}
}
bool InitVulkan(TVMArgs args, TVMRetValue* rv) {
vulkan::VulkanWorkspace::Global()->Init();
return true;
}
TVM_REGISTER_GLOBAL("device_api.vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = VulkanWorkspace::Global().get();
*rv = static_cast<void*>(ptr);
});
} // namespace vulkan
} // namespace runtime
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file vulkan_module.cc
*/
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <array>
#include <string>
#include <mutex>
#include "vulkan_common.h"
#include "vulkan_module.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "../meta_data.h"
#include "../file_util.h"
namespace tvm {
namespace runtime {
void VulkanShader::Save(dmlc::Stream* writer) const {
writer->Write(flag);
writer->Write(data);
}
bool VulkanShader::Load(dmlc::Stream* reader) {
if (!reader->Read(&flag)) return false;
if (!reader->Read(&data)) return false;
return true;
}
// Multi-device enabled module.
class VulkanModuleNode final :public runtime::ModuleNode {
public:
// Pipeline cache states
struct PipelineEntry {
VkShaderModule shader{VK_NULL_HANDLE};
VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
VkDescriptorSetLayout descriptor_layout{VK_NULL_HANDLE};
VkPipeline pipeline{VK_NULL_HANDLE};
};
// constructor
explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source)
: smap_(smap), fmap_(fmap), source_(source) {
}
~VulkanModuleNode() {
// cleanup vulkan related caches.
for (DeviceEntry& e : finfo_) {
if (e.device == nullptr) continue;
for (auto &kv : e.smap) {
PipelineEntry& pe = kv.second;
vkDestroyShaderModule(e.device, pe.shader, nullptr);
vkDestroyDescriptorSetLayout(e.device, pe.descriptor_layout, nullptr);
vkDestroyPipelineLayout(e.device, pe.pipeline_layout, nullptr);
vkDestroyPipeline(e.device, pe.pipeline, nullptr);
}
}
}
const char* type_key() const final {
return "vulkan";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_)
<< "Can only save to customized format vulkan";
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
std::string data_bin;
dmlc::MemoryStringStream fs(&data_bin);
dmlc::Stream* stream = &fs;
uint32_t magic = kVulkanModuleMagic;
stream->Write(magic);
stream->Write(smap_);
SaveBinaryToFile(file_name, data_bin);
}
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(smap_);
}
std::string GetSource(const std::string& format) final {
// can only return source code.
return source_;
}
// get a from primary context in device_id
PipelineEntry GetPipeline(size_t device_id,
const std::string& func_name,
size_t num_pack_args) {
vulkan::VulkanWorkspace* w = vulkan::VulkanWorkspace::Global().get();
CHECK_LT(device_id, w->context_.size());
// start lock scope.
std::lock_guard<std::mutex> lock(mutex_);
if (finfo_.size() <= device_id) {
finfo_.resize(device_id + 1, DeviceEntry());
}
DeviceEntry& e = finfo_[device_id];
auto it = e.smap.find(func_name);
if (it != e.smap.end()) return it->second;
PipelineEntry pe;
if (e.device == nullptr) {
e.device = w->context_[device_id].device;
}
{
// create shader
auto sit = smap_.find(func_name);
CHECK(sit != smap_.end());
const std::vector<uint32_t>& data = sit->second.data;
VkShaderModuleCreateInfo shader_cinfo;
shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
shader_cinfo.pNext = nullptr;
shader_cinfo.flags = 0;
shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
shader_cinfo.pCode = data.data();
VULKAN_CALL(vkCreateShaderModule(
e.device, &shader_cinfo, nullptr, &(pe.shader)));
}
std::vector<VkDescriptorSetLayoutBinding> arg_binding;
uint32_t num_pod = 0, num_buffer = 0;
{
auto fit = fmap_.find(func_name);
CHECK(fit != fmap_.end());
for (TVMType arg_type : fit->second.arg_types) {
if (arg_type.code == kHandle) {
VkDescriptorSetLayoutBinding bd;
bd.binding = num_buffer;
bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
bd.descriptorCount = 1;
bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
bd.pImmutableSamplers = nullptr;
arg_binding.push_back(bd);
++num_buffer;
} else {
++num_pod;
}
}
}
VkDescriptorSetLayoutCreateInfo descrip_cinfo;
descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
descrip_cinfo.pNext = nullptr;
descrip_cinfo.flags = 0;
descrip_cinfo.bindingCount = arg_binding.size();
descrip_cinfo.pBindings = arg_binding.data();
VULKAN_CALL(vkCreateDescriptorSetLayout(
e.device, &descrip_cinfo, nullptr, &(pe.descriptor_layout)));
VkPushConstantRange crange;
crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
crange.offset = 0;
crange.size = sizeof(ArgUnion) * num_pack_args;
VkPipelineLayoutCreateInfo playout_cinfo;
playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
playout_cinfo.pNext = nullptr;
playout_cinfo.flags = 0;
playout_cinfo.setLayoutCount = 1;
playout_cinfo.pSetLayouts = &(pe.descriptor_layout);
if (num_pack_args != 0) {
playout_cinfo.pushConstantRangeCount = 1;
playout_cinfo.pPushConstantRanges = &crange;
CHECK_LE(crange.size,
w->context_[device_id].phy_device_prop.limits.maxPushConstantsSize);
} else {
playout_cinfo.pushConstantRangeCount = 0;
playout_cinfo.pPushConstantRanges = nullptr;
}
VULKAN_CALL(vkCreatePipelineLayout(
e.device, &playout_cinfo, nullptr, &(pe.pipeline_layout)));
VkComputePipelineCreateInfo pipeline_cinfo;
pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
pipeline_cinfo.pNext = nullptr;
pipeline_cinfo.flags = 0;
pipeline_cinfo.stage.sType =
VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
pipeline_cinfo.stage.pNext = nullptr;
pipeline_cinfo.stage.flags = 0;
pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
pipeline_cinfo.stage.module = pe.shader;
pipeline_cinfo.stage.pName = func_name.c_str();
pipeline_cinfo.stage.pSpecializationInfo = nullptr;
pipeline_cinfo.layout = pe.pipeline_layout;
pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE;
pipeline_cinfo.basePipelineIndex = 0;
VULKAN_CALL(vkCreateComputePipelines(
e.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, &(pe.pipeline)));
e.smap[func_name] = pe;
return pe;
}
private:
// device specific entry
struct DeviceEntry {
VkDevice device{nullptr};
std::unordered_map<std::string, PipelineEntry> smap;
};
// the binary data
std::vector<uint32_t> data_;
// function information table.
std::unordered_map<std::string, VulkanShader> smap_;
// function information table.
std::unordered_map<std::string, FunctionInfo> fmap_;
// The format
std::string fmt_{"vulkan"};
// The source
std::string source_;
// device local pipeline information.
std::vector<DeviceEntry> finfo_;
// internal mutex when updating the module
std::mutex mutex_;
};
// a wrapped function class to get packed func.
class VulkanWrappedFunc {
public:
// initialize the VULKAN function.
void Init(VulkanModuleNode* m,
std::shared_ptr<ModuleNode> sptr,
const std::string& func_name,
size_t num_buffer_args,
size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) {
w_ = vulkan::VulkanWorkspace::Global().get();
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
num_buffer_args_ = num_buffer_args;
num_pack_args_ = num_pack_args;
thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args,
TVMRetValue* rv,
const ArgUnion* pack_args) const {
vulkan::VulkanThreadEntry* tls = vulkan::VulkanThreadEntry::ThreadLocal();
int device_id = tls->context.device_id;
CHECK_LT(device_id, kVulkanMaxNumDevice);
const vulkan::VulkanContext& vctx = w_->context_[device_id];
VulkanModuleNode::PipelineEntry& pe = scache_[device_id];
if (pe.pipeline == VK_NULL_HANDLE) {
pe = m_->GetPipeline(device_id, func_name_, num_pack_args_);
}
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
vulkan::VulkanCommandBuffer* cmd = tls->CommandPool(device_id)->Alloc(
&(pe.descriptor_layout));
cmd->write_descriptor_set.dstSet = cmd->descriptor_set;
// setup descriptors
for (uint32_t i = 0; i < num_buffer_args_; ++i) {
void* buf = args[static_cast<int>(i)];
VkDescriptorBufferInfo binfo;
binfo.buffer = static_cast<vulkan::VulkanBuffer*>(buf)->buffer;
binfo.offset = 0;
binfo.range = VK_WHOLE_SIZE;
cmd->write_descriptor_set.dstBinding = i;
cmd->write_descriptor_set.pBufferInfo = &binfo;
vkUpdateDescriptorSets(
vctx.device, 1, &(cmd->write_descriptor_set), 0, nullptr);
}
// dispatch
VkCommandBufferBeginInfo cb_begin;
cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
cb_begin.pNext = nullptr;
cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
cb_begin.pInheritanceInfo = 0;
VkSubmitInfo cb_submit;
cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
cb_submit.pNext = nullptr;
cb_submit.waitSemaphoreCount = 0;
cb_submit.pWaitSemaphores = nullptr;
cb_submit.pWaitDstStageMask = 0;
cb_submit.commandBufferCount = 1;
cb_submit.pCommandBuffers = &(cmd->cmd_buffer);
cb_submit.signalSemaphoreCount = 0;
cb_submit.pSignalSemaphores = nullptr;
// 0: begin
VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
// 1: dispatch
vkCmdBindPipeline(
cmd->cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, pe.pipeline);
vkCmdBindDescriptorSets(
cmd->cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE,
pe.pipeline_layout, 0, 1, &(cmd->descriptor_set), 0, nullptr);
// bind push constant if necessary
if (num_pack_args_ != 0) {
vkCmdPushConstants(
cmd->cmd_buffer,
pe.pipeline_layout,
VK_SHADER_STAGE_COMPUTE_BIT,
0, num_pack_args_ * sizeof(ArgUnion),
pack_args);
}
vkCmdDispatch(
cmd->cmd_buffer, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
// 2: barrier(compute->compute|transfer)
VkMemoryBarrier barrier_info;
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
barrier_info.pNext = nullptr;
barrier_info.srcAccessMask =
VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
barrier_info.dstAccessMask =
(VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
vkCmdPipelineBarrier(
cmd->cmd_buffer,
VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
0, 1, &barrier_info, 0, nullptr, 0, nullptr);
// 3: end
VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
// 4: submit with cmd->fence
VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
}
private:
// Reference to global workspace.
vulkan::VulkanWorkspace* w_;
// internal module
VulkanModuleNode* m_;
// the resource holder
std::shared_ptr<ModuleNode> sptr_;
// The name of the function.
std::string func_name_;
// Number of buffer arguments
size_t num_buffer_args_;
// number of packed arguments.
size_t num_pack_args_;
// Device state cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<VulkanModuleNode::PipelineEntry, kVulkanMaxNumDevice> scache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
};
PackedFunc VulkanModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
VulkanWrappedFunc f;
size_t num_buffer_args = NumBufferArgs(info.arg_types);
f.Init(this, sptr_to_self, name,
num_buffer_args, info.arg_types.size() - num_buffer_args,
info.thread_axis_tags);
return PackFuncNonBufferArg(f, info.arg_types);
}
Module VulkanModuleCreate(
std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) {
vulkan::VulkanWorkspace::Global()->Init();
std::shared_ptr<VulkanModuleNode> n =
std::make_shared<VulkanModuleNode>(smap, fmap, source);
return Module(n);
}
// Load module from module.
Module VulkanModuleLoadFile(const std::string& file_name,
const std::string& format) {
std::string data;
std::unordered_map<std::string, VulkanShader> smap;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
dmlc::MemoryStringStream fs(&data);
dmlc::Stream* stream = &fs;
uint32_t magic;
stream->Read(&magic);
CHECK_EQ(magic, kVulkanModuleMagic)
<< "VulkanModule Magic mismatch";
stream->Read(&smap);
return VulkanModuleCreate(smap, fmap, "");
}
Module VulkanModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::unordered_map<std::string, VulkanShader> smap;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt;
stream->Read(&fmt);
stream->Read(&fmap);
stream->Read(&smap);
return VulkanModuleCreate(smap, fmap, "");
}
TVM_REGISTER_GLOBAL("module.loadfile_vulkan")
.set_body_typed(VulkanModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_vulkan")
.set_body_typed(VulkanModuleLoadBinary);
} // namespace runtime
} // namespace tvm
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -16,67 +16,22 @@
* specific language governing permissions and limitations
* under the License.
*/
#pragma once
/*!
* Copyright (c) 2017 by Contributors
* \file metal_module.h
* \brief Execution handling of Metal kernels
*/
#ifndef TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
#define TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
#include <tvm/runtime/packed_func.h>
#include <dmlc/type_traits.h>
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "../meta_data.h"
#include "vulkan_shader.h"
namespace tvm {
namespace runtime {
/*! \brief Maximum number of GPU supported in VulkanModule. */
static constexpr const int kVulkanMaxNumDevice = 8;
/*! \brief TVM Vulkan binary pack magic number */
static constexpr const int kVulkanModuleMagic = 0x02700027;
/*!
* \brief A single VK shader program
*
* Due to the global resource declaration.
* Current SPIRV only allows one entry program per shader,
* making it less useful for a Module like system.
*
* Instead we pass in map of str->VulkanShader until
* there is a native solution available.
*/
struct VulkanShader {
/*! \brief header flag */
uint32_t flag{0};
/*! \brief Data segment */
std::vector<uint32_t> data;
namespace vulkan {
Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source);
void Save(dmlc::Stream *writer) const;
bool Load(dmlc::Stream *reader);
};
} // namespace vulkan
/*!
* \brief create a metal module from data.
*
* \param pmap The program map.
* \param fmap The function information map.
* \param source Optional, source code.
*/
Module VulkanModuleCreate(
std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source);
using vulkan::VulkanModuleCreate;
} // namespace runtime
} // namespace tvm
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::VulkanShader, true);
} // namespace dmlc
#endif // TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#pragma once
#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/packed_func.h>
#include <vector>
namespace tvm {
namespace runtime {
namespace vulkan {
struct VulkanShader {
/*! \brief header flag */
uint32_t flag{0};
/*! \brief Data segment */
std::vector<uint32_t> data;
void Save(dmlc::Stream* writer) const {
writer->Write(flag);
writer->Write(data);
}
bool Load(dmlc::Stream* reader) {
if (!reader->Read(&flag)) return false;
if (!reader->Read(&data)) return false;
return true;
}
};
} // namespace vulkan
using vulkan::VulkanShader;
} // namespace runtime
} // namespace tvm
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::vulkan::VulkanShader, true);
} // namespace dmlc
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#pragma once
#include <functional>
#include <memory>
#include <vector>
#include "vulkan_common.h"
namespace tvm {
namespace runtime {
namespace vulkan {
class VulkanStreamState {
public:
VkCommandBuffer cmd_buffer_;
VkFence fence_;
};
// Used to identify state that should only be used once-per-stream.
struct VulkanStreamToken {
VkDescriptorSet descriptor_set_{VK_NULL_HANDLE};
std::vector<VkBuffer> buffers_;
};
class VulkanStream {
public:
explicit VulkanStream(const VulkanContext* vctx)
: vctx_(vctx), state_(new VulkanStreamState()) {
// create command pool
VkCommandPoolCreateInfo cmd_pool_cinfo;
cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
cmd_pool_cinfo.pNext = nullptr;
cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
cmd_pool_cinfo.queueFamilyIndex = vctx_->queue_family_index;
VULKAN_CALL(vkCreateCommandPool(vctx_->device, &cmd_pool_cinfo, nullptr, &cmd_pool_));
VkCommandBufferAllocateInfo buffer_alloc_info;
buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
buffer_alloc_info.pNext = nullptr;
buffer_alloc_info.commandPool = cmd_pool_;
buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
buffer_alloc_info.commandBufferCount = 1;
VULKAN_CALL(
vkAllocateCommandBuffers(vctx_->device, &buffer_alloc_info, &(state_->cmd_buffer_)));
VkFenceCreateInfo fence_cinfo;
fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
fence_cinfo.pNext = nullptr;
fence_cinfo.flags = 0; // VK_FENCE_CREATE_SIGNALED_BIT;
VULKAN_CALL(vkCreateFence(vctx_->device, &fence_cinfo, nullptr, &(state_->fence_)));
VkCommandBufferBeginInfo cb_begin;
cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
cb_begin.pNext = nullptr;
cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
cb_begin.pInheritanceInfo = 0;
VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin));
}
~VulkanStream() {
vkDestroyFence(vctx_->device, state_->fence_, nullptr);
vkDestroyCommandPool(vctx_->device, cmd_pool_, nullptr);
}
// Launch the kernel on the current stream.
void Launch(const std::function<void(VulkanStreamState*)>& kernel) {
if (vctx_->UseImmediate()) {
kernel(state_.get());
} else {
deferred_kernels_.push_back(kernel);
}
}
// Launch the kernel on the current stream,
void LaunchDeferred(const std::function<void()>& deferred_initializer,
const std::function<void(VulkanStreamState*)>& deferred_kernel,
const VulkanStreamToken& deferred_token) {
CHECK(!vctx_->UseImmediate());
// It is invalid to schedule this instance on the current stream if we already
// have a matching descriptor set and a non-matching buffer set.
if (std::any_of(deferred_tokens_.begin(), deferred_tokens_.end(),
[&](const VulkanStreamToken& token) {
return token.descriptor_set_ == deferred_token.descriptor_set_ &&
token.buffers_ != deferred_token.buffers_;
})) {
Synchronize();
}
// It is unnecessary to invoke our initializer if we have a matching token.
if (!std::any_of(deferred_tokens_.begin(), deferred_tokens_.end(),
[&](const VulkanStreamToken& token) {
// If we have a matching descriptor set
return token.descriptor_set_ == deferred_token.descriptor_set_ &&
token.buffers_ == deferred_token.buffers_;
})) {
deferred_initializer();
}
deferred_kernels_.push_back(deferred_kernel);
deferred_tokens_.push_back(deferred_token);
}
// Synchronize the current stream `state_` with respect to the host.
void Synchronize() {
if (!vctx_->UseImmediate()) {
for (const auto& deferred_kernel : deferred_kernels_) {
deferred_kernel(state_.get());
}
deferred_kernels_.clear();
deferred_tokens_.clear();
} else {
DCHECK_EQ(deferred_kernels_.size(), 0);
DCHECK_EQ(deferred_tokens_.size(), 0);
}
VULKAN_CALL(vkEndCommandBuffer(state_->cmd_buffer_));
VkSubmitInfo cb_submit;
cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
cb_submit.pNext = nullptr;
cb_submit.waitSemaphoreCount = 0;
cb_submit.pWaitSemaphores = nullptr;
cb_submit.pWaitDstStageMask = 0;
cb_submit.commandBufferCount = 1;
cb_submit.pCommandBuffers = &(state_->cmd_buffer_);
cb_submit.signalSemaphoreCount = 0;
cb_submit.pSignalSemaphores = nullptr;
{
// Multiple streams (on different threads) use the same VulkanContext
// instance, so we need to externally synchronize accesses.
std::lock_guard<std::mutex> g(*(vctx_->queue_mutex));
VULKAN_CALL(vkQueueSubmit(vctx_->queue, 1, &cb_submit, state_->fence_));
}
uint64_t timeout = 1UL << 30UL;
VkResult res;
do {
res = vkWaitForFences(vctx_->device, 1, &(state_->fence_), 0, timeout);
} while (res == VK_TIMEOUT);
VULKAN_CHECK_ERROR(res);
VULKAN_CALL(vkResetCommandBuffer(state_->cmd_buffer_, 0));
VULKAN_CALL(vkResetFences(vctx_->device, 1, &(state_->fence_)));
// Re-initialize the command buffer
VkCommandBufferBeginInfo cb_begin;
cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
cb_begin.pNext = nullptr;
cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
cb_begin.pInheritanceInfo = 0;
VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin));
}
private:
const VulkanContext* vctx_;
std::unique_ptr<VulkanStreamState> state_;
std::vector<VulkanStreamToken> deferred_tokens_;
std::vector<std::function<void(VulkanStreamState*)>> deferred_kernels_;
VkCommandPool cmd_pool_;
};
} // namespace vulkan
} // namespace runtime
} // namespace tvm
......@@ -16,6 +16,7 @@
# under the License.
import tvm
import re
import numpy as np
def test_vector_comparison():
......@@ -54,5 +55,119 @@ def test_vector_comparison():
check_correct_assembly('float16')
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
def test_vulkan_copy():
def check_vulkan(dtype, n):
if not tvm.vulkan(0).exist or not tvm.module.enabled("vulkan"):
print("skip because vulkan is not enabled..")
return
A = tvm.placeholder((n,), name='A', dtype=dtype)
ctx = tvm.vulkan(0)
a_np = np.random.uniform(size=(n,)).astype(A.dtype)
a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(a_np)
b_np = a.asnumpy()
tvm.testing.assert_allclose(a_np, b_np)
tvm.testing.assert_allclose(a_np, a.asnumpy())
for _ in range(100):
dtype = np.random.choice(["float32", "float16", "int8", "int32"])
logN = np.random.randint(1, 15)
peturb = np.random.uniform(low=0.5, high=1.5)
check_vulkan(dtype, int(peturb * (2 ** logN)))
def test_vulkan_vectorize_add():
num_thread = 8
def check_vulkan(dtype, n, lanes):
if not tvm.vulkan(0).exist or not tvm.module.enabled("vulkan"):
print("skip because vulkan is not enabled..")
return
A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B')
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, bx)
s[B].bind(xi, tx)
fun = tvm.build(s, [A, B], "vulkan")
ctx = tvm.vulkan(0)
a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
np.random.uniform(size=(n, lanes)))
c = tvm.nd.empty((n,), B.dtype, ctx)
fun(a, c)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
check_vulkan("float32", 64, 2)
check_vulkan("float16", 64, 2)
def test_vulkan_stress():
"""
Launch a randomized test with multiple kernels per stream, multiple uses of
kernels per stream, over multiple threads.
"""
import random
import threading
n = 1024
num_thread = 64
def run_stress():
def worker():
if not tvm.vulkan(0).exist or not tvm.module.enabled("vulkan"):
print("skip because vulkan is not enabled..")
return
A = tvm.placeholder((n,), name='A', dtype="float32")
B = tvm.placeholder((n,), name='B', dtype="float32")
functions = [
(lambda: tvm.compute((n,), lambda i: 2 * A[i] + 3 * B[i]),
lambda a, b: 2 * a + 3 * b),
(lambda: tvm.compute((n,), lambda i: A[i]+B[i]),
lambda a, b: a + b),
(lambda: tvm.compute((n,), lambda i: A[i]+2 * B[i]),
lambda a, b: a + 2 * b),
]
def build_f(f_ref):
(C_f, ref) = f_ref
C = C_f()
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
s[C].bind(xo, bx)
s[C].bind(xi, tx)
fun = tvm.build(s, [A, B, C], "vulkan")
return (fun, ref)
fs = [build_f(random.choice(functions))
for _ in range(np.random.randint(low=1, high=10))]
ctx = tvm.vulkan(0)
a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
np.random.uniform(size=(n,)))
b = tvm.nd.empty((n,), B.dtype, ctx).copyfrom(
np.random.uniform(size=(n,)))
cs = [tvm.nd.empty((n,), A.dtype, ctx) for _ in fs]
for ((f, _), c) in zip(fs, cs):
f(a, b, c)
for ((_, ref), c) in zip(fs, cs):
tvm.testing.assert_allclose(
c.asnumpy(), ref(a.asnumpy(), b.asnumpy()))
ts = [threading.Thread(target=worker)
for _ in range(np.random.randint(1, 10))]
for t in ts:
t.start()
for t in ts:
t.join()
run_stress()
if __name__ == "__main__":
test_vector_comparison()
test_vulkan_copy()
test_vulkan_vectorize_add()
test_vulkan_stress()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment