build_vulkan.cc 2.5 KB
Newer Older
1 2 3 4 5 6
/*!
 *  Copyright (c) 2018 by Contributors
 * \file build_vulkan.cc
 * \brief Build SPIRV block
 */
// Use libspirv for parsing and validating code.
7
#include <libspirv.h>
8 9 10
#include <dmlc/memory_io.h>
#include <tvm/ir_pass.h>

11
#include "codegen_spirv.h"
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
#include "../build_common.h"
#include "../../runtime/vulkan/vulkan_module.h"

namespace tvm {
namespace codegen {

class SPIRVTools {
 public:
  SPIRVTools() {
    ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0);
  }
  ~SPIRVTools() {
    spvContextDestroy(ctx_);
  }
  std::string BinaryToText(const std::vector<uint32_t>& bin) {
    spv_text text = nullptr;
    spv_diagnostic diagnostic;
    spv_const_binary_t spv_bin{bin.data(), bin.size()};
    spv_result_t res;

    res = spvBinaryToText(
       ctx_, spv_bin.code, spv_bin.wordCount,
      SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES |
           SPV_BINARY_TO_TEXT_OPTION_INDENT,
        &text, &diagnostic);

    CHECK_EQ(res, SPV_SUCCESS)
        << " line=" << diagnostic->position.line
        << " column=" << diagnostic->position.column
        << " index=" << diagnostic->position.index
        << " error:" << diagnostic->error;

    std::string ret(text->str);
    spvTextDestroy(text);
    return ret;
  }

 private:
  spv_context ctx_;
};

runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
  using tvm::runtime::Registry;
  using tvm::runtime::VulkanShader;

  std::ostringstream code_data;
  static SPIRVTools spirv_tools;
  std::unordered_map<std::string, VulkanShader> smap;

  const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc");

  CodeGenSPIRV cg;
  for (LoweredFunc f : funcs) {
    f = PointerValueTypeRewrite(f);
    VulkanShader shader;
    shader.data = cg.BuildFunction(f);

    if (postproc != nullptr) {
      TVMByteArray arr;
      arr.data = reinterpret_cast<const char*>(dmlc::BeginPtr(shader.data));
      arr.size = shader.data.size() * sizeof(uint32_t);
      std::string transformed = (*postproc)(arr);
      CHECK_EQ(transformed.length() % 4U, 0U);
      shader.data.resize(transformed.size() / 4U);
      std::copy(transformed.begin(), transformed.end(),
                reinterpret_cast<char*>(dmlc::BeginPtr(shader.data)));
    }
    code_data << spirv_tools.BinaryToText(shader.data);
    smap[f->name] = std::move(shader);
  }
  return runtime::VulkanModuleCreate(
     smap, ExtractFuncInfo(funcs), code_data.str());
}

TVM_REGISTER_API("codegen.build_vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = BuildSPIRV(args[0]);
  });

}  // namespace codegen
}  // namespace tvm