Commit 79d503fd by Tianqi Chen Committed by GitHub

[BACKEND] Vulkan Runtime and SPIRV Codegen (#861)

* [BACKEND] Vulkan Runtime and SPIRV Codegen

* fix doc
parent 108e9f3f
cmake_minimum_required(VERSION 3.5)
cmake_minimum_required(VERSION 3.7)
project(tvm C CXX)
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/build/private/local_config.cmake)
......@@ -22,6 +22,7 @@ endif()
tvm_option(USE_CUDA "Build with CUDA" OFF)
tvm_option(USE_OPENCL "Build with OpenCL" OFF)
tvm_option(USE_VULKAN "Build with Vulkan" OFF)
tvm_option(USE_OPENGL "Build with OpenGL" OFF)
tvm_option(USE_METAL "Build with Metal" OFF)
tvm_option(USE_RPC "Build with RPC" ON)
......@@ -88,9 +89,11 @@ file(GLOB_RECURSE HALIDEIR_SRCS HalideIR/src/*.cpp)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
file(GLOB RUNTIME_SRCS src/runtime/*.cc)
file(GLOB COMPILER_LLVM_SRCS src/codegen/llvm/*.cc)
file(GLOB COMPILER_VULKAN_SRCS src/codegen/spirv/*.cc)
file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc)
file(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc)
file(GLOB RUNTIME_OPENGL_SRCS src/runtime/opengl/*.cc)
file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc)
file(GLOB RUNTIME_METAL_SRCS src/runtime/metal/*.mm)
file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc)
file(GLOB RUNTIME_GRAPH_SRCS src/runtime/graph/*.cc)
......@@ -151,6 +154,22 @@ else(USE_OPENGL)
add_definitions(-DTVM_OPENGL_RUNTIME=0)
endif(USE_OPENGL)
if(USE_VULKAN)
find_package(Vulkan REQUIRED)
message(STATUS "Build with VULKAN support")
include_directories(${Vulkan_INCLUDE_DIRS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARIES})
list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS})
list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS})
get_filename_component(VULKAN_LIB_PATH ${Vulkan_LIBRARY} DIRECTORY)
find_library(SPIRV_TOOLS_LIB SPIRV-Tools
${VULKAN_LIB_PATH}/spirv-tools)
list(APPEND TVM_LINKER_LIBS ${SPIRV_TOOLS_LIB})
add_definitions(-DTVM_VULKAN_RUNTIME=1)
else(USE_VULKAN)
add_definitions(-DTVM_VULKAN_RUNTIME=0)
endif(USE_VULKAN)
if(USE_METAL)
find_package(OpenCL QUIET REQUIRED)
message(STATUS "Build with Metal support")
......@@ -174,7 +193,7 @@ if(USE_GRAPH_RUNTIME)
endif(USE_GRAPH_RUNTIME)
if(USE_LLVM)
find_package(LLVM CONFIG REQUIRED)
find_spackage(LLVM CONFIG REQUIRED)
include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
set(TVM_LLVM_VERSION ${LLVM_VERSION_MAJOR}${LLVM_VERSION_MINOR})
......
Subproject commit 87b089a0ba20f2e8257038ee9211d6816088ce95
Subproject commit aadbf02d6bd7a545edbf6652494a7b07a97a06c1
......@@ -56,6 +56,7 @@ CUDA_SRC = $(wildcard src/runtime/cuda/*.cc)
ROCM_SRC = $(wildcard src/runtime/rocm/*.cc)
OPENCL_SRC = $(wildcard src/runtime/opencl/*.cc)
OPENGL_SRC = $(wildcard src/runtime/opengl/*.cc)
VULKAN_SRC = $(wildcard src/runtime/vulkan/*.cc)
RPC_SRC = $(wildcard src/runtime/rpc/*.cc)
GRAPH_SRC = $(wildcard src/runtime/graph/*.cc)
RUNTIME_SRC = $(wildcard src/runtime/*.cc)
......@@ -69,6 +70,7 @@ CUDA_OBJ = $(patsubst src/%.cc, build/%.o, $(CUDA_SRC))
ROCM_OBJ = $(patsubst src/%.cc, build/%.o, $(ROCM_SRC))
OPENCL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENCL_SRC))
OPENGL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENGL_SRC))
VULKAN_OBJ = $(patsubst src/%.cc, build/%.o, $(VULKAN_SRC))
RPC_OBJ = $(patsubst src/%.cc, build/%.o, $(RPC_SRC))
GRAPH_OBJ = $(patsubst src/%.cc, build/%.o, $(GRAPH_SRC))
CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC)) $(LLVM_OBJ)
......@@ -129,6 +131,20 @@ else
CFLAGS += -DTVM_OPENCL_RUNTIME=0
endif
ifdef VULKAN_SDK
CFLAGS += -I$(VULKAN_SDK)/include
LDFLAGS += -L$(VULKAN_SDK)/lib
LDFLAGS += -L$(VULKAN_SDK)/lib/spirv-tools
endif
ifeq ($(USE_VULKAN), 1)
CFLAGS += -DTVM_VULKAN_RUNTIME=1
LDFLAGS += -lvulkan -lSPIRV-Tools
RUNTIME_DEP += $(VULKAN_OBJ)
else
CFLAGS += -DTVM_VULKAN_RUNTIME=0
endif
ifeq ($(USE_OPENGL), 1)
CFLAGS += -DTVM_OPENGL_RUNTIME=1
EMCC_FLAGS += -DTVM_OPENGL_RUNTIME=1
......
......@@ -422,6 +422,18 @@ LoweredFunc LowerTVMBuiltin(LoweredFunc f);
LoweredFunc CombineContextCall(LoweredFunc f);
/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
* to avoid pointer casting in backend when possible.
*
* \note implemeneted in storage_rewrite.cc
* \param f The function to be trasnformed
* \return Transformed function.
*/
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
/*!
* \brief Lower intrinsic function calls.
* \param f The device function to be lowered.
* \param target The target device.
......
......@@ -55,8 +55,8 @@ typedef int64_t tvm_index_t;
/*! \brief Extension device types in TVM */
typedef enum {
kDLVulkan = 7,
kOpenGL = 11,
// Extension DRAM type, used for quickly test extension device
// The device api can differ depending on the xpu driver registered.
kExtDev = 12,
......
......@@ -17,7 +17,8 @@ from . import ir_builder
from . import target
from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm, opengl, ext_dev
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev
from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import Function
......
......@@ -94,6 +94,7 @@ class TVMContext(ctypes.Structure):
1 : 'cpu',
2 : 'gpu',
4 : 'opencl',
7 : 'vulkan',
8 : 'metal',
9 : 'vpi',
10: 'rocm',
......@@ -109,6 +110,7 @@ class TVMContext(ctypes.Structure):
'nvptx': 2,
'cl': 4,
'opencl': 4,
'vulkan': 7,
'metal': 8,
'vpi': 9,
'rocm': 10,
......
"""Utility for Interacting with SPIRV Tools"""
import subprocess
import os
from . import util
def optimize(spv_bin):
"""Optimize SPIRV using spirv-opt via CLI
Note that the spirv-opt is still experimental.
Parameters
----------
spv_bin : bytearray
The spirv file
Return
------
cobj_bin : bytearray
The HSA Code Object
"""
tmp_dir = util.tempdir()
tmp_in = tmp_dir.relpath("input.spv")
tmp_out = tmp_dir.relpath("output.spv")
with open(tmp_in, "wb") as out_file:
out_file.write(bytes(spv_bin))
sdk = os.environ.get("VULKAN_SDK", None)
cmd = os.path.join(sdk, "bin/spirv-opt") if sdk else "spirv-opt"
args = [cmd, "-O", tmp_in, "-o", tmp_out]
proc = subprocess.Popen(
args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Opitmizationerror using spirv-opt:\n"
msg += str(out)
raise RuntimeError(msg)
return bytearray(open(tmp_out, "rb").read())
......@@ -120,6 +120,23 @@ def vpi(dev_id=0):
"""
return TVMContext(9, dev_id)
def vulkan(dev_id=0):
"""Construct a Vulkan device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(7, dev_id)
def opengl(dev_id=0):
"""Construct a OpenGL device
......@@ -135,6 +152,7 @@ def opengl(dev_id=0):
"""
return TVMContext(11, dev_id)
def ext_dev(dev_id=0):
"""Construct a extension device
......
......@@ -116,7 +116,7 @@ class Target(object):
# For now assume rocm schedule for opencl
self.keys += ("rocm", "gpu")
self.max_num_threads = 256
elif target_name in ("metal",):
elif target_name in ("metal", "vulkan"):
self.keys += ("gpu",)
self.max_num_threads = 256
elif target_name in ("opengl",):
......
......@@ -666,6 +666,8 @@ void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*)
}
void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
// constraint of current logic
CHECK_EQ(op->base.type(), Int(32));
os << "((int" << op->lanes << ")(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
......
/*!
* Copyright (c) 2018 by Contributors
* \file codegen_common.h
* \brief Common utility for codegen.
*/
#ifndef TVM_CODEGEN_CODEGEN_COMMON_H_
#define TVM_CODEGEN_CODEGEN_COMMON_H_
#include <tvm/arithmetic.h>
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace codegen {
/*!
* \brief Visit AssertStmt recursively, update align_map from condition.
* \param op The AssertStmt
* \param align_map The alignmap
* \param fvisit The recursive visitor
* \tparam FVisit the recursive visitor
*/
template<typename FVisit>
inline void VisitAssert(
const ir::AssertStmt* op,
std::unordered_map<const Variable*, arith::ModularEntry>* align_map,
FVisit fvisit) {
using namespace ir;
auto& align_map_ = *align_map;
// Detect useful invariant pattern and use them to visit child.
// Pattern: Var % const == 0
// TODO(tqchen) merge these pattern to a generic scope info visitor.
if (const EQ* eq = op->condition.as<EQ>()) {
const Mod* mod = eq->a.as<Mod>();
int64_t factor = 0, offset = 0;
if (mod && arith::GetConst(eq->b, &offset)) {
const Variable *var = mod->a.as<Variable>();
if (var && arith::GetConst(mod->b, &factor)) {
arith::ModularEntry old = align_map_[var];
if (factor > old.coeff) {
arith::ModularEntry e;
e.coeff = static_cast<int>(factor);
e.base = static_cast<int>(offset);
// new alignment info,
align_map_[var] = e;
fvisit(op->body);
// restore old info
align_map_[var] = old;
return;
}
}
}
}
fvisit(op->body);
}
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_COMMON_H_
......@@ -9,6 +9,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include "./codegen_llvm.h"
#include "./codegen_cpu.h"
#include "../codegen_common.h"
#include "../../pass/ir_util.h"
#include "../../arithmetic/compute_expr.h"
......@@ -341,7 +342,7 @@ void CodeGenLLVM::GetAlignment(Type t,
int align_bits = t.bits();
while (align_bits < max_align_bits &&
me.base % 2 == 0 &&
me.coeff %2 == 0) {
me.coeff % 2 == 0) {
me.base = me.base / 2;
me.coeff = me.coeff / 2;
align_bits *= 2;
......@@ -1026,31 +1027,9 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
}
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
// Detect useful invariant pattern and use them to visit child.
// Pattern: Var % const == 0
// TODO(tqchen) move these pattern to a generic scope info visitor.
if (const EQ* eq = op->condition.as<EQ>()) {
const Mod* mod = eq->a.as<Mod>();
int64_t factor = 0, offset = 0;
if (mod && arith::GetConst(eq->b, &offset)) {
const Variable *var = mod->a.as<Variable>();
if (var && arith::GetConst(mod->b, &factor)) {
arith::ModularEntry old = align_map_[var];
if (factor > old.coeff) {
arith::ModularEntry e;
e.coeff = static_cast<int>(factor);
e.base = static_cast<int>(offset);
// new alignment info,
align_map_[var] = e;
this->VisitStmt(op->body);
// restore old info
align_map_[var] = old;
return;
}
}
}
}
this->VisitStmt(op->body);
VisitAssert(op, &align_map_, [this](const Stmt& body) {
this->VisitStmt(body);
});
}
void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
......
/*!
* Copyright (c) 2018 by Contributors
* \file build_vulkan.cc
* \brief Build SPIRV block
*/
#if TVM_VULKAN_RUNTIME
// Use libspirv for parsing and validating code.
#include <vulkan/libspirv.h>
#include <dmlc/memory_io.h>
#include <tvm/ir_pass.h>
#include "./codegen_spirv.h"
#include "../build_common.h"
#include "../../runtime/vulkan/vulkan_module.h"
namespace tvm {
namespace codegen {
class SPIRVTools {
public:
SPIRVTools() {
ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0);
}
~SPIRVTools() {
spvContextDestroy(ctx_);
}
std::string BinaryToText(const std::vector<uint32_t>& bin) {
spv_text text = nullptr;
spv_diagnostic diagnostic;
spv_const_binary_t spv_bin{bin.data(), bin.size()};
spv_result_t res;
res = spvBinaryToText(
ctx_, spv_bin.code, spv_bin.wordCount,
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES |
SPV_BINARY_TO_TEXT_OPTION_INDENT,
&text, &diagnostic);
CHECK_EQ(res, SPV_SUCCESS)
<< " line=" << diagnostic->position.line
<< " column=" << diagnostic->position.column
<< " index=" << diagnostic->position.index
<< " error:" << diagnostic->error;
std::string ret(text->str);
spvTextDestroy(text);
return ret;
}
private:
spv_context ctx_;
};
runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
using tvm::runtime::VulkanShader;
std::ostringstream code_data;
static SPIRVTools spirv_tools;
std::unordered_map<std::string, VulkanShader> smap;
const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc");
CodeGenSPIRV cg;
for (LoweredFunc f : funcs) {
f = PointerValueTypeRewrite(f);
VulkanShader shader;
shader.data = cg.BuildFunction(f);
if (postproc != nullptr) {
TVMByteArray arr;
arr.data = reinterpret_cast<const char*>(dmlc::BeginPtr(shader.data));
arr.size = shader.data.size() * sizeof(uint32_t);
std::string transformed = (*postproc)(arr);
CHECK_EQ(transformed.length() % 4U, 0U);
shader.data.resize(transformed.size() / 4U);
std::copy(transformed.begin(), transformed.end(),
reinterpret_cast<char*>(dmlc::BeginPtr(shader.data)));
}
code_data << spirv_tools.BinaryToText(shader.data);
smap[f->name] = std::move(shader);
}
return runtime::VulkanModuleCreate(
smap, ExtractFuncInfo(funcs), code_data.str());
}
TVM_REGISTER_API("codegen.build_vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildSPIRV(args[0]);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
/*!
* Copyright (c) 2018 by Contributors
* \file codegen_spirv.cc
* \brief Generate SPIRV block
*/
#if TVM_VULKAN_RUNTIME
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "../codegen_common.h"
#include "./codegen_spirv.h"
namespace tvm {
namespace codegen {
std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const LoweredFunc& f) {
this->InitFuncState();
CHECK(f->is_restricted)
<< "SPIRV only takes restricted memory model";
std::vector<Var> pod_args;
uint32_t num_buffer = 0;
for (Var arg : f->args) {
Type t = arg.type();
if (t.is_handle()) {
auto it = f->handle_data_type.find(arg);
if (it != f->handle_data_type.end()) {
Type value_type = (*it).second.type();
spirv::Value arg_value = builder_->BufferArgument(
builder_->GetSType(value_type), 0, num_buffer);
storage_info_[arg.get()].UpdateContentType(value_type);
var_map_[arg.get()] = arg_value;
} else {
LOG(FATAL) << "require all handles to be typed";
}
++num_buffer;
} else {
pod_args.push_back(arg);
}
}
spirv::Value func_ptr = builder_->DeclareKenrelFunction(f->name);
builder_->StartFunction(func_ptr);
// All the POD arguments are passed in through PushConstant
if (pod_args.size() != 0) {
std::vector<spirv::SType> value_types;
for (size_t i = 0; i < pod_args.size(); ++i) {
value_types.push_back(builder_->GetSType(pod_args[i].type()));
}
spirv::Value ptr = builder_->DeclarePushConstant(value_types);
for (size_t i = 0; i < pod_args.size(); ++i) {
spirv::Value value = builder_->GetPushConstant(
ptr, value_types[i], static_cast<uint32_t>(i));
var_map_[pod_args[i].get()] = value;
}
}
this->VisitStmt(f->body);
builder_->SetLocalSize(func_ptr, workgroup_size_);
builder_->MakeInst(spv::OpReturn);
builder_->MakeInst(spv::OpFunctionEnd);
return builder_->Finalize();
}
void CodeGenSPIRV::InitFuncState() {
std::fill(workgroup_size_, workgroup_size_ + 3, 1);
var_map_.clear();
storage_info_.clear();
align_map_.clear();
builder_.reset(new spirv::IRBuilder());
builder_->InitHeader();
}
spirv::Value CodeGenSPIRV::GetThreadIndex(
const IterVar& iv, const Expr& extent) {
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
spirv::Value v;
if (ts.rank == 1) {
v = builder_->GetLocalID(ts.dim_index);
int size;
CHECK(arith::GetConstInt(extent, &size))
<< "SPIRV only allows constant thread group size " << " get " << extent;
CHECK_LT(ts.dim_index, 3);
workgroup_size_[ts.dim_index] = static_cast<uint32_t>(size);
} else {
v = builder_->GetWorkgroupID(ts.dim_index);
}
return builder_->Cast(builder_->GetSType(iv->var.type()), v);
}
spirv::Value CodeGenSPIRV::CreateStorageSync(const Call* op) {
const std::string& sync = op->args[0].as<StringImm>()->value;
spirv::Value value;
if (sync == "warp") {
return value;
} else if (sync == "shared") {
builder_->MakeInst(
spv::OpControlBarrier,
spv::ScopeWorkgroup,
spv::ScopeWorkgroup,
spv::MemorySemanticsSequentiallyConsistentMask |
spv::MemorySemanticsWorkgroupMemoryMask);
} else {
LOG(FATAL) << "Do not support sync " << sync;
}
return value;
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Variable* op) {
auto it = var_map_.find(op);
CHECK(it != var_map_.end()) << "cannot find variable " << op->name_hint;
return it->second;
}
spirv::Value CodeGenSPIRV::VisitExpr_(const IntImm* op) {
return builder_->IntImm(builder_->GetSType(op->type), op->value);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImm* op) {
return builder_->UIntImm(builder_->GetSType(op->type), op->value);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImm* op) {
return builder_->FloatImm(builder_->GetSType(op->type), op->value);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const StringImm* op) {
LOG(FATAL) << "StringImm is not supported in Device code";
return spirv::Value();
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Cast* op) {
return builder_->Cast(builder_->GetSType(op->type), MakeValue(op->value));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Add* op) {
return builder_->Add(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Sub* op) {
return builder_->Sub(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Mul* op) {
return builder_->Mul(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Div* op) {
return builder_->Div(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Mod* op) {
return builder_->Mod(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Min* op) {
spirv::Value a = MakeValue(op->a);
spirv::Value b = MakeValue(op->b);
return builder_->Select(builder_->LT(a, b), a, b);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Max* op) {
spirv::Value a = MakeValue(op->a);
spirv::Value b = MakeValue(op->b);
return builder_->Select(builder_->GT(a, b), a, b);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const LT* op) {
return builder_->LT(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const LE* op) {
return builder_->LE(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const GT* op) {
return builder_->GT(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const GE* op) {
return builder_->GE(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const EQ* op) {
return builder_->EQ(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const NE* op) {
return builder_->NE(MakeValue(op->a), MakeValue(op->b));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const And* op) {
spirv::Value a = MakeValue(op->a);
spirv::Value b = MakeValue(op->b);
return builder_->MakeValue(spv::OpLogicalAnd, a.stype, a, b);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Or* op) {
spirv::Value a = MakeValue(op->a);
spirv::Value b = MakeValue(op->b);
return builder_->MakeValue(spv::OpLogicalOr, a.stype, a, b);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Not* op) {
spirv::Value a = MakeValue(op->a);
return builder_->MakeValue(spv::OpLogicalNot, a.stype, a);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Select* op) {
return builder_->Select(MakeValue(op->condition),
MakeValue(op->true_value),
MakeValue(op->false_value));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Let* op) {
CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = MakeValue(op->value);
align_map_[op->var.get()] = EvalModular(op->value, align_map_);
return MakeValue(op->body);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
if (op->is_intrinsic("spirv_glsl450")) {
CHECK_GE(op->args.size(), 2U);
uint32_t inst_id = op->args[0].as<UIntImm>()->value;
std::vector<spirv::Value> values;
for (size_t i = 1; i < op->args.size(); ++i) {
values.push_back(MakeValue(op->args[i]));
}
return builder_->CallGLSL450(
builder_->GetSType(op->type), inst_id, values);
} else if (op->is_intrinsic(Call::bitwise_and)) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b);
} else if (op->is_intrinsic(Call::bitwise_xor)) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b);
} else if (op->is_intrinsic(Call::bitwise_or)) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b);
} else if (op->is_intrinsic(Call::bitwise_not)) {
CHECK_EQ(op->args.size(), 1U);
spirv::Value a = MakeValue(op->args[0]);
return builder_->MakeValue(spv::OpNot, a.stype, a);
} else if (op->is_intrinsic(Call::shift_left)) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b);
} else if (op->is_intrinsic(Call::shift_right)) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
if (op->args[0].type().is_int()) {
return builder_->MakeValue(spv::OpShiftRightArithmetic, a.stype, a, b);
} else {
return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
}
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return this->CreateStorageSync(op);
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
CHECK_EQ(op->args.size(), 3U);
spirv::Value cond = MakeValue(op->args[0]);
spirv::Label then_label = builder_->NewLabel();
spirv::Label else_label = builder_->NewLabel();
spirv::Label merge_label = builder_->NewLabel();
builder_->MakeInst(
spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
builder_->MakeInst(
spv::OpBranchConditional, cond, then_label, else_label);
// then block, must get label after we see the value
builder_->StartLabel(then_label);
spirv::Value then_value = MakeValue(op->args[1]);
spirv::Label then_value_label = builder_->CurrentLabel();
builder_->MakeInst(spv::OpBranch, merge_label);
// else block
builder_->StartLabel(else_label);
spirv::Value else_value = MakeValue(op->args[2]);
spirv::Label else_value_label = builder_->CurrentLabel();
builder_->MakeInst(spv::OpBranch, merge_label);
// merge block
builder_->StartLabel(merge_label);
spirv::PhiValue phi = builder_->MakePhi(then_value.stype, 2);
phi.SetIncoming(0, then_value, then_value_label);
phi.SetIncoming(1, else_value, else_value_label);
return phi;
} else if (op->is_intrinsic("popcount")) {
return builder_->MakeValue(
spv::OpBitCount,
builder_->GetSType(op->type),
MakeValue(op->args[0]));
} else {
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
LOG(FATAL) << "Unresolved intrinsic " << op->name
<< " with return type " << op->type;
} else if (op->call_type == Call::Extern ||
op->call_type == Call::PureExtern) {
LOG(FATAL) << "Unresolved extern " << op->name
<< " with return type " << op->type;
} else {
LOG(FATAL) << "Unresolved call type " << op->call_type;
}
return spirv::Value();
}
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) {
std::vector<spirv::Value> values;
spirv::Value base = MakeValue(op->base);
for (int i = 0; i < op->lanes; ++i) {
spirv::Value v = base;
if (i != 0) {
spirv::Value offset = MakeValue(
arith::ComputeExpr<Mul>(make_const(op->stride.type(), i), op->stride));
v = builder_->Add(v, offset);
}
values.push_back(v);
}
return builder_->Concat(values);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Broadcast* op) {
std::vector<spirv::Value> values;
spirv::Value v = MakeValue(op->value);
for (int i = 0; i < op->lanes; i++) {
values.push_back(v);
}
return builder_->Concat(values);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) {
CHECK(is_one(op->predicate));
auto it = storage_info_.find(op->buffer_var.get());
CHECK(it != storage_info_.end());
StorageInfo& info = it->second;
if (!info.content_fixed) {
info.UpdateContentType(op->type);
}
spirv::SType content_type = builder_->GetSType(info.content_type);
spirv::Value buffer = MakeValue(op->buffer_var);
spirv::SType ptr_type = builder_->GetPointerType(
content_type, buffer.stype.storage_class);
uint32_t mask = spv::MemoryAccessMaskNone;
if (info.is_volatile) {
mask |= spv::MemoryAccessVolatileMask;
}
if (op->type.lanes() == 1) {
CHECK_EQ(info.content_type, op->type)
<< "Vulkan only allow one type access to the same buffer";
spirv::Value index = MakeValue(op->index);
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, index);
return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
} else {
if (op->type.element_of() == info.content_type) {
// because content type is element type, we can only do scalarize load.
std::vector<spirv::Value> values;
auto f = [&](int i, spirv::Value index) {
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, index);
values.emplace_back(
builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
};
this->Scalarize(op->index, f);
return builder_->Concat(values);
} else {
if (const Ramp* ramp = op->index.as<Ramp>()) {
if (is_one(ramp->stride)) {
CHECK_EQ(ramp->lanes, op->type.lanes());
arith::ModularEntry me = arith::EvalModular(ramp->base, align_map_);
CHECK((me.coeff % ramp->lanes) == 0 &&
(me.base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
Expr vec_index = ir::Simplify(
ramp->base / make_const(ramp->base.type(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, MakeValue(vec_index));
return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
}
}
}
LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
}
LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
return spirv::Value();
}
void CodeGenSPIRV::Scalarize(const Expr& e,
std::function<void(int i, spirv::Value v)> f) {
if (const Ramp* ramp = e.as<Ramp>()) {
for (int i = 0; i < ramp->type.lanes(); ++i) {
Expr offset = arith::ComputeExpr<Add>(
ramp->base,
arith::ComputeExpr<Mul>(ramp->stride, i));
f(i, MakeValue(offset));
}
} else {
spirv::SType etype = builder_->GetSType(e.type().element_of());
spirv::Value value = MakeValue(e);
for (int i = 0; i < e.type().lanes(); ++i) {
f(i, builder_->MakeValue(
spv::OpCompositeExtract, etype, value, i));
}
}
}
void CodeGenSPIRV::VisitStmt_(const Store* op) {
CHECK(is_one(op->predicate));
auto it = storage_info_.find(op->buffer_var.get());
CHECK(it != storage_info_.end());
StorageInfo& info = it->second;
if (!info.content_fixed) {
info.UpdateContentType(op->value.type());
}
spirv::SType content_type = builder_->GetSType(info.content_type);
spirv::Value buffer = MakeValue(op->buffer_var);
spirv::Value value = MakeValue(op->value);
spirv::SType ptr_type = builder_->GetPointerType(
content_type, buffer.stype.storage_class);
uint32_t mask = spv::MemoryAccessMaskNone;
if (info.is_volatile) {
mask |= spv::MemoryAccessVolatileMask;
}
if (op->value.type().lanes() == 1) {
CHECK_EQ(info.content_type, op->value.type())
<< "Vulkan only allow one type access to the same buffer";
spirv::Value index = MakeValue(op->index);
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, value, mask);
} else {
if (op->value.type().element_of() == info.content_type) {
// because content type is element type, we can only do scalarize load.
auto f = [&](int i, spirv::Value index) {
spirv::Value elem = builder_->MakeValue(
spv::OpCompositeExtract, content_type, value, i);
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, elem, mask);
};
this->Scalarize(op->index, f);
} else {
if (const Ramp* ramp = op->index.as<Ramp>()) {
if (is_one(ramp->stride)) {
CHECK_EQ(ramp->lanes, op->value.type().lanes());
arith::ModularEntry me = arith::EvalModular(ramp->base, align_map_);
CHECK((me.coeff % ramp->lanes) == 0 &&
(me.base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
Expr vec_index = ir::Simplify(
ramp->base / make_const(ramp->base.type(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, MakeValue(vec_index));
builder_->MakeInst(spv::OpStore, ptr, value, mask);
return;
}
}
LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
}
}
}
void CodeGenSPIRV::VisitStmt_(const For* op) {
CHECK(is_zero(op->min));
spirv::Value init_value = MakeValue(op->min);
spirv::Value extent_value = MakeValue(op->extent);
// Must get init label after making value(to make sure they are correct)
spirv::Label init_label = builder_->CurrentLabel();
spirv::Label head_label = builder_->NewLabel();
spirv::Label body_label = builder_->NewLabel();
spirv::Label continue_label = builder_->NewLabel();
spirv::Label merge_label = builder_->NewLabel();
builder_->MakeInst(spv::OpBranch, head_label);
// Loop head
builder_->StartLabel(head_label);
spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2);
loop_var.SetIncoming(0, init_value, init_label);
spirv::Value loop_cond = builder_->LT(loop_var, extent_value);
uint32_t control = (
op->for_type == ForType::Unrolled ?
spv::LoopControlUnrollMask : spv::LoopControlMaskNone);
builder_->MakeInst(
spv::OpLoopMerge, merge_label, continue_label, control);
builder_->MakeInst(
spv::OpBranchConditional, loop_cond, body_label, merge_label,
weight_likely_branch_, 1);
// loop body
builder_->StartLabel(body_label);
var_map_[op->loop_var.get()] = spirv::Value(loop_var);
this->VisitStmt(op->body);
builder_->MakeInst(spv::OpBranch, continue_label);
// loop continue
builder_->StartLabel(continue_label);
spirv::Value one =
op->loop_var.type().is_int() ?
builder_->IntImm(loop_var.stype, 1) :
builder_->UIntImm(loop_var.stype, 1);
spirv::Value next_value = builder_->Add(loop_var, one);
loop_var.SetIncoming(1, next_value, builder_->CurrentLabel());
builder_->MakeInst(spv::OpBranch, head_label);
// loop merge
builder_->StartLabel(merge_label);
}
void CodeGenSPIRV::VisitStmt_(const IfThenElse* op) {
spirv::Value cond = MakeValue(op->condition);
spirv::Label then_label = builder_->NewLabel();
spirv::Label merge_label = builder_->NewLabel();
if (op->else_case.defined()) {
spirv::Label else_label = builder_->NewLabel();
builder_->MakeInst(
spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
builder_->MakeInst(
spv::OpBranchConditional, cond, then_label, else_label);
// then block
builder_->StartLabel(then_label);
this->VisitStmt(op->then_case);
builder_->MakeInst(spv::OpBranch, merge_label);
// else block
builder_->StartLabel(else_label);
this->VisitStmt(op->else_case);
builder_->MakeInst(spv::OpBranch, merge_label);
} else {
builder_->MakeInst(
spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
builder_->MakeInst(
spv::OpBranchConditional, cond, then_label, merge_label,
weight_likely_branch_, 1);
// then block
builder_->StartLabel(then_label);
this->VisitStmt(op->then_case);
builder_->MakeInst(spv::OpBranch, merge_label);
}
// start merge label;
builder_->StartLabel(merge_label);
}
void CodeGenSPIRV::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition));
CHECK(!op->new_expr.defined());
CHECK(!op->type.is_handle());
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation in GPU";
spirv::Value buf;
StorageInfo& info = storage_info_[op->buffer_var.get()];
spirv::SType etype = builder_->GetSType(op->type);
if (info.scope.rank == 2) {
buf = builder_->Allocate(
etype, static_cast<uint32_t>(constant_size),
spv::StorageClassFunction);
} else {
// shared memory
CHECK_EQ(info.scope.rank, 1)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory
buf = builder_->Allocate(
etype, static_cast<uint32_t>(constant_size),
spv::StorageClassWorkgroup);
}
CHECK(!info.content_fixed);
info.UpdateContentType(op->type);
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body);
}
void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
if (!var_map_.count(iv->var.get())) {
var_map_[iv->var.get()] = GetThreadIndex(iv, op->value);
}
}
} else if (op->attr_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
storage_info_[v].scope =
runtime::StorageScope::make(op->value.as<StringImm>()->value);
} else if (op->attr_key == ir::attr::volatile_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
storage_info_[v].is_volatile = true;
}
this->VisitStmt(op->body);
}
void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) {
VisitAssert(op, &align_map_, [this](const Stmt& body) {
this->VisitStmt(body);
});
}
void CodeGenSPIRV::VisitStmt_(const LetStmt* op) {
CHECK(!var_map_.count(op->var.get()));
CHECK(!align_map_.count(op->var.get()));
CHECK(!op->var.type().is_handle());
var_map_[op->var.get()] = MakeValue(op->value);
align_map_[op->var.get()] = EvalModular(op->value, align_map_);
this->VisitStmt(op->body);
}
void CodeGenSPIRV::VisitStmt_(const Block* op) {
VisitStmt(op->first);
if (op->rest.defined()) {
this->VisitStmt(op->rest);
}
}
void CodeGenSPIRV::VisitStmt_(const Evaluate* op) {
MakeValue(op->value);
}
void CodeGenSPIRV::VisitStmt_(const ProducerConsumer* op) {
this->VisitStmt(op->body);
}
} // namespace codegen
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
/*!
* Copyright (c) 2018 by Contributors
* \file ir_builder.h
* \brief Utility for building SPIRV code block
*/
#ifndef TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
#define TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/lowered_func.h>
#include <vector>
#include "./ir_builder.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace codegen {
using namespace ir;
/*!
* \brief Code generator into SPIRV
*/
class CodeGenSPIRV:
public ExprFunctor<spirv::Value(const Expr&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
* \return The final spirv module.
*/
virtual std::vector<uint32_t> BuildFunction(const LoweredFunc& f);
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
* \return created value.
*/
spirv::Value MakeValue(const Expr& e) {
return VisitExpr(e);
}
// override codegen
spirv::Value VisitExpr_(const Variable* op) override;
spirv::Value VisitExpr_(const Cast* op) override;
spirv::Value VisitExpr_(const IntImm* op) override;
spirv::Value VisitExpr_(const UIntImm* op) override;
spirv::Value VisitExpr_(const FloatImm* op) override;
spirv::Value VisitExpr_(const StringImm* op) override;
spirv::Value VisitExpr_(const Add* op) override;
spirv::Value VisitExpr_(const Sub* op) override;
spirv::Value VisitExpr_(const Mul* op) override;
spirv::Value VisitExpr_(const Div* op) override;
spirv::Value VisitExpr_(const Mod* op) override;
spirv::Value VisitExpr_(const Min* op) override;
spirv::Value VisitExpr_(const Max* op) override;
spirv::Value VisitExpr_(const LT* op) override;
spirv::Value VisitExpr_(const LE* op) override;
spirv::Value VisitExpr_(const GT* op) override;
spirv::Value VisitExpr_(const GE* op) override;
spirv::Value VisitExpr_(const EQ* op) override;
spirv::Value VisitExpr_(const NE* op) override;
spirv::Value VisitExpr_(const And* op) override;
spirv::Value VisitExpr_(const Or* op) override;
spirv::Value VisitExpr_(const Not* op) override;
spirv::Value VisitExpr_(const Select* op) override;
spirv::Value VisitExpr_(const Let* op) override;
spirv::Value VisitExpr_(const Call* op) override;
spirv::Value VisitExpr_(const Ramp* op) override;
spirv::Value VisitExpr_(const Broadcast* op) override;
spirv::Value VisitExpr_(const Load* op) override;
// stmt
void VisitStmt_(const Store* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
protected:
/*! \brief The storage information */
struct StorageInfo {
/*! \brief The storage scope */
runtime::StorageScope scope;
/*! \brief Whether it is volatile */
bool is_volatile{false};
/*! \brief Whether it is volatile */
bool content_fixed{false};
/*! \brief Current content type */
Type content_type{Handle()};
// Update content type if it hasn't beenupdated.
void UpdateContentType(Type type) {
if (content_fixed) {
CHECK_EQ(type, content_type)
<< "Cannot use two different content type in GLSL model";
} else {
this->content_type = type;
content_fixed = true;
}
}
};
// Reset the state so it works for a new function.
void InitFuncState();
// Get the thread index
spirv::Value GetThreadIndex(const IterVar& iv, const Expr& extent);
spirv::Value CreateStorageSync(const Call* op);
void Scalarize(const Expr& e,
std::function<void(int i, spirv::Value v)> f);
// The builder
std::unique_ptr<spirv::IRBuilder> builder_;
// Work group size of three
uint32_t workgroup_size_[3];
// Likely branch
uint32_t weight_likely_branch_{128};
// the storage scope of allocation
std::unordered_map<const Variable*, StorageInfo> storage_info_;
// The definition of local variable.
std::unordered_map<const Variable*, spirv::Value> var_map_;
// The alignment information
std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_spirv.cc
*/
#if TVM_VULKAN_RUNTIME
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
#include <vulkan/GLSL.std.450.h>
namespace tvm {
namespace codegen {
namespace spirv {
using namespace runtime;
// num_signature means number of arguments used to query signature
template<unsigned id>
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(ir::UIntImm::make(UInt(32), id));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::Call::make(
call->type, "spirv_glsl450", cargs, ir::Call::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
} // namespace spirv
} // namespace codegen
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
/*!
* Copyright (c) 2018 by Contributors
* \file ir_builder.cc
* \brief IRBuilder for SPIRV block
*/
#if TVM_VULKAN_RUNTIME
#include "./ir_builder.h"
namespace tvm {
namespace codegen {
namespace spirv {
// implementations
void IRBuilder::InitHeader() {
CHECK_EQ(header_.size(), 0U);
header_.push_back(spv::MagicNumber);
header_.push_back(spv::Version);
// generator: set to 0, unknown
header_.push_back(0U);
// Bound: set during Finalize
header_.push_back(0U);
// Schema: reserved
header_.push_back(0U);
// shader
ib_.Begin(spv::OpCapability).Add(spv::CapabilityShader).Commit(&header_);
// memory model
ib_.Begin(spv::OpMemoryModel).AddSeq(
spv::AddressingModelLogical,
spv::MemoryModelGLSL450).Commit(&entry_);
this->InitPreDefs();
}
void IRBuilder::InitPreDefs() {
ext_glsl450_ = ExtInstImport("GLSL.std.450");
t_int32_ = DeclareType(Int(32));
t_uint32_ = DeclareType(UInt(32));
t_bool_ = DeclareType(UInt(1));
t_fp32_ = DeclareType(Float(32));
const_i32_zero_ = IntImm(t_int32_, 0);
// declare void, and void functions
t_void_.id = id_counter_++;
ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_);
t_void_func_.id = id_counter_++;
ib_.Begin(spv::OpTypeFunction)
.AddSeq(t_void_func_, t_void_).Commit(&global_);
}
SType IRBuilder::GetSType(const Type& dtype) {
if (dtype == Int(32)) {
return t_int32_;
} else if (dtype == UInt(1)) {
return t_bool_;
} else if (dtype == Float(32)) {
return t_fp32_;
} else if (dtype == UInt(32)) {
return t_uint32_;
}
uint32_t type_key;
type_key = static_cast<uint32_t>(dtype.code());
type_key |= static_cast<uint32_t>(dtype.bits()) << 8U;
type_key |= static_cast<uint32_t>(dtype.lanes()) << 16U;
auto it = pod_type_tbl_.find(type_key);
if (it != pod_type_tbl_.end()) {
return it->second;
}
SType t = DeclareType(dtype);
pod_type_tbl_[type_key] = t;
return t;
}
SType IRBuilder::GetPointerType(const SType& value_type,
spv::StorageClass storage_class) {
CHECK_NE(storage_class, spv::StorageClassMax);
auto key = std::make_pair(value_type.id, storage_class);
auto it = pointer_type_tbl_.find(key);
if (it != pointer_type_tbl_.end()) {
return it->second;
}
SType t;
t.id = id_counter_++;
t.type = Handle();
t.element_type_id = value_type.id;
t.storage_class = storage_class;
ib_.Begin(spv::OpTypePointer)
.AddSeq(t, storage_class, value_type).Commit(&global_);
pointer_type_tbl_[key] = t;
return t;
}
SType IRBuilder::GetStructArrayType(const SType& value_type,
uint32_t num_elems) {
auto key = std::make_pair(value_type.id, num_elems);
auto it = struct_array_type_tbl_.find(key);
if (it != struct_array_type_tbl_.end()) {
return it->second;
}
SType arr_type;
arr_type.id = id_counter_++;
arr_type.type = Handle();
arr_type.element_type_id = value_type.id;
if (num_elems != 0) {
Value length = UIntImm(GetSType(UInt(32)), num_elems);
ib_.Begin(spv::OpTypeArray)
.AddSeq(arr_type, value_type, length).Commit(&global_);
} else {
ib_.Begin(spv::OpTypeRuntimeArray)
.AddSeq(arr_type, value_type).Commit(&global_);
}
int nbits = value_type.type.bits() * value_type.type.lanes();
CHECK_EQ(nbits % 8, 0);
uint32_t nbytes = static_cast<uint32_t>(nbits) / 8;
// decorate the array type.
this->Decorate(spv::OpDecorate,
arr_type, spv::DecorationArrayStride, nbytes);
// declare struct of array
SType struct_type;
struct_type.id = id_counter_++;
struct_type.type = Handle();
struct_type.element_type_id = value_type.id;
ib_.Begin(spv::OpTypeStruct)
.AddSeq(struct_type, arr_type).Commit(&global_);
// decorate the array type.
ib_.Begin(spv::OpMemberDecorate)
.AddSeq(struct_type, 0, spv::DecorationOffset, 0)
.Commit(&decorate_);
// runtime array are always decorated as BufferBlock(shader storage buffer)
if (num_elems == 0) {
this->Decorate(spv::OpDecorate,
struct_type, spv::DecorationBufferBlock);
}
struct_array_type_tbl_[key] = struct_type;
return struct_type;
}
Value IRBuilder::StructArrayAccess(const SType& res_type,
Value buffer,
Value index) {
CHECK(buffer.flag == kStructArrayPtr);
return MakeValue(spv::OpInBoundsAccessChain,
res_type, buffer,
const_i32_zero_, index);
}
Value IRBuilder::IntImm(const SType& dtype, int64_t value) {
return GetConst_(dtype, reinterpret_cast<uint64_t*>(&value));
}
Value IRBuilder::UIntImm(const SType& dtype, uint64_t value) {
return GetConst_(dtype, &value);
}
Value IRBuilder::FloatImm(const SType& dtype, double value) {
if (dtype.type.bits() == 64) {
return GetConst_(dtype, reinterpret_cast<uint64_t*>(&value));
} else if (dtype.type.bits() == 32) {
float fvalue = static_cast<float>(value);
uint64_t data = reinterpret_cast<uint32_t*>(&fvalue)[0];
return GetConst_(dtype, &data);
} else {
CHECK_EQ(dtype.type.bits(), 16);
return Cast(dtype,
FloatImm(GetSType(Float(32)), value));
}
}
Value IRBuilder::BufferArgument(const SType& value_type,
uint32_t descriptor_set,
uint32_t binding) {
SType sarr_type = GetStructArrayType(value_type, 0);
SType ptr_type = GetPointerType(sarr_type, spv::StorageClassUniform);
Value val = NewValue(ptr_type, kStructArrayPtr);
ib_.Begin(spv::OpVariable)
.AddSeq(ptr_type, val, spv::StorageClassUniform).Commit(&global_);
this->Decorate(spv::OpDecorate,
val, spv::DecorationDescriptorSet, descriptor_set);
this->Decorate(spv::OpDecorate,
val, spv::DecorationBinding, binding);
return val;
}
Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
CHECK_EQ(push_const_.id, 0);
SType struct_type;
struct_type.id = id_counter_++;
struct_type.type = Handle();
ib_.Begin(spv::OpTypeStruct).Add(struct_type);
for (const SType& vtype : value_types) {
ib_.Add(vtype);
}
ib_.Commit(&global_);
uint32_t offset = 0;
for (uint32_t i = 0; i < value_types.size(); ++i) {
ib_.Begin(spv::OpMemberDecorate)
.AddSeq(struct_type, i, spv::DecorationOffset, offset)
.Commit(&decorate_);
Type t = value_types[i].type;
uint32_t nbits = t.bits() * t.lanes();
CHECK_EQ(nbits % 8 , 0);
offset += nbits / 8;
}
// Decorate push constants as UBO
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
SType ptr_type = GetPointerType(
struct_type, spv::StorageClassPushConstant);
Value val = NewValue(ptr_type, kPushConstantPtr);
ib_.Begin(spv::OpVariable)
.AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_);
return val;
}
Value IRBuilder::GetPushConstant(
Value ptr_push_const, const SType& v_type, uint32_t index) {
SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant);
Value ptr = this->MakeValue(
spv::OpAccessChain, ptr_vtype, ptr_push_const,
IntImm(t_int32_, static_cast<int64_t>(index)));
return this->MakeValue(spv::OpLoad, v_type, ptr);
}
Value IRBuilder::DeclareKenrelFunction(const std::string& name) {
Value val = NewValue(t_void_func_, kFunction);
ib_.Begin(spv::OpEntryPoint)
.AddSeq(spv::ExecutionModelGLCompute, val, name)
.Commit(&entry_);
return val;
}
void IRBuilder::StartFunction(const Value& func) {
CHECK_EQ(func.flag, kFunction);
this->MakeInst(
spv::OpFunction, t_void_, func, 0, t_void_func_);
spirv::Label start_label = this->NewLabel();
this->StartLabel(start_label);
}
void IRBuilder::SetLocalSize(const Value& func,
uint32_t local_size[3]) {
CHECK_EQ(func.flag, kFunction);
ib_.Begin(spv::OpExecutionMode)
.AddSeq(func, spv::ExecutionModeLocalSize,
local_size[0], local_size[1], local_size[2])
.Commit(&exec_mode_);
}
Value IRBuilder::Allocate(const SType& value_type,
uint32_t num_elems,
spv::StorageClass storage_class) {
CHECK_NE(num_elems, 0U);
SType sarr_type = GetStructArrayType(value_type, num_elems);
SType ptr_type = GetPointerType(sarr_type, storage_class);
Value val = NewValue(ptr_type, kStructArrayPtr);
if (storage_class == spv::StorageClassFunction) {
ib_.Begin(spv::OpVariable)
.AddSeq(ptr_type, val, storage_class).Commit(&function_);
} else {
ib_.Begin(spv::OpVariable)
.AddSeq(ptr_type, val, storage_class).Commit(&global_);
}
return val;
}
Value IRBuilder::GetWorkgroupID(uint32_t dim_index) {
if (workgroup_id_.id == 0) {
SType vec3_type = this->GetSType(Int(32).with_lanes(3));
SType ptr_type = this->GetPointerType(
vec3_type, spv::StorageClassInput);
workgroup_id_ = NewValue(ptr_type, kVectorPtr);
ib_.Begin(spv::OpVariable)
.AddSeq(ptr_type, workgroup_id_, spv::StorageClassInput)
.Commit(&global_);
this->Decorate(spv::OpDecorate, workgroup_id_,
spv::DecorationBuiltIn, spv::BuiltInWorkgroupId);
}
SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput);
Value ptr = this->MakeValue(
spv::OpAccessChain, pint_type, workgroup_id_,
IntImm(t_int32_, static_cast<int64_t>(dim_index)));
return this->MakeValue(spv::OpLoad, t_int32_, ptr);
}
Value IRBuilder::GetLocalID(uint32_t dim_index) {
if (local_id_.id == 0) {
SType vec3_type = this->GetSType(Int(32).with_lanes(3));
SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput);
local_id_ = NewValue(ptr_type, kVectorPtr);
ib_.Begin(spv::OpVariable)
.AddSeq(ptr_type, local_id_, spv::StorageClassInput)
.Commit(&global_);
this->Decorate(spv::OpDecorate, local_id_,
spv::DecorationBuiltIn, spv::BuiltInLocalInvocationId);
}
SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput);
Value ptr = this->MakeValue(
spv::OpAccessChain, pint_type, local_id_,
UIntImm(t_int32_, static_cast<int64_t>(dim_index)));
return this->MakeValue(spv::OpLoad, t_int32_, ptr);
}
Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) {
auto key = std::make_pair(dtype.id, pvalue[0]);
auto it = const_tbl_.find(key);
if (it != const_tbl_.end()) {
return it->second;
}
CHECK_LE(dtype.type.bits(), 64);
Value ret = NewValue(dtype, kConstant);
ib_.Begin(spv::OpConstant).AddSeq(dtype, ret);
uint64_t mask = 0xFFFFFFFFUL;
ib_.Add(static_cast<uint32_t>(pvalue[0] & mask));
if (dtype.type.bits() > 32) {
if (dtype.type.is_int()) {
int64_t sign_mask = 0xFFFFFFFFL;
const int64_t* sign_ptr =
reinterpret_cast<const int64_t*>(pvalue);
ib_.Add(static_cast<uint32_t>((sign_ptr[0] >> 32L) & sign_mask));
} else {
ib_.Add(static_cast<uint32_t>((pvalue[0] >> 32UL) & mask));
}
}
ib_.Commit(&global_);
const_tbl_[key] = ret;
return ret;
}
SType IRBuilder::DeclareType(const Type& dtype) {
if (dtype.lanes() == 1) {
SType t;
t.id = id_counter_++;
t.type = dtype;
if (dtype.bits() == 1) {
CHECK(dtype.is_uint());
ib_.Begin(spv::OpTypeBool).Add(t).Commit(&global_);
} else if (dtype.is_int()) {
ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 1).Commit(&global_);
} else if (dtype.is_uint()) {
ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 0).Commit(&global_);
} else if (dtype.is_float()) {
ib_.Begin(spv::OpTypeFloat).AddSeq(t, dtype.bits()).Commit(&global_);
} else {
LOG(FATAL) << "declare type do not support handle";
}
return t;
} else {
SType t;
t.id = id_counter_++;
t.type = dtype;
SType base_type = GetSType(dtype.element_of());
ib_.Begin(spv::OpTypeVector).AddSeq(
t, base_type, dtype.lanes()).Commit(&global_);
return t;
}
}
PhiValue IRBuilder::MakePhi(const SType& out_type, uint32_t num_incoming) {
Value val = NewValue(out_type, kNormal);
ib_.Begin(spv::OpPhi).AddSeq(out_type, val);
for (uint32_t i = 0; i < 2 * num_incoming; ++i) {
ib_.Add(0);
}
PhiValue phi;
phi.id = val.id;
phi.stype = out_type;
phi.flag = kNormal;
phi.instr = ib_.Commit(&function_);
CHECK_EQ(phi.instr.WordCount(), 2 * num_incoming + 3);
return phi;
}
Value IRBuilder::CallGLSL450(const SType& ret_type,
uint32_t inst_id,
const std::vector<Value>& args) {
Value val = NewValue(ret_type, kNormal);
ib_.Begin(spv::OpExtInst)
.AddSeq(ret_type, val, ext_glsl450_, inst_id);
for (const Value& v : args) {
ib_.Add(v);
}
ib_.Commit(&function_);
return val;
}
Value IRBuilder::Concat(const std::vector<Value>& vec) {
bool is_const = vec[0].flag == kConstant;
Type etype = vec[0].stype.type;
int lanes = etype.lanes();
for (size_t i = 1; i < vec.size(); ++i) {
CHECK_EQ(etype, vec[i].stype.type.element_of())
<< "Cannot concat vector of different element type";
lanes += vec[i].stype.type.lanes();
is_const = is_const && (vec[i].flag == kConstant);
}
Value ret = NewValue(GetSType(etype.with_lanes(lanes)), kNormal);
if (is_const && vec.size() == static_cast<size_t>(lanes)) {
ib_.Begin(spv::OpConstantComposite);
ib_.AddSeq(ret.stype, ret);
for (const Value& v : vec) {
ib_.Add(v);
}
ib_.Commit(&global_);
} else {
ib_.Begin(spv::OpCompositeConstruct);
ib_.AddSeq(ret.stype, ret);
for (const Value& v : vec) {
ib_.Add(v);
}
ib_.Commit(&function_);
}
return ret;
}
Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
CHECK_NE(value.stype.id, 0U);
if (value.stype.id == dst_type.id) return value;
const tvm::Type& from = value.stype.type;
const tvm::Type& to = dst_type.type;
CHECK_EQ(from.lanes(), to.lanes());
if (from.is_int() && to.is_int()) {
return MakeValue(spv::OpSConvert, dst_type, value);
} else if (from.is_uint() && to.is_uint()) {
return MakeValue(spv::OpUConvert, dst_type, value);
} else if (from.is_uint() && to.is_int()) {
if (from.bits() != to.bits()) {
value = MakeValue(
spv::OpUConvert, GetSType(from.with_bits(to.bits())), value);
}
return MakeValue(spv::OpBitcast, dst_type, value);
} else if (from.is_int() && to.is_uint()) {
if (from.bits() != to.bits()) {
value = MakeValue(
spv::OpSConvert, GetSType(from.with_bits(to.bits())), value);
}
return MakeValue(spv::OpBitcast, dst_type, value);
} else if (from.is_float() && to.is_int()) {
return MakeValue(spv::OpConvertFToS, dst_type, value);
} else if (from.is_float() && to.is_uint()) {
return MakeValue(spv::OpConvertFToU, dst_type, value);
} else if (from.is_int() && to.is_float()) {
return MakeValue(spv::OpConvertSToF, dst_type, value);
} else if (from.is_uint() && to.is_float()) {
return MakeValue(spv::OpConvertUToF, dst_type, value);
} else if (from.is_float() && to.is_float()) {
return MakeValue(spv::OpFConvert, dst_type, value);
} else {
LOG(FATAL) << "do not support type cast from "
<< from << " to " << to;
return Value();
}
}
#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
return MakeValue(spv::OpI ## _Op, a.stype, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpF ## _Op, a.stype, a, b); \
} \
}
#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS ## _Op, a.stype, a, b); \
} else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU ## _Op, a.stype, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpF ## _Op, a.stype, a, b); \
} \
}
DEFINE_BUILDER_BINARY_USIGN_OP(Add, Add);
DEFINE_BUILDER_BINARY_USIGN_OP(Sub, Sub);
DEFINE_BUILDER_BINARY_USIGN_OP(Mul, Mul);
DEFINE_BUILDER_BINARY_SIGN_OP(Div, Div);
Value IRBuilder::Mod(Value a, Value b) {
CHECK_EQ(a.stype.id, b.stype.id);
if (a.stype.type.is_int()) {
return MakeValue(spv::OpSRem, a.stype, a, b);
} else if (a.stype.type.is_uint()) {
return MakeValue(spv::OpUMod, a.stype, a, b);
} else {
CHECK(a.stype.type.is_float());
return MakeValue(spv::OpFRem, a.stype, a, b);
}
}
#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
Value IRBuilder:: _OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
if (t_bool_.id == 0) { \
t_bool_ = DeclareType(UInt(1)); \
} \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS ## _Op, t_bool_, a, b); \
} else if (a.stype.type.is_uint()) { \
return MakeValue(spv::OpU ## _Op, t_bool_, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \
} \
}
DEFINE_BUILDER_CMP_OP(LT, LessThan);
DEFINE_BUILDER_CMP_OP(LE, LessThanEqual);
DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
Value IRBuilder:: _OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
if (t_bool_.id == 0) { \
t_bool_ = DeclareType(UInt(1)); \
} \
if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
return MakeValue(spv::OpI ## _Op, t_bool_, a, b); \
} else { \
CHECK(a.stype.type.is_float()); \
return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \
} \
}
DEFINE_BUILDER_CMP_UOP(EQ, Equal);
DEFINE_BUILDER_CMP_UOP(NE, NotEqual);
Value IRBuilder::Select(Value cond, Value a, Value b) {
CHECK_EQ(a.stype.id, b.stype.id);
CHECK_EQ(cond.stype.type, UInt(1));
return MakeValue(spv::OpSelect, a.stype, cond, a, b);
}
} // namespace spirv
} // namespace codegen
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
/*!
* Copyright (c) 2018 by Contributors
* \file ir_builder.h
* \brief Utility for building SPIRV code block
*/
#ifndef TVM_CODEGEN_SPIRV_IR_BUILDER_H_
#define TVM_CODEGEN_SPIRV_IR_BUILDER_H_
#include <tvm/runtime/packed_func.h>
#include <tvm/ir.h>
#include <algorithm>
#include <utility>
#include <vector>
#include <string>
#include <map>
#include <vulkan/spirv.hpp>
namespace tvm {
namespace codegen {
namespace spirv {
/*! \brief Represent the SPIRV Type */
struct SType {
/*! \brief The Id to represent type */
uint32_t id{0};
/*! \brief corresponding TVM type */
tvm::Type type;
/*! \brief content type id if it is a pointer/struct-array class */
uint32_t element_type_id{0};
/*! \brief The storage class, if it is a pointer */
spv::StorageClass storage_class{spv::StorageClassMax};
};
enum ValueKind {
kNormal,
kConstant,
kVectorPtr,
kStructArrayPtr,
kPushConstantPtr,
kFunction,
kExtInst
};
/*! \brief Represent the SPIRV Value */
struct Value {
/*! \brief The Id to represent value */
uint32_t id{0};
/*! \brief The data type */
SType stype;
/*! \brief additional flags about the value */
ValueKind flag{kNormal};
};
/*! \brief Represent the SPIRV Label */
struct Label {
/*! \brief The Id to represent label */
uint32_t id{0};
};
/*!
* \brief A SPIRV instruction,
* can be used as handle to modify its content later
*/
class Instr {
public:
/*! \return the word count */
uint32_t WordCount() const {
return word_count_;
}
/*!
* \brief Access idx-th word of instruction
* \param idx The index
* \return reference to idx-th word.
*/
uint32_t& operator[](uint32_t idx) {
CHECK_LT(idx, word_count_);
return (*data_)[begin_ + idx];
}
private:
friend class InstrBuilder;
/*!
* \brief the data that backs this instruction
* Have to use vector reference because
* vector can change.
*/
std::vector<uint32_t>* data_{nullptr};
/*! \brief begin location of instruction */
uint32_t begin_{0};
/*! \brief work count */
uint32_t word_count_{0};
};
/*! \brief Representation of phi value */
struct PhiValue : public Value {
/*! \brief The corresponding instr */
Instr instr;
/*!
* \brief Add incoming information of a PhiValue
* \param index The location of Phi
* \param value The value to come
* \param parent The parent label.
*/
void SetIncoming(uint32_t index,
const Value& value,
const Label& parent) {
CHECK_EQ(this->stype.id, value.stype.id);
instr[3 + index * 2] = value.id;
instr[3 + index * 2 + 1] = parent.id;
}
};
/*!
* \brief Helper class to build SPIRV instruction.
*
* \code
*
* std::vector<uint32_t> func_seg_vec_;
* InstrBuilder ib;
*
* // construct and append to the end of func_seg_vec_;
* ib.Begin(spv::OpIAdd)
* .Add(result).Add(v1).Add(v2)
* .Commit(&func_seg_vec_);
*
* \endcode
*/
class InstrBuilder {
public:
/*!
* \brief Begin construction of instruction.
* \param op The op code
* \return reference to self.
*/
InstrBuilder& Begin(spv::Op op) { // NOLINT(*);
// finish previous build
CHECK_EQ(data_.size(), 0U);
op_ = op;
data_.push_back(0);
return *this;
}
/*!
* \brief Add v to end of instruction.
* \param v The value to be appended to the instruction.
* \return reference to self.
*/
InstrBuilder& Add(const Value& v) {
data_.push_back(v.id);
return *this;
}
/*!
* \brief Add v to end of instruction.
* \param v The type to be appended to the instruction.
* \return reference to self.
*/
InstrBuilder& Add(const SType& v) {
data_.push_back(v.id);
return *this;
}
/*!
* \brief Add v to end of instruction.
* \param v The label to be appended to the instruction.
* \return reference to self.
*/
InstrBuilder& Add(const Label& v) {
data_.push_back(v.id);
return *this;
}
/*!
* \brief Add a word to end of instruction.
* \param v The value to be added.
* \return reference to self.
*/
InstrBuilder& Add(const uint32_t& v) {
data_.push_back(v);
return *this;
}
/*!
* \brief Add string literal of end of instruction.
* \param v The string literal to be appended.
* \return reference to self.
*/
InstrBuilder& Add(const std::string& v) {
const uint32_t kWordSize = sizeof(uint32_t);
uint32_t nwords =
(static_cast<uint32_t>(v.length()) + kWordSize) / kWordSize;
size_t begin = data_.size();
data_.resize(begin + nwords, 0U);
std::copy(v.begin(), v.end(),
reinterpret_cast<char*>(&data_[begin]));
return *this;
}
/*!
* \brief add sequence of values to instruction
* \param args The instruction sequence
* \return reference to self.
* \tparams Args The positional arguments
*/
template<typename... Args>
InstrBuilder& AddSeq(Args&& ...args) {
AddSeqHelper helper;
helper.builder = this;
runtime::detail::for_each(helper, std::forward<Args>(args)...);
return *this;
}
/*!
* \brief Finish build, commit the current
* instruction to the end of seg.
*
* \param seg The code segment to commit to
* \return The result instruction.
*/
Instr Commit(std::vector<uint32_t>* seg) {
Instr ret;
ret.data_ = seg;
ret.begin_ = seg->size();
ret.word_count_ = static_cast<uint32_t>(data_.size());
data_[0] = op_ | (ret.word_count_ << spv::WordCountShift);
seg->insert(seg->end(), data_.begin(), data_.end());
data_.clear();
return ret;
}
private:
// current op code.
spv::Op op_;
// The internal data to store code
std::vector<uint32_t> data_;
// helper class to support variadic arguments
struct AddSeqHelper {
// The reference to builder
InstrBuilder* builder;
// invoke function
template<typename T>
void operator()(size_t, const T& v) const {
builder->Add(v);
}
};
};
/*!
* \brief Builder to build up a single SPIR-V module
*
* This is a thin wrapper to build SPIRV binary.
* SPIRV adopts structure control-flow.
* We can build the code by always appending to the end of the
* binary code block and revisit some
*
* This IRBuilder did not introduce concept of BasicBlock.
* instead instructions are append to end of each segment.
*/
class IRBuilder {
public:
/*! \brief Initialize header */
void InitHeader();
/*! \brief Initialize the predefined contents */
void InitPreDefs();
/*!
* \brief Import additional extension libraries.
* \param name The name of the library.
* \return The finalized binary instruction.
*/
Value ExtInstImport(const std::string& name) {
Value val = NewValue(SType(), kExtInst);
ib_.Begin(spv::OpExtInstImport).AddSeq(val, name).Commit(&header_);
return val;
}
/*!
* \brief Get the final binary built from the builder
* \return The finalized binary instruction.
*/
std::vector<uint32_t> Finalize() {
std::vector<uint32_t> data;
// set bound
const int kBoundLoc = 3;
header_[kBoundLoc] = id_counter_;
data.insert(data.end(), header_.begin(), header_.end());
data.insert(data.end(), entry_.begin(), entry_.end());
data.insert(data.end(), exec_mode_.begin(), exec_mode_.end());
data.insert(data.end(), debug_.begin(), debug_.end());
data.insert(data.end(), decorate_.begin(), decorate_.end());
data.insert(data.end(), global_.begin(), global_.end());
data.insert(data.end(), function_.begin(), function_.end());
return data;
}
/*!
* \brief Create new label
* \return The created new label
*/
Label NewLabel() {
Label label;
label.id = id_counter_++;
return label;
}
/*!
* \brief Start a new block with given label
* \param label The label we use.
*/
void StartLabel(Label label) {
MakeInst(spv::OpLabel, label);
curr_label_ = label;
}
/*! \return The current label */
Label CurrentLabel() const {
return curr_label_;
}
/*!
* \brief Add code to debug segment.
* \param op The operator
* \param args The instruction sequence
* \tparams Args The positional arguments
*/
template<typename... Args>
void Debug(spv::Op op, Args&& ...args) {
ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&debug_);
}
/*!
* \brief Add Execution mode to a function.
* \param func The function value
* \param args The instruction sequence
* \tparams Args The positional arguments
*/
template<typename... Args>
void ExecutionMode(Value func, Args&& ...args) {
ib_.Begin(spv::OpExecutionMode).AddSeq(
func, std::forward<Args>(args)...).Commit(&exec_mode_);
}
/*!
* \brief Add code to decorate segment.
* \param op The operator
* \param args The instruction sequence
* \tparams Args The positional arguments
*/
template<typename... Args>
void Decorate(spv::Op op, Args&& ...args) {
ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&decorate_);
}
/*!
* \brief Add code to global segment.
* \param op The operator
* \param args The instruction sequence
* \tparams Args The positional arguments
*/
template<typename... Args>
Value DeclareGlobal(spv::Op op, Args&& ...args) {
ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&decorate_);
}
/*!
* \brief Make a new instruction and append it to end of function segment.
*
* \param op The operator
* \param args The instruction sequence
* \return The result SSA value.
* \tparams Args The positional arguments
*/
template<typename... Args>
Instr MakeInst(spv::Op op, Args&& ...args) {
return ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&function_);
}
/*!
* \brief Make a new SSA value,
*
* \param op The operator.
* \param out_type The result type.
* \param args The instruction sequence
* \return The result SSA value.
* \tparams Args The positional arguments
*/
template<typename... Args>
Value MakeValue(spv::Op op, const SType& out_type, Args&& ...args) {
Value val = NewValue(out_type, kNormal);
MakeInst(op, out_type, val, std::forward<Args>(args)...);
return val;
}
/*!
* \brief Make a phi value.
*
* \param out_type The output data type.
* \param num_incoming number of incoming blocks.
* \return The result Phi value.
*/
PhiValue MakePhi(const SType& out_type, uint32_t num_incoming);
/*!
* \brief Create a GLSL450 call
*
* \param ret_type The result type.
* \param inst_id The instance id of the function.
* \param args The arguments
* \return The result value.
*/
Value CallGLSL450(const SType& ret_type,
uint32_t inst_id,
const std::vector<Value>& args);
/*!
* \brief Build vector by concatenating components
*
* \param vec The vector component
* \tparams Args The positional arguments
*/
Value Concat(const std::vector<Value>& vec);
/*!
* \brief Get the spirv type for a given tvm data type.
* \param dtype The data type.
* \return The corresponding spirv type.
*/
SType GetSType(const tvm::Type& dtype);
/*!
* \brief Get the pointer type that points to value_type
* \param value_type.
* \param storage_class The storage class
* \return The corresponding spirv type.
*/
SType GetPointerType(const SType& value_type,
spv::StorageClass storage_class);
/*!
* \brief Get a struct{ value_type[num_elems] } type.
* \param value_type the content value type.
* \param num_elems number of elements in array
* num_elems = 0 means runtime array with BufferBlock Decoration
*
* \return The corresponding spirv type.
*/
SType GetStructArrayType(const SType& value_type,
uint32_t num_elems);
/*!
* \brief Get a struct array access with a given index.
* \param ptr_type The pointer type.
* \param buffer The buffer ptr to struct array
* \param index The array index.
*/
Value StructArrayAccess(const SType& ptr_type,
Value buffer,
Value index);
/*!
* \brief Create a cast that cast value to dst_type
* \param dst_type The target type.
* \param value the source value.
* \return The result value
*/
Value Cast(const SType& dst_type, Value value);
/*
* \brief Create a const integer.
* \param dtype The content data type.
* \param value The data value.
*/
Value IntImm(const SType& dtype, int64_t value);
/*
* \brief Create a const unsigned integer.
* \param dtype The content data type.
* \param value The data value.
*/
Value UIntImm(const SType& dtype, uint64_t value);
/*
* \brief Create a const float.
* \param dtype The content data type.
* \param value The data value.
*/
Value FloatImm(const SType& dtype, double value);
/*
* \brief Declare buffer argument of function
*
* \param arg_type The type of argument.
* \param descriptor_set The descriptor set we want to use.
* \param binding The binding locaiton in descriptor set.
* \param The argument type.
*/
Value BufferArgument(const SType& value_type,
uint32_t descriptor_set,
uint32_t binding);
/*!
* \brief Declare POD arguments through push constants.
*
* \note Only call this function once!
* \param value_types The values in the push constant
* \return reference to self.
*/
Value DeclarePushConstant(const std::vector<SType>& value_types);
/*!
* \brief Get i-th push constant
* \param v_type The value type
* \param index The push constant index
* \return the value of push constant
*/
Value GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index);
/*!
* \brief Declare a kernel function
* \param name Name of the entry point.
* \return The created function ID.
*/
Value DeclareKenrelFunction(const std::string& name);
/*!
* \brief Start function scope.
* \param func function to be started.
*/
void StartFunction(const Value& func);
/*!
* \brief Set the local size of the function
* \param func function of interest
* \param local_size The local workgroup_size
*/
void SetLocalSize(const Value& func, uint32_t local_size[3]);
/*
* \brief Allocate space
* \param value_type The content value type
* \param num_elems Number of elements to allocate.
* \param storage_class The storage class we want to store to.
*/
Value Allocate(const SType& value_type,
uint32_t num_elems,
spv::StorageClass storage_class);
/*
* \brief Get the i-th workgroup id.
* \return The value representing the workgroup id.
*/
Value GetWorkgroupID(uint32_t dim_index);
/*
* \brief Get the i-th local id.
* \return The value representing the local id.
*/
Value GetLocalID(uint32_t dim_index);
// Expressions
Value Add(Value a, Value b);
Value Sub(Value a, Value b);
Value Mul(Value a, Value b);
Value Div(Value a, Value b);
Value Mod(Value a, Value b);
Value EQ(Value a, Value b);
Value NE(Value a, Value b);
Value LT(Value a, Value b);
Value LE(Value a, Value b);
Value GT(Value a, Value b);
Value GE(Value a, Value b);
Value Select(Value cond, Value a, Value b);
private:
/*!
* \brief Create new value
* \return The created new label
*/
Value NewValue(const SType& stype, ValueKind flag) {
Value val;
val.id = id_counter_++;
val.stype = stype;
val.flag = flag;
return val;
}
// get constant given value encoded in uint64_t
Value GetConst_(const SType& dtype, const uint64_t* pvalue);
// declare type
SType DeclareType(const Type& dtype);
/*! \brief internal instruction builder */
InstrBuilder ib_;
/*! \brief Current label */
Label curr_label_;
/*! \brief The current maximum id */
uint32_t id_counter_{1};
/*! \brief glsl 450 extension */
Value ext_glsl450_;
/*! \brief Special cache int32, fp32, void*/
SType t_bool_, t_int32_, t_uint32_, t_fp32_, t_void_, t_void_func_;
/*! \brief quick cache for const one i32 */
Value const_i32_zero_;
/*! \brief cache value for workgroup_id, local_id */
Value workgroup_id_, local_id_;
/*! \brief whether push constant is defined */
Value push_const_;
/*! \brief map from type code to the type */
std::unordered_map<uint32_t, SType> pod_type_tbl_;
/*! \brief map from value to array type */
std::map<std::pair<uint32_t, uint32_t>, SType> struct_array_type_tbl_;
/*! \brief map from value to its pointer type */
std::map<std::pair<uint32_t, spv::StorageClass>, SType> pointer_type_tbl_;
/*! \brief map from constant int to its value */
std::map<std::pair<uint32_t, uint64_t>, Value> const_tbl_;
/*! \brief Header segment, include import */
std::vector<uint32_t> header_;
/*! \brief engtry point segment */
std::vector<uint32_t> entry_;
/*! \brief Header segment */
std::vector<uint32_t> exec_mode_;
/*! \brief Debug segment */
std::vector<uint32_t> debug_;
/*! \brief Annotation segment */
std::vector<uint32_t> decorate_;
/*! \brief Global segment: types, variables, types */
std::vector<uint32_t> global_;
/*! \brief Function segment */
std::vector<uint32_t> function_;
};
} // namespace spirv
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_SPIRV_IR_BUILDER_H_
......@@ -59,7 +59,7 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
return VisitExpr(op->a);
}
bool VisitExpr_(const Let* op) final {
return VisitExpr(op->body) && VisitExpr(op->value);
return VisitExpr(op->body) || VisitExpr(op->value);
}
bool VisitExpr_(const Cast* op) final {
return VisitExpr(op->value);
......@@ -84,7 +84,7 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
private:
template<typename T>
bool BinaryOp(const T* op) {
return VisitExpr(op->a) && VisitExpr(op->b);
return VisitExpr(op->a) || VisitExpr(op->b);
}
};
......
......@@ -903,20 +903,42 @@ class VectorAllocRewriter : public IRMutator {
return stmt;
}
private:
void UpdateTypeMap(const Variable* buffer, Type t) {
auto& tvec = acc_map_[buffer];
if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
tvec.push_back(t);
}
}
// Internal access map
std::unordered_map<const Variable*,
std::vector<Type> > acc_map_;
std::unordered_map<const Variable*, std::vector<Type> > acc_map_;
};
LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
std::shared_ptr<LoweredFuncNode> n =
std::make_shared<LoweredFuncNode>(*f.operator->());
VectorAllocRewriter rewriter;
n->body = rewriter.Mutate(n->body);
for (Var arg : f->args) {
if (arg.type().is_handle()) {
const auto& tvec = rewriter.acc_map_[arg.get()];
if (tvec.size() == 1) {
Expr dtype = make_const(tvec[0], 0);
n->handle_data_type.Set(arg, dtype);
} else {
// always set data type to be non vectorized so
// load/store can still work via scalarization
if (tvec.size() != 0 && !n->handle_data_type.count(arg)) {
Expr dtype = make_const(tvec[0].with_lanes(1), 0);
n->handle_data_type.Set(arg, dtype);
}
}
}
}
return LoweredFunc(n);
}
Stmt StorageRewrite(Stmt stmt) {
stmt = StoragePlanRewriter().Rewrite(stmt, true);
return VectorAllocRewriter().Mutate(stmt);
......
......@@ -28,6 +28,7 @@ inline std::string DeviceName(int type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
......
......@@ -119,6 +119,8 @@ bool RuntimeEnabled(const std::string& target) {
f_name = "device_api.opengl";
} else if (target == "mtl" || target == "metal") {
f_name = "device_api.metal";
} else if (target == "vulkan") {
f_name = "device_api.vulkan";
} else if (target == "stackvm") {
f_name = "codegen.build_stackvm";
} else if (target == "rpc") {
......
......@@ -44,12 +44,13 @@ class ROCMDeviceAPI final : public DeviceAPI {
value = 64;
break;
}
case kComputeVersion:
case kComputeVersion: {
hipDeviceProp_t prop;
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
*rv = prop.gcnArch;
return;
}
}
*rv = value;
}
void* AllocDataSpace(TVMContext ctx,
......
/*!
* Copyright (c) 2017 by Contributors
* \file vulkan_common.h
* \brief Vulkan common header
*/
#ifndef TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
#define TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
#include <tvm/runtime/config.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <dmlc/logging.h>
#if TVM_VULKAN_RUNTIME
#include <vulkan/vulkan.h>
#include <mutex>
#include <string>
#include <vector>
#include "../workspace_pool.h"
namespace tvm {
namespace runtime {
namespace vulkan {
inline const char* VKGetErrorString(VkResult error) {
switch (error) {
case VK_SUCCESS: return "VK_SUCCESS";
case VK_NOT_READY: return "VK_NOT_READY";
case VK_TIMEOUT: return "VK_TIMEOUT";
case VK_EVENT_SET: return "VK_EVENT_SET";
case VK_EVENT_RESET: return "VK_EVENT_RESET";
case VK_INCOMPLETE: return "VK_INCOMPLETE";
case VK_ERROR_OUT_OF_HOST_MEMORY: return "VK_ERROR_OUT_OF_HOST_MEMORY";
case VK_ERROR_OUT_OF_DEVICE_MEMORY: return "VK_ERROR_OUT_OF_DEVICE_MEMORY";
case VK_ERROR_INITIALIZATION_FAILED: return "VK_ERROR_INITIALIZATION_FAILED";
case VK_ERROR_DEVICE_LOST: return "VK_ERROR_DEVICE_LOST";
case VK_ERROR_MEMORY_MAP_FAILED: return "VK_ERROR_MEMORY_MAP_FAILED";
case VK_ERROR_LAYER_NOT_PRESENT: return "VK_ERROR_LAYER_NOT_PRESENT";
case VK_ERROR_EXTENSION_NOT_PRESENT: return "VK_ERROR_EXTENSION_NOT_PRESENT";
case VK_ERROR_FEATURE_NOT_PRESENT: return "VK_ERROR_FEATURE_NOT_PRESENT";
case VK_ERROR_INCOMPATIBLE_DRIVER: return "VK_ERROR_INCOMPATIBLE_DRIVER";
case VK_ERROR_TOO_MANY_OBJECTS: return "VK_ERROR_TOO_MANY_OBJECTS";
case VK_ERROR_FORMAT_NOT_SUPPORTED: return "VK_ERROR_FORMAT_NOT_SUPPORTED";
case VK_ERROR_FRAGMENTED_POOL: return "VK_ERROR_FRAGMENTED_POOL";
default: return "Unknown Vulkan error code";
}
}
/*!
* \brief Protected Vulkan call
* \param func Expression to call.
*/
#define VULKAN_CHECK_ERROR(__e) \
{ \
CHECK(__e == VK_SUCCESS) \
<< "Vulan Error, code=" << __e << ": " << vulkan::VKGetErrorString(__e); \
}
#define VULKAN_CALL(func) \
{ \
VkResult __e = (func); \
VULKAN_CHECK_ERROR(__e); \
}
/*! \brief Auxiliary context structure for vulkan */
struct VulkanContext {
// phyiscal device
VkPhysicalDevice phy_device{nullptr};
// Phyiscal device property
VkPhysicalDeviceProperties phy_device_prop;
// Memory type index for staging.
uint32_t staging_mtype_index{0};
// whether staging is coherent
bool coherent_staging{false};
// Memory type index for compute
uint32_t compute_mtype_index{0};
// The logical device
VkDevice device{nullptr};
// command queue
VkQueue queue{nullptr};
// queue family_index;
uint32_t queue_family_index{0};
// Queue family index.
VkQueueFamilyProperties queue_prop;
};
/*! \brief The buffer object */
struct VulkanBuffer {
/*! \brief underlying buffer */
VkBuffer buffer{nullptr};
/*! \brief underlying buffer */
VkDeviceMemory memory{nullptr};
};
/*! \brief Buffer only used for stagging */
struct VulkanStagingBuffer {
/*! \brief the corresponding device */
VkDevice device{nullptr};
/*! \brief underlying buffer */
VkBuffer buffer{nullptr};
/*! \brief underlying buffer */
VkDeviceMemory memory{nullptr};
/*! \brief host address */
void* host_addr{nullptr};
/*! \brief size of the memory */
size_t size{0};
};
/*!
* \brief Process global Vulkan workspace.
*/
class VulkanWorkspace final : public DeviceAPI {
public:
// global mutex
std::mutex mu;
// whether the workspace it initialized.
bool initialized_{false};
// vulkan instance
VkInstance instance_{nullptr};
// The physical devices, have 1 to 1 mapping to devices
std::vector<VulkanContext> context_;
// Destructor
~VulkanWorkspace();
// Initialize workspace
// Return false if already initialized, otherwise return true.
void Init();
// override device API
void SetDevice(TVMContext ctx) final;
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
TVMType type_hint) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from,
size_t from_size,
void* to,
size_t to_size,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace
static const std::shared_ptr<VulkanWorkspace>& Global();
};
/*! \brief Helper command buffer resource */
struct VulkanCommandBuffer {
/*! \brief fence to signal the resource is ready to use */
VkFence fence{nullptr};
/*! \brief The internal command buffer */
VkCommandBuffer cmd_buffer{nullptr};
/*! \brief Descriptor set used to bind arguments */
VkDescriptorSet descriptor_set{nullptr};
/*! \brief Internal utilities for write command */
VkWriteDescriptorSet write_descriptor_set;
VulkanCommandBuffer() {
write_descriptor_set.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
write_descriptor_set.pNext = nullptr;
write_descriptor_set.dstSet = nullptr;
write_descriptor_set.dstBinding = 0;
write_descriptor_set.dstArrayElement = 0;
write_descriptor_set.descriptorCount = 1;
write_descriptor_set.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
write_descriptor_set.pImageInfo = nullptr;
write_descriptor_set.pBufferInfo = nullptr;
write_descriptor_set.pTexelBufferView = nullptr;
}
};
/*!
* \brief Command pool backed by a fixed size ring buffer.
*
* Vulkan requires us not to reuse command buffer until
* All its corresponding jobs have finished.
*
* This class to faciliate automatic management
* of the command buffers. A fence is created
* for each launch of command buffer jobs
* and when we try to reuse the same entry
* in the ring, we need to make sure that
* the previous pending job already finishes.
*
*/
class VulkanCommandPool {
public:
/*! \brief Maximum number of pending jobs in the pool */
static constexpr const int kMaxPending = 4;
/*! \brief Maximum number of pending jobs in the pool */
static constexpr const int kMaxNumArgs = 16;
/*!
* \brief constructor
* \param vctx The corresponding vulkan context.
*/
explicit VulkanCommandPool(const VulkanContext& vctx);
/*! \brief destructor */
~VulkanCommandPool();
/*!
* \brief Allocate a new command buffer entry
*
* The caller must only submit the entry once
* with the given fence in the entry,
* before calling next Alloc.
*
* This function may block to wait for a
* previously unfinished command when
* there is more than kMaxPending jobs.
*
* \returns The allocated entry.
*/
VulkanCommandBuffer* Alloc();
/*!
* \brief Allocate a new command buffer entry
* \param dlayout the descriptor layout.
*
* \returns The allocated entry.
*/
VulkanCommandBuffer* Alloc(const VkDescriptorSetLayout* dlayout);
private:
/*! \brief Local ring buffer */
std::vector<VulkanCommandBuffer> ring_;
/*! \brief clock pointer */
size_t clock_ptr_{0};
/*! \brief the corresponding device*/
VkDevice device_{nullptr};
/*! \brief internal command buffer pool */
VkCommandPool cmd_pool_{nullptr};
/*! \brief Descriptor pool */
VkDescriptorPool descriptor_pool_{nullptr};
};
/*! \brief Thread local workspace */
class VulkanThreadEntry {
public:
/*! \brief The current context */
TVMContext context;
/*! \brief workspace pool */
WorkspacePool pool;
/*! \brief The staging buffers */
std::vector<VulkanStagingBuffer> staging_buffer_;
/*!
* \brief Get the command pool of corresponding device;
* \param device_id The device id
* \return The corresponding command buffer.
*/
VulkanCommandPool* CommandPool(int device_id);
/*!
* \brief Get the stagging buffer.
* \param device_id The device id
* \return The corresponding stagging buffer.
*/
VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
// constructor
VulkanThreadEntry()
: pool(static_cast<DLDeviceType>(kDLVulkan), VulkanWorkspace::Global()) {
context.device_id = 0;
context.device_type = static_cast<DLDeviceType>(kDLVulkan);
}
~VulkanThreadEntry();
// get the global workspace
static VulkanThreadEntry* ThreadLocal();
private:
/*! \brief the command pools */
std::vector<std::unique_ptr<VulkanCommandPool> > pool_;
};
// inline implementation
} // namespace vulkan
} // namespace runtime
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
#endif // TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
/*!
* Copyright (c) 2017 by Contributors
* \file vulkan_device_api.cc
*/
#include "./vulkan_common.h"
#if TVM_VULKAN_RUNTIME
#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
#include <cstring>
namespace tvm {
namespace runtime {
namespace vulkan {
VulkanWorkspace::~VulkanWorkspace() {
for (VulkanContext& ctx : context_) {
vkDestroyDevice(ctx.device, nullptr);
}
if (instance_ != nullptr) {
vkDestroyInstance(instance_, nullptr);
}
}
const std::shared_ptr<VulkanWorkspace>& VulkanWorkspace::Global() {
static std::shared_ptr<VulkanWorkspace> inst = std::make_shared<VulkanWorkspace>();
return inst;
}
void VulkanWorkspace::SetDevice(TVMContext ctx) {
VulkanThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
}
void VulkanWorkspace::GetAttr(
TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
this->Init();
size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) {
*rv = static_cast<int>(index< context_.size());
return;
}
CHECK_LT(index, context_.size())
<< "Invalid device id " << index;
switch (kind) {
case kMaxThreadsPerBlock: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
int64_t value = phy_prop.limits.maxComputeWorkGroupSize[0];
*rv = value;
break;
}
case kWarpSize: {
*rv = 1;
break;
}
case kComputeVersion: {
VkPhysicalDeviceProperties phy_prop;
vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
int64_t value = phy_prop.apiVersion;
std::ostringstream os;
os << VK_VERSION_MAJOR(value)
<< "." << VK_VERSION_MINOR(value)
<< "." << VK_VERSION_PATCH(value);
*rv = os.str();
break;
}
case kExist: break;
}
}
void* VulkanWorkspace::AllocDataSpace(
TVMContext ctx, size_t size, size_t alignment, TVMType type_hint) {
this->Init();
VulkanContext& vctx = context_[ctx.device_id];
VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.size = size;
info.queueFamilyIndexCount = 1;
info.pQueueFamilyIndices = &(vctx.queue_family_index);
info.usage =
VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
VK_BUFFER_USAGE_TRANSFER_DST_BIT |
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
// create buffer
VkBuffer buffer;
VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
// bind to memory
VkMemoryAllocateInfo minfo;
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
minfo.pNext = nullptr;
minfo.allocationSize = size;
minfo.memoryTypeIndex = vctx.compute_mtype_index;
VkDeviceMemory memory;
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
VulkanBuffer* pbuf = new VulkanBuffer();
pbuf->memory = memory;
pbuf->buffer = buffer;
return pbuf;
}
void VulkanWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
VulkanContext& vctx = context_[ctx.device_id];
VulkanBuffer* pbuf = static_cast<VulkanBuffer*>(ptr);
vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr);
vkFreeMemory(vctx.device, pbuf->memory, nullptr);
delete pbuf;
}
void VulkanWorkspace::CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) {
this->Init();
CHECK(stream == nullptr);
TVMContext ctx = ctx_from;
if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
VulkanThreadEntry* tls = VulkanThreadEntry::ThreadLocal();
VulkanCommandBuffer* cmd = tls->CommandPool(ctx.device_id)->Alloc();
VkCommandBufferBeginInfo cb_begin;
cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
cb_begin.pNext = nullptr;
cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
cb_begin.pInheritanceInfo = 0;
VkSubmitInfo cb_submit;
cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
cb_submit.pNext = nullptr;
cb_submit.waitSemaphoreCount = 0;
cb_submit.pWaitSemaphores = nullptr;
cb_submit.pWaitDstStageMask = 0;
cb_submit.commandBufferCount = 1;
cb_submit.pCommandBuffers = &(cmd->cmd_buffer);
cb_submit.signalSemaphoreCount = 0;
cb_submit.pSignalSemaphores = nullptr;
int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type);
if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) {
CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
<< "Vulkan disallow cross device copy.";
const VulkanContext& vctx = context_[ctx_from.device_id];
const VulkanBuffer* from_buf = static_cast<const VulkanBuffer*>(from);
VulkanBuffer* to_buf = static_cast<VulkanBuffer*>(to);
// The assumption is that subsequence ops only perform compute/transfer
// 0: begin
VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
// 1: copy
VkBufferCopy copy_info;
copy_info.srcOffset = from_offset;
copy_info.dstOffset = to_offset;
copy_info.size = size;
vkCmdCopyBuffer(cmd->cmd_buffer, from_buf->buffer, to_buf->buffer, 1, &copy_info);
// 2: barrier(transfer-> compute|transfer)
VkMemoryBarrier barrier_info;
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
barrier_info.pNext = nullptr;
barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
barrier_info.dstAccessMask =
(VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
vkCmdPipelineBarrier(
cmd->cmd_buffer,
VK_PIPELINE_STAGE_TRANSFER_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
0, 1, &barrier_info, 0, nullptr, 0, nullptr);
// 3: end
VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
// 4: submit with cmd->fence
VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
} else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) {
const VulkanContext& vctx = context_[ctx_from.device_id];
const VulkanBuffer* from_buf = static_cast<const VulkanBuffer*>(from);
VulkanStagingBuffer* temp = tls->StagingBuffer(ctx_from.device_id, size);
// 0: begin
VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
// 1: copy
VkBufferCopy copy_info;
copy_info.srcOffset = from_offset;
copy_info.dstOffset = 0;
copy_info.size = size;
vkCmdCopyBuffer(cmd->cmd_buffer,
from_buf->buffer,
temp->buffer,
1, &copy_info);
// 2: end
VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
// 4: submit with cmd->fence
VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
// Block until done, to make sure temp can be reused later.
VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
// host side invalidation if access is not coherent.
// so writes from GPU is visible to CPU
if (!vctx.coherent_staging) {
VkMappedMemoryRange mrange;
mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
mrange.pNext = nullptr;
mrange.memory = temp->memory;
mrange.offset = 0;
mrange.size = size;
VULKAN_CALL(vkInvalidateMappedMemoryRanges(
vctx.device, 1, &mrange));
}
memcpy(static_cast<char*>(to) + to_offset,
static_cast<char*>(temp->host_addr),
size);
} else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) {
const VulkanContext& vctx = context_[ctx_to.device_id];
const VulkanBuffer* to_buf = static_cast<const VulkanBuffer*>(to);
VulkanStagingBuffer* temp = tls->StagingBuffer(ctx_to.device_id, size);
memcpy(temp->host_addr,
static_cast<const char*>(from) + from_offset,
size);
// host side flush if access is not coherent.
// so writes from CPU is visible to GPU
if (!vctx.coherent_staging) {
VkMappedMemoryRange mrange;
mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
mrange.pNext = nullptr;
mrange.memory = temp->memory;
mrange.offset = 0;
mrange.size = size;
VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange));
}
VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
// 0: barrier(host->transfer)
VkMemoryBarrier barrier_info;
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
barrier_info.pNext = nullptr;
barrier_info.srcAccessMask = 0;
barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
vkCmdPipelineBarrier(cmd->cmd_buffer,
VK_PIPELINE_STAGE_HOST_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT,
0, 1, &barrier_info,
0, nullptr, 0, nullptr);
// 1: copy
VkBufferCopy copy_info;
copy_info.srcOffset = 0;
copy_info.dstOffset = to_offset;
copy_info.size = size;
vkCmdCopyBuffer(cmd->cmd_buffer,
temp->buffer,
to_buf->buffer,
1, &copy_info);
// 2: end
VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
// 4: submit with cmd->fence
VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
// wait until copy finishes, so we can reuse temp next time.
VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
} else {
LOG(FATAL) << "Expect copy from/to Metal or between Metal"
<< ", from=" << from_dev_type
<< ", to=" << to_dev_type;
}
}
void VulkanWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
CHECK(stream == nullptr);
VulkanContext& vctx = context_[ctx.device_id];
VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
}
void* VulkanWorkspace::AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) {
return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
void VulkanWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}
// VulkanCommandPool
VulkanCommandPool::VulkanCommandPool(const VulkanContext& vctx) {
ring_.resize(kMaxPending, VulkanCommandBuffer());
device_ = vctx.device;
{
// create command pool
VkCommandPoolCreateInfo cmd_pool_cinfo;
cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
cmd_pool_cinfo.pNext = nullptr;
cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
cmd_pool_cinfo.queueFamilyIndex = vctx.queue_family_index;
VULKAN_CALL(vkCreateCommandPool(device_, &cmd_pool_cinfo, nullptr, &cmd_pool_));
}
{
// create descriptor pool
VkDescriptorPoolSize pool_size;
pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
pool_size.descriptorCount = kMaxPending * kMaxNumArgs;
VkDescriptorPoolCreateInfo descrip_pool_cinfo;
descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
descrip_pool_cinfo.pNext = nullptr;
descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
descrip_pool_cinfo.maxSets = kMaxPending + 2;
descrip_pool_cinfo.poolSizeCount = 1;
descrip_pool_cinfo.pPoolSizes = &pool_size;
VULKAN_CALL(vkCreateDescriptorPool(
device_, &descrip_pool_cinfo, nullptr, &descriptor_pool_));
}
VkCommandBufferAllocateInfo buffer_alloc_info;
buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
buffer_alloc_info.pNext = nullptr;
buffer_alloc_info.commandPool = cmd_pool_;
buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
buffer_alloc_info.commandBufferCount = 1;
VkFenceCreateInfo fence_cinfo;
fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
fence_cinfo.pNext = nullptr;
fence_cinfo.flags = VK_FENCE_CREATE_SIGNALED_BIT;
for (size_t i = 0; i < ring_.size(); ++i) {
VULKAN_CALL(vkAllocateCommandBuffers(
device_, &buffer_alloc_info, &(ring_[i].cmd_buffer)));
VULKAN_CALL(vkCreateFence(
device_, &fence_cinfo, nullptr, &(ring_[i].fence)));
}
}
VulkanCommandPool::~VulkanCommandPool() {
// wait device to be idle so we know we can recycle buffers
VULKAN_CALL(vkDeviceWaitIdle(device_));
// start recycling.
for (size_t i = 0; i < ring_.size(); ++i) {
if (ring_[i].cmd_buffer != nullptr) {
vkFreeCommandBuffers(device_, cmd_pool_, 1, &(ring_[i].cmd_buffer));
ring_[i].cmd_buffer = nullptr;
}
if (ring_[i].fence != nullptr) {
vkDestroyFence(device_, ring_[i].fence, nullptr);
}
}
// delete cmd_pool and descriptor pool
vkDestroyCommandPool(device_, cmd_pool_, nullptr);
vkDestroyDescriptorPool(device_, descriptor_pool_, nullptr);
}
VulkanCommandBuffer* VulkanCommandPool::Alloc() {
return Alloc(nullptr);
}
VulkanCommandBuffer* VulkanCommandPool::Alloc(
const VkDescriptorSetLayout* dlayout) {
// always allocate resource in round robin manner
VulkanCommandBuffer* e = &(ring_[clock_ptr_]);
clock_ptr_ = (clock_ptr_ + 1) % ring_.size();
// Wait until previous usage of commad buffer is finished.
uint64_t timeout = 1UL << 30UL;
VkResult res;
res = vkWaitForFences(device_, 1, &(e->fence), 0, timeout);
while (res == VK_TIMEOUT) {
res = vkWaitForFences(device_, 1, &(e->fence), 0, timeout);
}
VULKAN_CHECK_ERROR(res);
vkResetFences(device_, 1, (&e->fence));
if (e->descriptor_set != nullptr) {
VULKAN_CALL(vkFreeDescriptorSets(
device_, descriptor_pool_, 1, &(e->descriptor_set)));
e->descriptor_set = nullptr;
}
if (dlayout != nullptr) {
VkDescriptorSetAllocateInfo alloc_info;
alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
alloc_info.pNext = nullptr;
alloc_info.descriptorPool = descriptor_pool_;
alloc_info.descriptorSetCount = 1;
alloc_info.pSetLayouts = dlayout;
VULKAN_CALL(vkAllocateDescriptorSets(
device_, &alloc_info, &(e->descriptor_set)));
}
return e;
}
// VulkanThreadEntry
typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() {
return VulkanThreadStore::Get();
}
VulkanCommandPool* VulkanThreadEntry::CommandPool(int device_id) {
while (pool_.size() <= static_cast<size_t>(device_id)) {
pool_.emplace_back(std::unique_ptr<VulkanCommandPool>());
}
if (pool_[device_id] == nullptr) {
const VulkanContext& vctx =
VulkanWorkspace::Global()->context_[device_id];
pool_[device_id].reset(new VulkanCommandPool(vctx));
}
return pool_[device_id].get();
}
VulkanStagingBuffer*
VulkanThreadEntry::StagingBuffer(int device_id, size_t size) {
if (staging_buffer_.size() <= static_cast<size_t>(device_id)) {
staging_buffer_.resize(device_id + 1, VulkanStagingBuffer());
}
VulkanStagingBuffer& buf = staging_buffer_[device_id];
if (buf.device != nullptr && buf.size < size) {
// free previous buffer
if (buf.host_addr != nullptr) {
vkUnmapMemory(buf.device, buf.memory);
}
if (buf.memory != nullptr) {
vkFreeMemory(buf.device, buf.memory, nullptr);
}
if (buf.buffer != nullptr) {
vkDestroyBuffer(buf.device, buf.buffer, nullptr);
}
buf.host_addr = nullptr;
buf.memory = nullptr;
buf.buffer = nullptr;
}
const VulkanContext& vctx =
VulkanWorkspace::Global()->context_[device_id];
if (buf.device == nullptr) {
buf.device = vctx.device;
}
if (buf.memory == nullptr) {
// allocate the stagging buffer memory if necessary
VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.size = size;
info.queueFamilyIndexCount = 1;
info.pQueueFamilyIndices = &(vctx.queue_family_index);
info.usage =
VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
VK_BUFFER_USAGE_TRANSFER_DST_BIT;
VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer)));
VkMemoryAllocateInfo minfo;
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
minfo.pNext = nullptr;
minfo.allocationSize = size;
minfo.memoryTypeIndex = vctx.staging_mtype_index;
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory)));
VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0));
VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr)));
buf.size = size;
}
memset(buf.host_addr, 0, size);
return &buf;
}
VulkanThreadEntry::~VulkanThreadEntry() {
// Because the thread entry refers to Device API
// The command buffer always will be destroyed before
// the instance and device get destroyed.
// The destruction need to be manually called
// to ensure the destruction order.
pool_.clear();
for (VulkanStagingBuffer buf : staging_buffer_) {
if (buf.host_addr != nullptr) {
vkUnmapMemory(buf.device, buf.memory);
}
if (buf.memory != nullptr) {
vkFreeMemory(buf.device, buf.memory, nullptr);
}
if (buf.buffer != nullptr) {
vkDestroyBuffer(buf.device, buf.buffer, nullptr);
}
}
}
VkInstance CreateInstance() {
VkApplicationInfo app_info;
app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
app_info.pNext = nullptr;
app_info.pApplicationName = "TVM";
app_info.applicationVersion = 0;
app_info.pEngineName = "";
app_info.engineVersion = 0;
app_info.apiVersion = VK_MAKE_VERSION(1, 0, 65);
VkInstanceCreateInfo inst_info;
inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
inst_info.pNext = nullptr;
inst_info.flags = 0;
inst_info.pApplicationInfo = &app_info;
inst_info.enabledLayerCount = 0;
inst_info.ppEnabledLayerNames = nullptr;
inst_info.enabledExtensionCount = 0;
inst_info.ppEnabledExtensionNames = nullptr;
VkInstance inst;
VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &inst));
return inst;
}
// find suitable mem_type_index for staging and compute
void FindMemoryTypeIndex(VulkanContext* vctx) {
// Find suitable compute index.
VkBuffer buffer;
VkMemoryRequirements req_staging, req_compute;
VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.size = 1024;
info.queueFamilyIndexCount = 1;
info.pQueueFamilyIndices = &(vctx->queue_family_index);
// get staging requirement
info.usage =
VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
VK_BUFFER_USAGE_TRANSFER_DST_BIT;
VULKAN_CALL(vkCreateBuffer(vctx->device, &info, nullptr, &buffer));
vkGetBufferMemoryRequirements(vctx->device, buffer, &req_staging);
vkDestroyBuffer(vctx->device, buffer, nullptr);
// get compute requirement
info.usage =
VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
VK_BUFFER_USAGE_TRANSFER_DST_BIT |
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
VULKAN_CALL(vkCreateBuffer(vctx->device, &info, nullptr, &buffer));
vkGetBufferMemoryRequirements(vctx->device, buffer, &req_compute);
vkDestroyBuffer(vctx->device, buffer, nullptr);
// Query phyiscal device property
// find a memory that is host visible, no need to be consistent
int win_rank = -1;
VkPhysicalDeviceMemoryProperties prop;
vkGetPhysicalDeviceMemoryProperties(vctx->phy_device, &prop);
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
// match copy requirment
if (!(req_staging.memoryTypeBits & (1 << k))) continue;
if (heap_size < 1024) continue;
int rank = 0;
rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
if (rank > win_rank) {
win_rank = rank;
vctx->staging_mtype_index = k;
vctx->coherent_staging =
ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
}
}
CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
win_rank = -1;
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
VkMemoryType ty = prop.memoryTypes[k];
size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
// host visible
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
// match copy requirment
if (!(req_staging.memoryTypeBits & (1 << k))) continue;
if (heap_size < 1024) continue;
int rank = 0;
// prefer not host visible
rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
if (rank > win_rank) {
win_rank = rank;
vctx->compute_mtype_index = k;
}
}
CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
}
// Get all logic devices that support compute
std::vector<VulkanContext> GetContext(VkInstance instance) {
std::vector<VulkanContext> result;
uint32_t phy_dev_count = 0;
VULKAN_CALL(vkEnumeratePhysicalDevices(
instance, &phy_dev_count, nullptr));
std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
VULKAN_CALL(vkEnumeratePhysicalDevices(
instance, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
for (VkPhysicalDevice phy_dev : all_phy_devs) {
uint32_t queue_prop_count = 0;
vkGetPhysicalDeviceQueueFamilyProperties(
phy_dev, &queue_prop_count, nullptr);
std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
vkGetPhysicalDeviceQueueFamilyProperties(
phy_dev, &queue_prop_count, dmlc::BeginPtr(queue_props));
uint32_t queue_family_index = 0;
std::vector<VkDeviceQueueCreateInfo> queue_create_info;
for (uint32_t i = 0; i < queue_props.size(); i++) {
// find queues that support compute
if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
float priority = 1.0f;
VkDeviceQueueCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
info.pNext = nullptr;
info.flags = 0;
info.queueFamilyIndex = i;
info.queueCount = 1;
info.pQueuePriorities = &priority;
queue_create_info.push_back(info);
// only use the first available queue for now
if (queue_create_info.size() == 0) {
queue_family_index = i;
}
}
}
if (queue_create_info.size() == 0) continue;
VkDeviceCreateInfo device_create_info;
device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
device_create_info.pNext = nullptr;
device_create_info.flags = 0;
device_create_info.queueCreateInfoCount
= static_cast<uint32_t>(queue_create_info.size());
device_create_info.pQueueCreateInfos = queue_create_info.data();
device_create_info.enabledLayerCount = 0;
device_create_info.ppEnabledLayerNames = nullptr;
device_create_info.enabledExtensionCount = 0;
device_create_info.ppEnabledExtensionNames = nullptr;
device_create_info.pEnabledFeatures = nullptr;
VulkanContext ctx;
// setup context
ctx.phy_device = phy_dev;
vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop));
VULKAN_CALL(vkCreateDevice(
phy_dev, &device_create_info, nullptr, &(ctx.device)));
vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
ctx.queue_family_index = queue_family_index;
FindMemoryTypeIndex(&ctx);
// Find suitable memory type for staging and compute
result.push_back(ctx);
}
return result;
}
void VulkanWorkspace::Init() {
if (initialized_) return;
std::lock_guard<std::mutex>(this->mu);
if (initialized_) return;
initialized_ = true;
instance_ = CreateInstance();
context_ = GetContext(instance_);
LOG(INFO) << "Initialzie Vulkan with " << context_.size() << " devices..";
for (size_t i = 0; i < context_.size(); ++i) {
LOG(INFO) << "vulkan(" << i
<< ")=\'" << context_[i].phy_device_prop.deviceName
<< "\' phy_dev_id=" << context_[i].phy_device;
}
}
bool InitVulkan(TVMArgs args, TVMRetValue* rv) {
vulkan::VulkanWorkspace::Global()->Init();
return true;
}
TVM_REGISTER_GLOBAL("device_api.vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = VulkanWorkspace::Global().get();
*rv = static_cast<void*>(ptr);
});
} // namespace vulkan
} // namespace runtime
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
/*!
* Copyright (c) 2018 by Contributors
* \file vulkan_module.cc
*/
#include "./vulkan_module.h"
#if TVM_VULKAN_RUNTIME
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <array>
#include <string>
#include <mutex>
#include "./vulkan_common.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "../meta_data.h"
#include "../file_util.h"
namespace tvm {
namespace runtime {
void VulkanShader::Save(dmlc::Stream* writer) const {
writer->Write(flag);
writer->Write(data);
}
bool VulkanShader::Load(dmlc::Stream* reader) {
if (!reader->Read(&flag)) return false;
if (!reader->Read(&data)) return false;
return true;
}
// Multi-device enabled module.
class VulkanModuleNode final :public runtime::ModuleNode {
public:
// Pipeline cache states
struct PipelineEntry {
VkShaderModule shader{nullptr};
VkPipelineLayout pipeline_layout{nullptr};
VkDescriptorSetLayout descriptor_layout{nullptr};
VkPipeline pipeline{nullptr};
};
// constructor
explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source)
: smap_(smap), fmap_(fmap), source_(source) {
}
~VulkanModuleNode() {
// cleanup vulkan related caches.
for (DeviceEntry& e : finfo_) {
if (e.device == nullptr) continue;
for (auto &kv : e.smap) {
PipelineEntry& pe = kv.second;
vkDestroyShaderModule(e.device, pe.shader, nullptr);
vkDestroyDescriptorSetLayout(e.device, pe.descriptor_layout, nullptr);
vkDestroyPipelineLayout(e.device, pe.pipeline_layout, nullptr);
vkDestroyPipeline(e.device, pe.pipeline, nullptr);
}
}
}
const char* type_key() const final {
return "vulkan";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_)
<< "Can only save to customized format vulkan";
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
std::string data_bin;
dmlc::MemoryStringStream fs(&data_bin);
dmlc::Stream* stream = &fs;
uint32_t magic = kVulkanModuleMagic;
stream->Write(magic);
stream->Write(smap_);
SaveBinaryToFile(file_name, data_bin);
}
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(smap_);
}
std::string GetSource(const std::string& format) final {
// can only return source code.
return source_;
}
// get a from primary context in device_id
PipelineEntry GetPipeline(size_t device_id,
const std::string& func_name,
size_t num_pack_args) {
vulkan::VulkanWorkspace* w = vulkan::VulkanWorkspace::Global().get();
CHECK_LT(device_id, w->context_.size());
// start lock scope.
std::lock_guard<std::mutex> lock(mutex_);
if (finfo_.size() <= device_id) {
finfo_.resize(device_id + 1, DeviceEntry());
}
DeviceEntry& e = finfo_[device_id];
auto it = e.smap.find(func_name);
if (it != e.smap.end()) return it->second;
PipelineEntry pe;
if (e.device == nullptr) {
e.device = w->context_[device_id].device;
}
{
// create shader
auto sit = smap_.find(func_name);
CHECK(sit != smap_.end());
const std::vector<uint32_t>& data = sit->second.data;
VkShaderModuleCreateInfo shader_cinfo;
shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
shader_cinfo.pNext = nullptr;
shader_cinfo.flags = 0;
shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
shader_cinfo.pCode = data.data();
VULKAN_CALL(vkCreateShaderModule(
e.device, &shader_cinfo, nullptr, &(pe.shader)));
}
std::vector<VkDescriptorSetLayoutBinding> arg_binding;
uint32_t num_pod = 0, num_buffer = 0;
{
auto fit = fmap_.find(func_name);
CHECK(fit != fmap_.end());
for (TVMType arg_type : fit->second.arg_types) {
if (arg_type.code == kHandle) {
VkDescriptorSetLayoutBinding bd;
bd.binding = num_buffer;
bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
bd.descriptorCount = 1;
bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
bd.pImmutableSamplers = nullptr;
arg_binding.push_back(bd);
++num_buffer;
} else {
++num_pod;
}
}
}
VkDescriptorSetLayoutCreateInfo descrip_cinfo;
descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
descrip_cinfo.pNext = nullptr;
descrip_cinfo.flags = 0;
descrip_cinfo.bindingCount = arg_binding.size();
descrip_cinfo.pBindings = arg_binding.data();
VULKAN_CALL(vkCreateDescriptorSetLayout(
e.device, &descrip_cinfo, nullptr, &(pe.descriptor_layout)));
VkPushConstantRange crange;
crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
crange.offset = 0;
crange.size = sizeof(ArgUnion) * num_pack_args;
VkPipelineLayoutCreateInfo playout_cinfo;
playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
playout_cinfo.pNext = nullptr;
playout_cinfo.flags = 0;
playout_cinfo.setLayoutCount = 1;
playout_cinfo.pSetLayouts = &(pe.descriptor_layout);
if (num_pack_args != 0) {
playout_cinfo.pushConstantRangeCount = 1;
playout_cinfo.pPushConstantRanges = &crange;
CHECK_LE(crange.size,
w->context_[device_id].phy_device_prop.limits.maxPushConstantsSize);
} else {
playout_cinfo.pushConstantRangeCount = 0;
playout_cinfo.pPushConstantRanges = nullptr;
}
VULKAN_CALL(vkCreatePipelineLayout(
e.device, &playout_cinfo, nullptr, &(pe.pipeline_layout)));
VkComputePipelineCreateInfo pipeline_cinfo;
pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
pipeline_cinfo.pNext = nullptr;
pipeline_cinfo.flags = 0;
pipeline_cinfo.stage.sType =
VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
pipeline_cinfo.stage.pNext = nullptr;
pipeline_cinfo.stage.flags = 0;
pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
pipeline_cinfo.stage.module = pe.shader;
pipeline_cinfo.stage.pName = func_name.c_str();
pipeline_cinfo.stage.pSpecializationInfo = nullptr;
pipeline_cinfo.layout = pe.pipeline_layout;
pipeline_cinfo.basePipelineHandle = nullptr;
pipeline_cinfo.basePipelineIndex = 0;
VULKAN_CALL(vkCreateComputePipelines(
e.device, nullptr, 1, &pipeline_cinfo, nullptr, &(pe.pipeline)));
e.smap[func_name] = pe;
return pe;
}
private:
// device specific entry
struct DeviceEntry {
VkDevice device{nullptr};
std::unordered_map<std::string, PipelineEntry> smap;
};
// the binary data
std::vector<uint32_t> data_;
// function information table.
std::unordered_map<std::string, VulkanShader> smap_;
// function information table.
std::unordered_map<std::string, FunctionInfo> fmap_;
// The format
std::string fmt_{"vulkan"};
// The source
std::string source_;
// device local pipeline information.
std::vector<DeviceEntry> finfo_;
// internal mutex when updating the module
std::mutex mutex_;
};
// a wrapped function class to get packed fucn.
class VulkanWrappedFunc {
public:
// initialize the VULKAN function.
void Init(VulkanModuleNode* m,
std::shared_ptr<ModuleNode> sptr,
const std::string& func_name,
size_t num_buffer_args,
size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) {
w_ = vulkan::VulkanWorkspace::Global().get();
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
num_buffer_args_ = num_buffer_args;
num_pack_args_ = num_pack_args;
thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args,
TVMRetValue* rv,
const ArgUnion* pack_args) const {
vulkan::VulkanThreadEntry* tls = vulkan::VulkanThreadEntry::ThreadLocal();
int device_id = tls->context.device_id;
CHECK_LT(device_id, kVulkanMaxNumDevice);
const vulkan::VulkanContext& vctx = w_->context_[device_id];
VulkanModuleNode::PipelineEntry& pe = scache_[device_id];
if (pe.pipeline == nullptr) {
pe = m_->GetPipeline(device_id, func_name_, num_pack_args_);
}
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
vulkan::VulkanCommandBuffer* cmd = tls->CommandPool(device_id)->Alloc(
&(pe.descriptor_layout));
cmd->write_descriptor_set.dstSet = cmd->descriptor_set;
// setup descriptors
for (uint32_t i = 0; i < num_buffer_args_; ++i) {
void* buf = args[static_cast<int>(i)];
VkDescriptorBufferInfo binfo;
binfo.buffer = static_cast<vulkan::VulkanBuffer*>(buf)->buffer;
binfo.offset = 0;
binfo.range = VK_WHOLE_SIZE;
cmd->write_descriptor_set.dstBinding = i;
cmd->write_descriptor_set.pBufferInfo = &binfo;
vkUpdateDescriptorSets(
vctx.device, 1, &(cmd->write_descriptor_set), 0, nullptr);
}
// dispatch
VkCommandBufferBeginInfo cb_begin;
cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
cb_begin.pNext = nullptr;
cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
cb_begin.pInheritanceInfo = 0;
VkSubmitInfo cb_submit;
cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
cb_submit.pNext = nullptr;
cb_submit.waitSemaphoreCount = 0;
cb_submit.pWaitSemaphores = nullptr;
cb_submit.pWaitDstStageMask = 0;
cb_submit.commandBufferCount = 1;
cb_submit.pCommandBuffers = &(cmd->cmd_buffer);
cb_submit.signalSemaphoreCount = 0;
cb_submit.pSignalSemaphores = nullptr;
// 0: begin
VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
// 1: dispatch
vkCmdBindPipeline(
cmd->cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, pe.pipeline);
vkCmdBindDescriptorSets(
cmd->cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE,
pe.pipeline_layout, 0, 1, &(cmd->descriptor_set), 0, nullptr);
// bind push constant if necessary
if (num_pack_args_ != 0) {
vkCmdPushConstants(
cmd->cmd_buffer,
pe.pipeline_layout,
VK_SHADER_STAGE_COMPUTE_BIT,
0, num_pack_args_ * sizeof(ArgUnion),
pack_args);
}
vkCmdDispatch(
cmd->cmd_buffer, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
// 2: barrier(compute->compute|transfer)
VkMemoryBarrier barrier_info;
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
barrier_info.pNext = nullptr;
barrier_info.srcAccessMask =
VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
barrier_info.dstAccessMask =
(VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
vkCmdPipelineBarrier(
cmd->cmd_buffer,
VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
0, 1, &barrier_info, 0, nullptr, 0, nullptr);
// 3: end
VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
// 4: submit with cmd->fence
VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
}
private:
// Reference to global workspace.
vulkan::VulkanWorkspace* w_;
// internal module
VulkanModuleNode* m_;
// the resource holder
std::shared_ptr<ModuleNode> sptr_;
// The name of the function.
std::string func_name_;
// Number of buffer arguments
size_t num_buffer_args_;
// number of packed arguments.
size_t num_pack_args_;
// Device state cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<VulkanModuleNode::PipelineEntry, kVulkanMaxNumDevice> scache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
};
PackedFunc VulkanModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
VulkanWrappedFunc f;
size_t num_buffer_args = NumBufferArgs(info.arg_types);
f.Init(this, sptr_to_self, name,
num_buffer_args, info.arg_types.size() - num_buffer_args,
info.thread_axis_tags);
return PackFuncNonBufferArg(f, info.arg_types);
}
Module VulkanModuleCreate(
std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) {
vulkan::VulkanWorkspace::Global()->Init();
std::shared_ptr<VulkanModuleNode> n =
std::make_shared<VulkanModuleNode>(smap, fmap, source);
return Module(n);
}
// Load module from module.
Module VulkanModuleLoadFile(const std::string& file_name,
const std::string& format) {
std::string data;
std::unordered_map<std::string, VulkanShader> smap;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
dmlc::MemoryStringStream fs(&data);
dmlc::Stream* stream = &fs;
uint32_t magic;
stream->Read(&magic);
CHECK_EQ(magic, kVulkanModuleMagic)
<< "VulkanModule Magic mismatch";
stream->Read(&smap);
return VulkanModuleCreate(smap, fmap, "");
}
Module VulkanModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::unordered_map<std::string, VulkanShader> smap;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt;
stream->Read(&fmt);
stream->Read(&fmap);
stream->Read(&smap);
return VulkanModuleCreate(smap, fmap, "");
}
TVM_REGISTER_GLOBAL("module.loadfile_vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = VulkanModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadbinary_vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = VulkanModuleLoadBinary(args[0]);
});
} // namespace runtime
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
/*!
* Copyright (c) 2017 by Contributors
* \file metal_module.h
* \brief Execution handling of Metal kernels
*/
#ifndef TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
#define TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
#include <tvm/runtime/config.h>
#include <tvm/runtime/packed_func.h>
#include <dmlc/type_traits.h>
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "../meta_data.h"
namespace tvm {
namespace runtime {
/*! \brief Maximum number of GPU supported in VulkanModule. */
static constexpr const int kVulkanMaxNumDevice = 8;
/*! \brief TVM Vulkan binary pack magic number */
static constexpr const int kVulkanModuleMagic = 0x02700027;
/*!
* \brief A single VK shader program
*
* Due to the global resource declaration.
* Current SPIRV only allows one entry program per shader,
* making it less useful for a Module like system.
*
* Instead we pass in map of str->VulkanShader until
* there is a native solution available.
*/
struct VulkanShader {
/*! \brief header flag */
uint32_t flag{0};
/*! \brief Data segment */
std::vector<uint32_t> data;
void Save(dmlc::Stream *writer) const;
bool Load(dmlc::Stream *reader);
};
/*!
* \brief create a metal module from data.
*
* \param pmap The program map.
* \param fmap The function information map.
* \param source Optional, source code.
*/
Module VulkanModuleCreate(
std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source);
} // namespace runtime
} // namespace tvm
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::VulkanShader, true);
} // namespace dmlc
#endif // TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
......@@ -18,7 +18,8 @@ def test_exp():
def check_device(device, host="stackvm"):
if not tvm.module.enabled(host):
return
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
return
fexp = tvm.build(s, [A, B],
device, host,
......@@ -33,6 +34,7 @@ def test_exp():
b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
check_device("cuda", "llvm")
check_device("vulkan")
check_device("opencl")
......@@ -75,11 +77,12 @@ def test_popcount():
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
if str(ctx).startswith('gpu'):
target = tvm.target.create(device)
if "cpu" not in target.keys:
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
func = tvm.build(s, [A, B], device)
......@@ -95,6 +98,8 @@ def test_popcount():
check_device("cuda")
check_device("opencl")
check_device("metal")
if dtype == "uint32":
check_device("vulkan")
run('uint32')
run('uint64')
......@@ -121,14 +126,14 @@ def test_add():
# one line to build the function.
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
fadd = tvm.build(s, [A, B, C],
device,
name="myadd")
print(fadd.imported_modules[0].get_source())
ctx = tvm.context(device, 0)
# launch the kernel.
n = 1024
a = tvm.nd.array((np.random.uniform(size=n) * 256).astype(A.dtype), ctx)
......@@ -142,6 +147,8 @@ def test_add():
check_device("opencl")
check_device("metal")
check_device("cuda")
check_device("vulkan")
run("float32")
run("int32")
run("int64")
......@@ -149,7 +156,7 @@ def test_add():
if __name__ == "__main__":
test_add()
test_log_pow_llvm()
test_exp()
test_add()
test_popcount()
......@@ -2,6 +2,7 @@ import tvm
import numpy as np
import time
def test_gemm():
# graph
nn = 1024
......@@ -64,13 +65,14 @@ def test_gemm():
# one line to build the function.
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
with tvm.target.create(device):
f = tvm.build(s, [A, B, C])
ctx = tvm.context(device, 0)
# launch the kernel.
n = nn
m = n
......@@ -86,12 +88,12 @@ def test_gemm():
np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("vulkan")
check_device("nvptx -mcpu=sm_20")
check_device("rocm")
check_device("metal")
check_device("opencl")
check_device("cuda")
#check_device("nvptx -mcpu=sm_20")
if __name__ == "__main__":
test_gemm()
import tvm
import numpy as np
def test_reduce_prims():
def test_prim(reducer, np_reducer):
# graph
......@@ -21,12 +22,12 @@ def test_reduce_prims():
# one line to build the function.
def check_device(device, host="stackvm"):
ctx = tvm.context(device, 0)
if not tvm.module.enabled(host):
return
if not tvm.module.enabled(device):
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
freduce = tvm.build(s,
args=[A, B],
target=device, target_host=host,
......@@ -44,6 +45,7 @@ def test_reduce_prims():
np.testing.assert_allclose(npy, res, rtol=1e-4)
check_device("metal")
check_device("vulkan")
check_device("cuda")
check_device("opencl")
test_prim(tvm.sum, np.sum)
......@@ -106,10 +108,11 @@ def test_rfactor_threads():
# one line to build the function.
def check_target(device, host="stackvm"):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[A, B])
fsum = tvm.build(fapi,
target=device,
......@@ -125,6 +128,7 @@ def test_rfactor_threads():
np.testing.assert_allclose(
b.asnumpy(), res, rtol=1e-4)
check_target("vulkan")
check_target("cuda")
check_target("metal")
check_target("opencl")
......@@ -159,15 +163,14 @@ def test_rfactor_elemwise_threads():
# one line to build the function.
def check_target(device, host="stackvm"):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[A, C])
fsum = tvm.build(fapi,
target=device,
name="mysum")
print(fsum.imported_modules[0].get_source())
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
......@@ -176,6 +179,7 @@ def test_rfactor_elemwise_threads():
np.testing.assert_allclose(
b.asnumpy(), res, rtol=1e-4)
check_target("vulkan")
check_target("cuda")
check_target("metal")
check_target("opencl")
......@@ -264,10 +268,10 @@ def test_rfactor_argmax():
s[B0].set_store_predicate(thread_x.var.equal(0))
def check_target(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[A0, A1, B0, B1])
fargmax = tvm.build(fapi,
target=device,
......@@ -285,6 +289,7 @@ def test_rfactor_argmax():
np.testing.assert_allclose(np_res, nd_res0.asnumpy())
check_target("cuda")
check_target("vulkan")
if __name__ == "__main__":
test_rfactor_elemwise_threads()
......
......@@ -24,13 +24,13 @@ def test_scan():
# one line to build the function.
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
fscan = tvm.build(s, [X, res],
device,
name="myscan")
ctx = tvm.context(device, 0)
# launch the kernel.
n = 1024
m = 10
......@@ -41,6 +41,7 @@ def test_scan():
np.testing.assert_allclose(
b.asnumpy(), np.cumsum(a_np, axis=0))
check_device("vulkan")
check_device("cuda")
check_device("metal")
check_device("opencl")
......
......@@ -13,12 +13,12 @@ def test_add_pipeline():
# GPU schedule have to split by gridIdx and threadIdx
num_thread = 256
xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
s[C].bind(xo, tvm.thread_axis("threadIdx.x"))
s[C].bind(xi, tvm.thread_axis("blockIdx.x"))
s[C].bind(xi, tvm.thread_axis("threadIdx.x"))
s[C].bind(xo, tvm.thread_axis("blockIdx.x"))
xo, xi = s[D].split(D.op.axis[0], factor=num_thread)
s[D].bind(xo, tvm.thread_axis("threadIdx.x"))
s[D].bind(xi, tvm.thread_axis("blockIdx.x"))
s[D].bind(xi, tvm.thread_axis("threadIdx.x"))
s[D].bind(xo, tvm.thread_axis("blockIdx.x"))
# compile to IR
s = s.normalize()
......@@ -35,11 +35,11 @@ def test_add_pipeline():
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
def check_target(device, host="stackvm"):
if not tvm.module.enabled(host):
ctx = tvm.context(device, 0)
if not ctx.exist:
return
if not tvm.module.enabled(device):
if not tvm.module.enabled(host):
return
ctx = tvm.context(device, 0)
mhost = tvm.codegen.build_module(fsplits[0], host)
mdev = tvm.codegen.build_module(fsplits[1:], device)
mhost.import_module(mdev)
......@@ -55,12 +55,12 @@ def test_add_pipeline():
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
def check_module_save(device, host="stackvm"):
if not tvm.module.enabled(host):
ctx = tvm.context(device, 0)
if not ctx.exist:
return
if not tvm.module.enabled(device):
if not tvm.module.enabled(host):
return
ctx = tvm.context(device, 0)
fmt = "ptx" if device == "cuda" else "cl"
fmt = "ptx" if device == "cuda" else device
mhost = tvm.codegen.build_module(fsplits[0], host)
mdev = tvm.codegen.build_module(fsplits[1:], device)
temp = util.tempdir()
......@@ -82,7 +82,9 @@ def test_add_pipeline():
check_target("cuda", host="llvm")
check_module_save("cuda", host="stackvm")
check_target("nvptx", host="llvm")
check_target("vulkan", host="llvm")
check_target("rocm", host="llvm")
check_module_save("vulkan", host="stackvm")
if __name__ == "__main__":
......
......@@ -110,6 +110,7 @@ def test_device_module_dump():
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_device("cuda")
check_device("vulkan")
check_device("opencl")
check_device("metal")
......
......@@ -7,6 +7,7 @@ def enabled_ctx_list():
('cl', tvm.opencl(0)),
('metal', tvm.metal(0)),
('rocm', tvm.rocm(0)),
('vulkan', tvm.vulkan(0)),
('vpi', tvm.vpi(0))]
for k, v in ctx_list:
assert tvm.context(k, 0) == v
......
......@@ -2,6 +2,7 @@
import tvm
import os
from tvm.contrib import nvcc
from tvm.contrib import spirv
import numpy as np
TASK="gemm"
......@@ -25,6 +26,7 @@ def tvm_callback_cuda_postproc(code):
code = open("perf/%s_manual.cu" % TASK).read()
return code
def test_gemm():
# graph
nn = 2048
......@@ -101,12 +103,12 @@ def test_gemm():
s[BB].double_buffer()
# correctness
def check_device(device):
print("Device %s" % device)
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Device %s" % device)
f = tvm.build(s, [A, B, C], device)
ctx = tvm.context(device, 0)
# launch the kernel.
n, m, l = nn, nn, nn
a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
......@@ -126,7 +128,7 @@ def test_gemm():
GFLOPS = num_flops / (t * 1e3) / 1e6
print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))
for device in ["cuda", "opencl", "rocm", "nvptx"]:
for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]:
with tvm.build_config(auto_unroll_max_step=128,
unroll_explicit=(device != "cuda")):
check_device(device)
......
......@@ -9,13 +9,13 @@ def verify_broadcast_to_ele(in_shape, out_shape):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.broadcast_to(A, out_shape)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape)
......@@ -25,6 +25,7 @@ def verify_broadcast_to_ele(in_shape, out_shape):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("vulkan")
check_device("opencl")
check_device("cuda")
check_device("metal")
......@@ -50,13 +51,13 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
else:
raise NotImplementedError
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype)
......@@ -82,6 +83,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
foo(lhs_nd, rhs_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
check_device("vulkan")
check_device("opencl")
check_device("cuda")
check_device("metal")
......@@ -105,5 +107,5 @@ def test_broadcast_binary():
if __name__ == "__main__":
test_broadcast_to()
test_broadcast_binary()
test_broadcast_to()
......@@ -31,11 +31,11 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
......@@ -49,7 +49,7 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
......
......@@ -29,14 +29,14 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s1 = topi.generic.schedule_conv2d_nchw([B])
s2 = topi.generic.schedule_conv2d_nchw([C])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
......@@ -50,7 +50,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
......
......@@ -29,14 +29,14 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
......@@ -50,7 +50,7 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
......
......@@ -29,13 +29,13 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_dense(D)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx)
......@@ -44,7 +44,7 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
f(a, b, c, d)
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
def test_dense():
......
......@@ -23,7 +23,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
......@@ -32,7 +33,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
ctx = tvm.context(device, 0)
# build the kernels
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
......@@ -90,6 +90,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
check_device("cuda")
check_device("metal")
check_device("rocm")
check_device("vulkan")
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
in_width = in_height
......@@ -108,7 +110,8 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
# schedule
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
......@@ -117,7 +120,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
s1 = topi.generic.schedule_depthwise_conv2d_nhwc(DepthwiseConv2d)
s2 = topi.generic.schedule_depthwise_conv2d_nhwc(ScaleShift)
s3 = topi.generic.schedule_depthwise_conv2d_nhwc(Relu)
ctx = tvm.context(device, 0)
# build the kernels
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
......@@ -177,6 +179,7 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
check_device("cuda")
check_device("metal")
check_device("rocm")
check_device("vulkan")
def test_depthwise_conv2d():
print("testing nchw")
......
......@@ -32,11 +32,11 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli
schedule = schedule_depthwise_conv2d_backward_input_nhwc(In_grad)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
ctx = tvm.context(device, 0)
# build the kernel
f = tvm.build(schedule, [Filter, Out_grad, In_grad], device)
# prepare pod type for test data closure
......@@ -85,6 +85,7 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli
check_device("cuda")
check_device("metal")
check_device("rocm")
check_device("vulkan")
def test_topi_depthwise_conv2d_backward_input_nhwc():
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1)
......
......@@ -32,11 +32,11 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl
schedule = schedule_depthwise_conv2d_backward_weight_nhwc(Weight_grad)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
ctx = tvm.context(device, 0)
# build the kernel
f = tvm.build(schedule, [Input, Out_grad, Weight_grad], device)
# prepare pod type for test data closure
......@@ -78,6 +78,7 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl
check_device("cuda")
check_device("metal")
check_device("rocm")
check_device("vulkan")
def test_topi_depthwise_conv2d_backward_weight_nhwc():
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 1)
......
......@@ -44,20 +44,21 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
b_np = np.maximum(b_np, 0.0)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_pool(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
def test_pool():
......@@ -82,20 +83,20 @@ def verify_global_pool(n, c, h, w, pool_type):
b_np = np.maximum(b_np, 0.0)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_global_pool(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
def test_global_pool():
......
......@@ -47,13 +47,14 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
raise NotImplementedError
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_reduce(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name=type)
# Test
in_npy = np.random.uniform(size=in_shape).astype(np.float32)
......@@ -90,7 +91,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3)
else:
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
for device in ["cuda", "opencl", "metal", "llvm", "rocm"]:
for device in ["cuda", "opencl", "metal", "llvm", "rocm", "vulkan"]:
check_device(device)
......
......@@ -13,20 +13,21 @@ def verify_relu(m, n):
b_np = a_np * (a_np > 0)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_elemwise(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="relu")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
......
......@@ -17,20 +17,21 @@ def verify_softmax(m, n):
b_np = topi.testing.softmax_python(a_np)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="softmax")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
def test_softmax():
......@@ -48,20 +49,20 @@ def verify_log_softmax(m, n):
b_np = topi.testing.log_softmax_python(a_np)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="log_softmax")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ["cuda", "opencl", "metal", "rocm"]:
for device in ["cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device)
......
......@@ -7,13 +7,13 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.expand_dims(A, axis, num_newaxis)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="expand_dims")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = data_npy.reshape(out_shape)
......@@ -22,7 +22,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device)
......@@ -30,13 +30,13 @@ def verify_tranpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.transpose(A, axes)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="tranpose")
data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
out_npy = data_npy.transpose(axes)
......@@ -45,7 +45,7 @@ def verify_tranpose(in_shape, axes):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device)
......@@ -53,13 +53,13 @@ def verify_reshape(src_shape, dst_shape):
A = tvm.placeholder(shape=src_shape, name="A")
B = topi.reshape(A, dst_shape)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="reshape")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.reshape(data_npy, newshape=dst_shape)
......@@ -68,7 +68,7 @@ def verify_reshape(src_shape, dst_shape):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device)
......@@ -76,13 +76,14 @@ def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A")
B = topi.squeeze(A, axis=axis)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="squeeze")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis)
......@@ -95,7 +96,7 @@ def verify_squeeze(src_shape, axis):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device)
def verify_concatenate(shapes, axis):
......@@ -104,13 +105,14 @@ def verify_concatenate(shapes, axis):
tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor)
ctx = tvm.context(device, 0)
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
out_npy = np.concatenate(data_npys, axis=axis)
......@@ -119,7 +121,7 @@ def verify_concatenate(shapes, axis):
foo(*(data_nds + [out_nd]))
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device)
......@@ -127,13 +129,14 @@ def verify_split(src_shape, indices_or_sections, axis):
A = tvm.placeholder(shape=src_shape, name="A")
tensor_l = topi.split(A, indices_or_sections, axis=axis)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(tensor_l)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A] + tensor_l, device, name="split")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npys = np.split(data_npy, indices_or_sections, axis=axis)
......@@ -143,7 +146,7 @@ def verify_split(src_shape, indices_or_sections, axis):
for out_nd, out_npy in zip(out_nds, out_npys):
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device)
......
......@@ -14,13 +14,13 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale):
b_np = topi.testing.upsampling_python(a_np, scale)
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
......@@ -28,7 +28,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale):
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'cuda']:
for device in ['llvm', 'cuda', 'vulkan']:
check_device(device)
def test_upsampling():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment