/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2018 by Contributors
 * \file build_vulkan.cc
 * \brief Build SPIRV block
 */
// Use libspirv for parsing and validating code.
#include <libspirv.h>
#include <dmlc/memory_io.h>
#include <tvm/ir_pass.h>

#include "codegen_spirv.h"
#include "../build_common.h"
#include "../../runtime/vulkan/vulkan_module.h"

namespace tvm {
namespace codegen {

class SPIRVTools {
 public:
  SPIRVTools() {
    ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0);
  }
  ~SPIRVTools() {
    spvContextDestroy(ctx_);
  }
  std::string BinaryToText(const std::vector<uint32_t>& bin) {
    spv_text text = nullptr;
    spv_diagnostic diagnostic;
    spv_const_binary_t spv_bin{bin.data(), bin.size()};
    spv_result_t res;

    res = spvBinaryToText(
       ctx_, spv_bin.code, spv_bin.wordCount,
      SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES |
           SPV_BINARY_TO_TEXT_OPTION_INDENT,
        &text, &diagnostic);

    CHECK_EQ(res, SPV_SUCCESS)
        << " line=" << diagnostic->position.line
        << " column=" << diagnostic->position.column
        << " index=" << diagnostic->position.index
        << " error:" << diagnostic->error;

    std::string ret(text->str);
    spvTextDestroy(text);
    return ret;
  }

 private:
  spv_context ctx_;
};

runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
  using tvm::runtime::Registry;
  using tvm::runtime::VulkanShader;

  std::ostringstream code_data;
  static SPIRVTools spirv_tools;
  std::unordered_map<std::string, VulkanShader> smap;

  const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc");

  CodeGenSPIRV cg;
  for (LoweredFunc f : funcs) {
    f = PointerValueTypeRewrite(f);
    VulkanShader shader;
    shader.data = cg.BuildFunction(f);

    if (postproc != nullptr) {
      TVMByteArray arr;
      arr.data = reinterpret_cast<const char*>(dmlc::BeginPtr(shader.data));
      arr.size = shader.data.size() * sizeof(uint32_t);
      std::string transformed = (*postproc)(arr);
      CHECK_EQ(transformed.length() % 4U, 0U);
      shader.data.resize(transformed.size() / 4U);
      std::copy(transformed.begin(), transformed.end(),
                reinterpret_cast<char*>(dmlc::BeginPtr(shader.data)));
    }
    code_data << spirv_tools.BinaryToText(shader.data);
    smap[f->name] = std::move(shader);
  }
  return runtime::VulkanModuleCreate(
     smap, ExtractFuncInfo(funcs), code_data.str());
}

TVM_REGISTER_API("codegen.build_vulkan")
.set_body_typed(BuildSPIRV);

}  // namespace codegen
}  // namespace tvm