/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#pragma once

#include <functional>
#include <memory>
#include <vector>
#include <unordered_map>

#include "vulkan_common.h"


namespace tvm {
namespace runtime {
namespace vulkan {

class VulkanStreamState {
 public:
  VkCommandBuffer cmd_buffer_;
  VkFence fence_;
};

// Used to identify state that should only be used once-per-stream.
struct VulkanStreamToken {
  VkDescriptorSet descriptor_set_{VK_NULL_HANDLE};
  std::vector<VkBuffer> buffers_;
};

class VulkanStream {
 public:
  explicit VulkanStream(const VulkanContext* vctx)
      : vctx_(vctx), state_(new VulkanStreamState()) {
    // create command pool
    VkCommandPoolCreateInfo cmd_pool_cinfo;
    cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
    cmd_pool_cinfo.pNext = nullptr;
    cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
    cmd_pool_cinfo.queueFamilyIndex = vctx_->queue_family_index;
    VULKAN_CALL(vkCreateCommandPool(vctx_->device, &cmd_pool_cinfo, nullptr, &cmd_pool_));

    VkCommandBufferAllocateInfo buffer_alloc_info;
    buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
    buffer_alloc_info.pNext = nullptr;
    buffer_alloc_info.commandPool = cmd_pool_;
    buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
    buffer_alloc_info.commandBufferCount = 1;
    VULKAN_CALL(
        vkAllocateCommandBuffers(vctx_->device, &buffer_alloc_info, &(state_->cmd_buffer_)));

    VkFenceCreateInfo fence_cinfo;
    fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
    fence_cinfo.pNext = nullptr;
    fence_cinfo.flags = 0;  // VK_FENCE_CREATE_SIGNALED_BIT;
    VULKAN_CALL(vkCreateFence(vctx_->device, &fence_cinfo, nullptr, &(state_->fence_)));

    VkCommandBufferBeginInfo cb_begin;
    cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
    cb_begin.pNext = nullptr;
    cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
    cb_begin.pInheritanceInfo = 0;
    VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin));
  }

  ~VulkanStream() {
    vkDestroyFence(vctx_->device, state_->fence_, nullptr);
    vkDestroyCommandPool(vctx_->device, cmd_pool_, nullptr);
  }

  // Launch the kernel on the current stream.
  void Launch(const std::function<void(VulkanStreamState*)>& kernel) {
    if (vctx_->UseImmediate()) {
      kernel(state_.get());
    } else {
      deferred_kernels_.push_back(kernel);
    }
  }

  // Launch the kernel on the current stream,
  void LaunchDeferred(const std::function<void()>& deferred_initializer,
                      const std::function<void(VulkanStreamState*)>& deferred_kernel,
                      const VulkanStreamToken& deferred_token) {
    CHECK(!vctx_->UseImmediate());

    // It is invalid to schedule this instance on the current stream if we already
    // have a matching descriptor set and a non-matching buffer set.
    if (std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(),
                    deferred_tokens_[deferred_token.descriptor_set_].end(),
                    [&](const VulkanStreamToken& token) {
                      DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_);
                      return token.descriptor_set_ == deferred_token.descriptor_set_ &&
                             token.buffers_ != deferred_token.buffers_;
                    })) {
      Synchronize();
    }

    // It is unnecessary to invoke our initializer if we have a matching token.
    if (!std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(),
                     deferred_tokens_[deferred_token.descriptor_set_].end(),
                     [&](const VulkanStreamToken& token) {
                       DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_);
                       return token.descriptor_set_ == deferred_token.descriptor_set_ &&
                              token.buffers_ == deferred_token.buffers_;
                     })) {
      deferred_initializer();
    }

    deferred_kernels_.push_back(deferred_kernel);
    deferred_tokens_[deferred_token.descriptor_set_].push_back(deferred_token);
  }

  // Synchronize the current stream `state_` with respect to the host.
  void Synchronize() {
    if (!vctx_->UseImmediate()) {
      for (const auto& deferred_kernel : deferred_kernels_) {
        deferred_kernel(state_.get());
      }
      deferred_kernels_.clear();
      deferred_tokens_.clear();
    } else {
      DCHECK_EQ(deferred_kernels_.size(), 0);
      DCHECK_EQ(deferred_tokens_.size(), 0);
    }

    VULKAN_CALL(vkEndCommandBuffer(state_->cmd_buffer_));
    VkSubmitInfo cb_submit;
    cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
    cb_submit.pNext = nullptr;
    cb_submit.waitSemaphoreCount = 0;
    cb_submit.pWaitSemaphores = nullptr;
    cb_submit.pWaitDstStageMask = 0;
    cb_submit.commandBufferCount = 1;
    cb_submit.pCommandBuffers = &(state_->cmd_buffer_);
    cb_submit.signalSemaphoreCount = 0;
    cb_submit.pSignalSemaphores = nullptr;

    {
      // Multiple streams (on different threads) use the same VulkanContext
      // instance, so we need to externally synchronize accesses.
      std::lock_guard<std::mutex> g(*(vctx_->queue_mutex));
      VULKAN_CALL(vkQueueSubmit(vctx_->queue, 1, &cb_submit, state_->fence_));
    }
    uint64_t timeout = 1UL << 30UL;
    VkResult res;
    do {
      res = vkWaitForFences(vctx_->device, 1, &(state_->fence_), 0, timeout);
    } while (res == VK_TIMEOUT);
    VULKAN_CHECK_ERROR(res);
    VULKAN_CALL(vkResetCommandBuffer(state_->cmd_buffer_, 0));
    VULKAN_CALL(vkResetFences(vctx_->device, 1, &(state_->fence_)));

    // Re-initialize the command buffer
    VkCommandBufferBeginInfo cb_begin;
    cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
    cb_begin.pNext = nullptr;
    cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
    cb_begin.pInheritanceInfo = 0;
    VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin));
  }

 private:
  const VulkanContext* vctx_;
  std::unique_ptr<VulkanStreamState> state_;
  // An index of deferred tokens, allowing us to efficiently detect duplicated
  // deferred_initializer blocks.
  std::unordered_map<VkDescriptorSet, std::vector<VulkanStreamToken>> deferred_tokens_;
  std::vector<std::function<void(VulkanStreamState*)>> deferred_kernels_;
  VkCommandPool cmd_pool_;
};

}  // namespace vulkan
}  // namespace runtime
}  // namespace tvm