Commit 79d503fd by Tianqi Chen Committed by GitHub

[BACKEND] Vulkan Runtime and SPIRV Codegen (#861)

* [BACKEND] Vulkan Runtime and SPIRV Codegen

* fix doc
parent 108e9f3f
cmake_minimum_required(VERSION 3.5) cmake_minimum_required(VERSION 3.7)
project(tvm C CXX) project(tvm C CXX)
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/build/private/local_config.cmake) if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/build/private/local_config.cmake)
...@@ -22,6 +22,7 @@ endif() ...@@ -22,6 +22,7 @@ endif()
tvm_option(USE_CUDA "Build with CUDA" OFF) tvm_option(USE_CUDA "Build with CUDA" OFF)
tvm_option(USE_OPENCL "Build with OpenCL" OFF) tvm_option(USE_OPENCL "Build with OpenCL" OFF)
tvm_option(USE_VULKAN "Build with Vulkan" OFF)
tvm_option(USE_OPENGL "Build with OpenGL" OFF) tvm_option(USE_OPENGL "Build with OpenGL" OFF)
tvm_option(USE_METAL "Build with Metal" OFF) tvm_option(USE_METAL "Build with Metal" OFF)
tvm_option(USE_RPC "Build with RPC" ON) tvm_option(USE_RPC "Build with RPC" ON)
...@@ -88,9 +89,11 @@ file(GLOB_RECURSE HALIDEIR_SRCS HalideIR/src/*.cpp) ...@@ -88,9 +89,11 @@ file(GLOB_RECURSE HALIDEIR_SRCS HalideIR/src/*.cpp)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS}) list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
file(GLOB RUNTIME_SRCS src/runtime/*.cc) file(GLOB RUNTIME_SRCS src/runtime/*.cc)
file(GLOB COMPILER_LLVM_SRCS src/codegen/llvm/*.cc) file(GLOB COMPILER_LLVM_SRCS src/codegen/llvm/*.cc)
file(GLOB COMPILER_VULKAN_SRCS src/codegen/spirv/*.cc)
file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc) file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc)
file(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc) file(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc)
file(GLOB RUNTIME_OPENGL_SRCS src/runtime/opengl/*.cc) file(GLOB RUNTIME_OPENGL_SRCS src/runtime/opengl/*.cc)
file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc)
file(GLOB RUNTIME_METAL_SRCS src/runtime/metal/*.mm) file(GLOB RUNTIME_METAL_SRCS src/runtime/metal/*.mm)
file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc) file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc)
file(GLOB RUNTIME_GRAPH_SRCS src/runtime/graph/*.cc) file(GLOB RUNTIME_GRAPH_SRCS src/runtime/graph/*.cc)
...@@ -151,6 +154,22 @@ else(USE_OPENGL) ...@@ -151,6 +154,22 @@ else(USE_OPENGL)
add_definitions(-DTVM_OPENGL_RUNTIME=0) add_definitions(-DTVM_OPENGL_RUNTIME=0)
endif(USE_OPENGL) endif(USE_OPENGL)
if(USE_VULKAN)
find_package(Vulkan REQUIRED)
message(STATUS "Build with VULKAN support")
include_directories(${Vulkan_INCLUDE_DIRS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARIES})
list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS})
list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS})
get_filename_component(VULKAN_LIB_PATH ${Vulkan_LIBRARY} DIRECTORY)
find_library(SPIRV_TOOLS_LIB SPIRV-Tools
${VULKAN_LIB_PATH}/spirv-tools)
list(APPEND TVM_LINKER_LIBS ${SPIRV_TOOLS_LIB})
add_definitions(-DTVM_VULKAN_RUNTIME=1)
else(USE_VULKAN)
add_definitions(-DTVM_VULKAN_RUNTIME=0)
endif(USE_VULKAN)
if(USE_METAL) if(USE_METAL)
find_package(OpenCL QUIET REQUIRED) find_package(OpenCL QUIET REQUIRED)
message(STATUS "Build with Metal support") message(STATUS "Build with Metal support")
...@@ -174,7 +193,7 @@ if(USE_GRAPH_RUNTIME) ...@@ -174,7 +193,7 @@ if(USE_GRAPH_RUNTIME)
endif(USE_GRAPH_RUNTIME) endif(USE_GRAPH_RUNTIME)
if(USE_LLVM) if(USE_LLVM)
find_package(LLVM CONFIG REQUIRED) find_spackage(LLVM CONFIG REQUIRED)
include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS}) add_definitions(${LLVM_DEFINITIONS})
set(TVM_LLVM_VERSION ${LLVM_VERSION_MAJOR}${LLVM_VERSION_MINOR}) set(TVM_LLVM_VERSION ${LLVM_VERSION_MAJOR}${LLVM_VERSION_MINOR})
......
Subproject commit 87b089a0ba20f2e8257038ee9211d6816088ce95 Subproject commit aadbf02d6bd7a545edbf6652494a7b07a97a06c1
...@@ -56,6 +56,7 @@ CUDA_SRC = $(wildcard src/runtime/cuda/*.cc) ...@@ -56,6 +56,7 @@ CUDA_SRC = $(wildcard src/runtime/cuda/*.cc)
ROCM_SRC = $(wildcard src/runtime/rocm/*.cc) ROCM_SRC = $(wildcard src/runtime/rocm/*.cc)
OPENCL_SRC = $(wildcard src/runtime/opencl/*.cc) OPENCL_SRC = $(wildcard src/runtime/opencl/*.cc)
OPENGL_SRC = $(wildcard src/runtime/opengl/*.cc) OPENGL_SRC = $(wildcard src/runtime/opengl/*.cc)
VULKAN_SRC = $(wildcard src/runtime/vulkan/*.cc)
RPC_SRC = $(wildcard src/runtime/rpc/*.cc) RPC_SRC = $(wildcard src/runtime/rpc/*.cc)
GRAPH_SRC = $(wildcard src/runtime/graph/*.cc) GRAPH_SRC = $(wildcard src/runtime/graph/*.cc)
RUNTIME_SRC = $(wildcard src/runtime/*.cc) RUNTIME_SRC = $(wildcard src/runtime/*.cc)
...@@ -69,6 +70,7 @@ CUDA_OBJ = $(patsubst src/%.cc, build/%.o, $(CUDA_SRC)) ...@@ -69,6 +70,7 @@ CUDA_OBJ = $(patsubst src/%.cc, build/%.o, $(CUDA_SRC))
ROCM_OBJ = $(patsubst src/%.cc, build/%.o, $(ROCM_SRC)) ROCM_OBJ = $(patsubst src/%.cc, build/%.o, $(ROCM_SRC))
OPENCL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENCL_SRC)) OPENCL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENCL_SRC))
OPENGL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENGL_SRC)) OPENGL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENGL_SRC))
VULKAN_OBJ = $(patsubst src/%.cc, build/%.o, $(VULKAN_SRC))
RPC_OBJ = $(patsubst src/%.cc, build/%.o, $(RPC_SRC)) RPC_OBJ = $(patsubst src/%.cc, build/%.o, $(RPC_SRC))
GRAPH_OBJ = $(patsubst src/%.cc, build/%.o, $(GRAPH_SRC)) GRAPH_OBJ = $(patsubst src/%.cc, build/%.o, $(GRAPH_SRC))
CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC)) $(LLVM_OBJ) CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC)) $(LLVM_OBJ)
...@@ -129,6 +131,20 @@ else ...@@ -129,6 +131,20 @@ else
CFLAGS += -DTVM_OPENCL_RUNTIME=0 CFLAGS += -DTVM_OPENCL_RUNTIME=0
endif endif
ifdef VULKAN_SDK
CFLAGS += -I$(VULKAN_SDK)/include
LDFLAGS += -L$(VULKAN_SDK)/lib
LDFLAGS += -L$(VULKAN_SDK)/lib/spirv-tools
endif
ifeq ($(USE_VULKAN), 1)
CFLAGS += -DTVM_VULKAN_RUNTIME=1
LDFLAGS += -lvulkan -lSPIRV-Tools
RUNTIME_DEP += $(VULKAN_OBJ)
else
CFLAGS += -DTVM_VULKAN_RUNTIME=0
endif
ifeq ($(USE_OPENGL), 1) ifeq ($(USE_OPENGL), 1)
CFLAGS += -DTVM_OPENGL_RUNTIME=1 CFLAGS += -DTVM_OPENGL_RUNTIME=1
EMCC_FLAGS += -DTVM_OPENGL_RUNTIME=1 EMCC_FLAGS += -DTVM_OPENGL_RUNTIME=1
......
...@@ -422,6 +422,18 @@ LoweredFunc LowerTVMBuiltin(LoweredFunc f); ...@@ -422,6 +422,18 @@ LoweredFunc LowerTVMBuiltin(LoweredFunc f);
LoweredFunc CombineContextCall(LoweredFunc f); LoweredFunc CombineContextCall(LoweredFunc f);
/*! /*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
* to avoid pointer casting in backend when possible.
*
* \note implemeneted in storage_rewrite.cc
* \param f The function to be trasnformed
* \return Transformed function.
*/
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
/*!
* \brief Lower intrinsic function calls. * \brief Lower intrinsic function calls.
* \param f The device function to be lowered. * \param f The device function to be lowered.
* \param target The target device. * \param target The target device.
......
...@@ -55,8 +55,8 @@ typedef int64_t tvm_index_t; ...@@ -55,8 +55,8 @@ typedef int64_t tvm_index_t;
/*! \brief Extension device types in TVM */ /*! \brief Extension device types in TVM */
typedef enum { typedef enum {
kDLVulkan = 7,
kOpenGL = 11, kOpenGL = 11,
// Extension DRAM type, used for quickly test extension device // Extension DRAM type, used for quickly test extension device
// The device api can differ depending on the xpu driver registered. // The device api can differ depending on the xpu driver registered.
kExtDev = 12, kExtDev = 12,
......
...@@ -17,7 +17,8 @@ from . import ir_builder ...@@ -17,7 +17,8 @@ from . import ir_builder
from . import target from . import target
from . import ndarray as nd from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm, opengl, ext_dev from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev
from ._ffi.runtime_ctypes import TypeCode from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import Function from ._ffi.function import Function
......
...@@ -94,6 +94,7 @@ class TVMContext(ctypes.Structure): ...@@ -94,6 +94,7 @@ class TVMContext(ctypes.Structure):
1 : 'cpu', 1 : 'cpu',
2 : 'gpu', 2 : 'gpu',
4 : 'opencl', 4 : 'opencl',
7 : 'vulkan',
8 : 'metal', 8 : 'metal',
9 : 'vpi', 9 : 'vpi',
10: 'rocm', 10: 'rocm',
...@@ -109,6 +110,7 @@ class TVMContext(ctypes.Structure): ...@@ -109,6 +110,7 @@ class TVMContext(ctypes.Structure):
'nvptx': 2, 'nvptx': 2,
'cl': 4, 'cl': 4,
'opencl': 4, 'opencl': 4,
'vulkan': 7,
'metal': 8, 'metal': 8,
'vpi': 9, 'vpi': 9,
'rocm': 10, 'rocm': 10,
......
"""Utility for Interacting with SPIRV Tools"""
import subprocess
import os
from . import util
def optimize(spv_bin):
"""Optimize SPIRV using spirv-opt via CLI
Note that the spirv-opt is still experimental.
Parameters
----------
spv_bin : bytearray
The spirv file
Return
------
cobj_bin : bytearray
The HSA Code Object
"""
tmp_dir = util.tempdir()
tmp_in = tmp_dir.relpath("input.spv")
tmp_out = tmp_dir.relpath("output.spv")
with open(tmp_in, "wb") as out_file:
out_file.write(bytes(spv_bin))
sdk = os.environ.get("VULKAN_SDK", None)
cmd = os.path.join(sdk, "bin/spirv-opt") if sdk else "spirv-opt"
args = [cmd, "-O", tmp_in, "-o", tmp_out]
proc = subprocess.Popen(
args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Opitmizationerror using spirv-opt:\n"
msg += str(out)
raise RuntimeError(msg)
return bytearray(open(tmp_out, "rb").read())
...@@ -120,6 +120,23 @@ def vpi(dev_id=0): ...@@ -120,6 +120,23 @@ def vpi(dev_id=0):
""" """
return TVMContext(9, dev_id) return TVMContext(9, dev_id)
def vulkan(dev_id=0):
"""Construct a Vulkan device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(7, dev_id)
def opengl(dev_id=0): def opengl(dev_id=0):
"""Construct a OpenGL device """Construct a OpenGL device
...@@ -135,6 +152,7 @@ def opengl(dev_id=0): ...@@ -135,6 +152,7 @@ def opengl(dev_id=0):
""" """
return TVMContext(11, dev_id) return TVMContext(11, dev_id)
def ext_dev(dev_id=0): def ext_dev(dev_id=0):
"""Construct a extension device """Construct a extension device
......
...@@ -116,7 +116,7 @@ class Target(object): ...@@ -116,7 +116,7 @@ class Target(object):
# For now assume rocm schedule for opencl # For now assume rocm schedule for opencl
self.keys += ("rocm", "gpu") self.keys += ("rocm", "gpu")
self.max_num_threads = 256 self.max_num_threads = 256
elif target_name in ("metal",): elif target_name in ("metal", "vulkan"):
self.keys += ("gpu",) self.keys += ("gpu",)
self.max_num_threads = 256 self.max_num_threads = 256
elif target_name in ("opengl",): elif target_name in ("opengl",):
......
...@@ -666,6 +666,8 @@ void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*) ...@@ -666,6 +666,8 @@ void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*)
} }
void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
// constraint of current logic
CHECK_EQ(op->base.type(), Int(32));
os << "((int" << op->lanes << ")("; os << "((int" << op->lanes << ")(";
for (int i = 0; i < op->lanes; i++) { for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")"; os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
......
/*!
* Copyright (c) 2018 by Contributors
* \file codegen_common.h
* \brief Common utility for codegen.
*/
#ifndef TVM_CODEGEN_CODEGEN_COMMON_H_
#define TVM_CODEGEN_CODEGEN_COMMON_H_
#include <tvm/arithmetic.h>
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace codegen {
/*!
* \brief Visit AssertStmt recursively, update align_map from condition.
* \param op The AssertStmt
* \param align_map The alignmap
* \param fvisit The recursive visitor
* \tparam FVisit the recursive visitor
*/
template<typename FVisit>
inline void VisitAssert(
const ir::AssertStmt* op,
std::unordered_map<const Variable*, arith::ModularEntry>* align_map,
FVisit fvisit) {
using namespace ir;
auto& align_map_ = *align_map;
// Detect useful invariant pattern and use them to visit child.
// Pattern: Var % const == 0
// TODO(tqchen) merge these pattern to a generic scope info visitor.
if (const EQ* eq = op->condition.as<EQ>()) {
const Mod* mod = eq->a.as<Mod>();
int64_t factor = 0, offset = 0;
if (mod && arith::GetConst(eq->b, &offset)) {
const Variable *var = mod->a.as<Variable>();
if (var && arith::GetConst(mod->b, &factor)) {
arith::ModularEntry old = align_map_[var];
if (factor > old.coeff) {
arith::ModularEntry e;
e.coeff = static_cast<int>(factor);
e.base = static_cast<int>(offset);
// new alignment info,
align_map_[var] = e;
fvisit(op->body);
// restore old info
align_map_[var] = old;
return;
}
}
}
}
fvisit(op->body);
}
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_COMMON_H_
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include "./codegen_llvm.h" #include "./codegen_llvm.h"
#include "./codegen_cpu.h" #include "./codegen_cpu.h"
#include "../codegen_common.h"
#include "../../pass/ir_util.h" #include "../../pass/ir_util.h"
#include "../../arithmetic/compute_expr.h" #include "../../arithmetic/compute_expr.h"
...@@ -341,7 +342,7 @@ void CodeGenLLVM::GetAlignment(Type t, ...@@ -341,7 +342,7 @@ void CodeGenLLVM::GetAlignment(Type t,
int align_bits = t.bits(); int align_bits = t.bits();
while (align_bits < max_align_bits && while (align_bits < max_align_bits &&
me.base % 2 == 0 && me.base % 2 == 0 &&
me.coeff %2 == 0) { me.coeff % 2 == 0) {
me.base = me.base / 2; me.base = me.base / 2;
me.coeff = me.coeff / 2; me.coeff = me.coeff / 2;
align_bits *= 2; align_bits *= 2;
...@@ -1026,31 +1027,9 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { ...@@ -1026,31 +1027,9 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
} }
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
// Detect useful invariant pattern and use them to visit child. VisitAssert(op, &align_map_, [this](const Stmt& body) {
// Pattern: Var % const == 0 this->VisitStmt(body);
// TODO(tqchen) move these pattern to a generic scope info visitor. });
if (const EQ* eq = op->condition.as<EQ>()) {
const Mod* mod = eq->a.as<Mod>();
int64_t factor = 0, offset = 0;
if (mod && arith::GetConst(eq->b, &offset)) {
const Variable *var = mod->a.as<Variable>();
if (var && arith::GetConst(mod->b, &factor)) {
arith::ModularEntry old = align_map_[var];
if (factor > old.coeff) {
arith::ModularEntry e;
e.coeff = static_cast<int>(factor);
e.base = static_cast<int>(offset);
// new alignment info,
align_map_[var] = e;
this->VisitStmt(op->body);
// restore old info
align_map_[var] = old;
return;
}
}
}
}
this->VisitStmt(op->body);
} }
void CodeGenLLVM::VisitStmt_(const LetStmt* op) { void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
......
/*!
* Copyright (c) 2018 by Contributors
* \file build_vulkan.cc
* \brief Build SPIRV block
*/
#if TVM_VULKAN_RUNTIME
// Use libspirv for parsing and validating code.
#include <vulkan/libspirv.h>
#include <dmlc/memory_io.h>
#include <tvm/ir_pass.h>
#include "./codegen_spirv.h"
#include "../build_common.h"
#include "../../runtime/vulkan/vulkan_module.h"
namespace tvm {
namespace codegen {
class SPIRVTools {
public:
SPIRVTools() {
ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0);
}
~SPIRVTools() {
spvContextDestroy(ctx_);
}
std::string BinaryToText(const std::vector<uint32_t>& bin) {
spv_text text = nullptr;
spv_diagnostic diagnostic;
spv_const_binary_t spv_bin{bin.data(), bin.size()};
spv_result_t res;
res = spvBinaryToText(
ctx_, spv_bin.code, spv_bin.wordCount,
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES |
SPV_BINARY_TO_TEXT_OPTION_INDENT,
&text, &diagnostic);
CHECK_EQ(res, SPV_SUCCESS)
<< " line=" << diagnostic->position.line
<< " column=" << diagnostic->position.column
<< " index=" << diagnostic->position.index
<< " error:" << diagnostic->error;
std::string ret(text->str);
spvTextDestroy(text);
return ret;
}
private:
spv_context ctx_;
};
runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
using tvm::runtime::VulkanShader;
std::ostringstream code_data;
static SPIRVTools spirv_tools;
std::unordered_map<std::string, VulkanShader> smap;
const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc");
CodeGenSPIRV cg;
for (LoweredFunc f : funcs) {
f = PointerValueTypeRewrite(f);
VulkanShader shader;
shader.data = cg.BuildFunction(f);
if (postproc != nullptr) {
TVMByteArray arr;
arr.data = reinterpret_cast<const char*>(dmlc::BeginPtr(shader.data));
arr.size = shader.data.size() * sizeof(uint32_t);
std::string transformed = (*postproc)(arr);
CHECK_EQ(transformed.length() % 4U, 0U);
shader.data.resize(transformed.size() / 4U);
std::copy(transformed.begin(), transformed.end(),
reinterpret_cast<char*>(dmlc::BeginPtr(shader.data)));
}
code_data << spirv_tools.BinaryToText(shader.data);
smap[f->name] = std::move(shader);
}
return runtime::VulkanModuleCreate(
smap, ExtractFuncInfo(funcs), code_data.str());
}
TVM_REGISTER_API("codegen.build_vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildSPIRV(args[0]);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
/*!
* Copyright (c) 2018 by Contributors
* \file ir_builder.h
* \brief Utility for building SPIRV code block
*/
#ifndef TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
#define TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/lowered_func.h>
#include <vector>
#include "./ir_builder.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace codegen {
using namespace ir;
/*!
* \brief Code generator into SPIRV
*/
class CodeGenSPIRV:
public ExprFunctor<spirv::Value(const Expr&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
* \return The final spirv module.
*/
virtual std::vector<uint32_t> BuildFunction(const LoweredFunc& f);
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
* \return created value.
*/
spirv::Value MakeValue(const Expr& e) {
return VisitExpr(e);
}
// override codegen
spirv::Value VisitExpr_(const Variable* op) override;
spirv::Value VisitExpr_(const Cast* op) override;
spirv::Value VisitExpr_(const IntImm* op) override;
spirv::Value VisitExpr_(const UIntImm* op) override;
spirv::Value VisitExpr_(const FloatImm* op) override;
spirv::Value VisitExpr_(const StringImm* op) override;
spirv::Value VisitExpr_(const Add* op) override;
spirv::Value VisitExpr_(const Sub* op) override;
spirv::Value VisitExpr_(const Mul* op) override;
spirv::Value VisitExpr_(const Div* op) override;
spirv::Value VisitExpr_(const Mod* op) override;
spirv::Value VisitExpr_(const Min* op) override;
spirv::Value VisitExpr_(const Max* op) override;
spirv::Value VisitExpr_(const LT* op) override;
spirv::Value VisitExpr_(const LE* op) override;
spirv::Value VisitExpr_(const GT* op) override;
spirv::Value VisitExpr_(const GE* op) override;
spirv::Value VisitExpr_(const EQ* op) override;
spirv::Value VisitExpr_(const NE* op) override;
spirv::Value VisitExpr_(const And* op) override;
spirv::Value VisitExpr_(const Or* op) override;
spirv::Value VisitExpr_(const Not* op) override;
spirv::Value VisitExpr_(const Select* op) override;
spirv::Value VisitExpr_(const Let* op) override;
spirv::Value VisitExpr_(const Call* op) override;
spirv::Value VisitExpr_(const Ramp* op) override;
spirv::Value VisitExpr_(const Broadcast* op) override;
spirv::Value VisitExpr_(const Load* op) override;
// stmt
void VisitStmt_(const Store* op) override;
void VisitStmt_(const For* op) override;
void VisitStmt_(const IfThenElse* op) override;
void VisitStmt_(const Allocate* op) override;
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
protected:
/*! \brief The storage information */
struct StorageInfo {
/*! \brief The storage scope */
runtime::StorageScope scope;
/*! \brief Whether it is volatile */
bool is_volatile{false};
/*! \brief Whether it is volatile */
bool content_fixed{false};
/*! \brief Current content type */
Type content_type{Handle()};
// Update content type if it hasn't beenupdated.
void UpdateContentType(Type type) {
if (content_fixed) {
CHECK_EQ(type, content_type)
<< "Cannot use two different content type in GLSL model";
} else {
this->content_type = type;
content_fixed = true;
}
}
};
// Reset the state so it works for a new function.
void InitFuncState();
// Get the thread index
spirv::Value GetThreadIndex(const IterVar& iv, const Expr& extent);
spirv::Value CreateStorageSync(const Call* op);
void Scalarize(const Expr& e,
std::function<void(int i, spirv::Value v)> f);
// The builder
std::unique_ptr<spirv::IRBuilder> builder_;
// Work group size of three
uint32_t workgroup_size_[3];
// Likely branch
uint32_t weight_likely_branch_{128};
// the storage scope of allocation
std::unordered_map<const Variable*, StorageInfo> storage_info_;
// The definition of local variable.
std::unordered_map<const Variable*, spirv::Value> var_map_;
// The alignment information
std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_spirv.cc
*/
#if TVM_VULKAN_RUNTIME
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
#include <vulkan/GLSL.std.450.h>
namespace tvm {
namespace codegen {
namespace spirv {
using namespace runtime;
// num_signature means number of arguments used to query signature
template<unsigned id>
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const ir::Call* call = e.as<ir::Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(ir::UIntImm::make(UInt(32), id));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::Call::make(
call->type, "spirv_glsl450", cargs, ir::Call::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
} // namespace spirv
} // namespace codegen
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
...@@ -59,7 +59,7 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> { ...@@ -59,7 +59,7 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
return VisitExpr(op->a); return VisitExpr(op->a);
} }
bool VisitExpr_(const Let* op) final { bool VisitExpr_(const Let* op) final {
return VisitExpr(op->body) && VisitExpr(op->value); return VisitExpr(op->body) || VisitExpr(op->value);
} }
bool VisitExpr_(const Cast* op) final { bool VisitExpr_(const Cast* op) final {
return VisitExpr(op->value); return VisitExpr(op->value);
...@@ -84,7 +84,7 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> { ...@@ -84,7 +84,7 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
private: private:
template<typename T> template<typename T>
bool BinaryOp(const T* op) { bool BinaryOp(const T* op) {
return VisitExpr(op->a) && VisitExpr(op->b); return VisitExpr(op->a) || VisitExpr(op->b);
} }
}; };
......
...@@ -903,20 +903,42 @@ class VectorAllocRewriter : public IRMutator { ...@@ -903,20 +903,42 @@ class VectorAllocRewriter : public IRMutator {
return stmt; return stmt;
} }
private:
void UpdateTypeMap(const Variable* buffer, Type t) { void UpdateTypeMap(const Variable* buffer, Type t) {
auto& tvec = acc_map_[buffer]; auto& tvec = acc_map_[buffer];
if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) { if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
tvec.push_back(t); tvec.push_back(t);
} }
} }
// Internal access map // Internal access map
std::unordered_map<const Variable*, std::unordered_map<const Variable*, std::vector<Type> > acc_map_;
std::vector<Type> > acc_map_;
}; };
LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
std::shared_ptr<LoweredFuncNode> n =
std::make_shared<LoweredFuncNode>(*f.operator->());
VectorAllocRewriter rewriter;
n->body = rewriter.Mutate(n->body);
for (Var arg : f->args) {
if (arg.type().is_handle()) {
const auto& tvec = rewriter.acc_map_[arg.get()];
if (tvec.size() == 1) {
Expr dtype = make_const(tvec[0], 0);
n->handle_data_type.Set(arg, dtype);
} else {
// always set data type to be non vectorized so
// load/store can still work via scalarization
if (tvec.size() != 0 && !n->handle_data_type.count(arg)) {
Expr dtype = make_const(tvec[0].with_lanes(1), 0);
n->handle_data_type.Set(arg, dtype);
}
}
}
}
return LoweredFunc(n);
}
Stmt StorageRewrite(Stmt stmt) { Stmt StorageRewrite(Stmt stmt) {
stmt = StoragePlanRewriter().Rewrite(stmt, true); stmt = StoragePlanRewriter().Rewrite(stmt, true);
return VectorAllocRewriter().Mutate(stmt); return VectorAllocRewriter().Mutate(stmt);
......
...@@ -28,6 +28,7 @@ inline std::string DeviceName(int type) { ...@@ -28,6 +28,7 @@ inline std::string DeviceName(int type) {
case kDLCPU: return "cpu"; case kDLCPU: return "cpu";
case kDLGPU: return "gpu"; case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl"; case kDLOpenCL: return "opencl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal"; case kDLMetal: return "metal";
case kDLVPI: return "vpi"; case kDLVPI: return "vpi";
case kDLROCM: return "rocm"; case kDLROCM: return "rocm";
......
...@@ -119,6 +119,8 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -119,6 +119,8 @@ bool RuntimeEnabled(const std::string& target) {
f_name = "device_api.opengl"; f_name = "device_api.opengl";
} else if (target == "mtl" || target == "metal") { } else if (target == "mtl" || target == "metal") {
f_name = "device_api.metal"; f_name = "device_api.metal";
} else if (target == "vulkan") {
f_name = "device_api.vulkan";
} else if (target == "stackvm") { } else if (target == "stackvm") {
f_name = "codegen.build_stackvm"; f_name = "codegen.build_stackvm";
} else if (target == "rpc") { } else if (target == "rpc") {
......
...@@ -44,12 +44,13 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -44,12 +44,13 @@ class ROCMDeviceAPI final : public DeviceAPI {
value = 64; value = 64;
break; break;
} }
case kComputeVersion: case kComputeVersion: {
hipDeviceProp_t prop; hipDeviceProp_t prop;
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id)); ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
*rv = prop.gcnArch; *rv = prop.gcnArch;
return; return;
} }
}
*rv = value; *rv = value;
} }
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(TVMContext ctx,
......
/*!
* Copyright (c) 2017 by Contributors
* \file vulkan_common.h
* \brief Vulkan common header
*/
#ifndef TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
#define TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
#include <tvm/runtime/config.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
#include <dmlc/logging.h>
#if TVM_VULKAN_RUNTIME
#include <vulkan/vulkan.h>
#include <mutex>
#include <string>
#include <vector>
#include "../workspace_pool.h"
namespace tvm {
namespace runtime {
namespace vulkan {
inline const char* VKGetErrorString(VkResult error) {
switch (error) {
case VK_SUCCESS: return "VK_SUCCESS";
case VK_NOT_READY: return "VK_NOT_READY";
case VK_TIMEOUT: return "VK_TIMEOUT";
case VK_EVENT_SET: return "VK_EVENT_SET";
case VK_EVENT_RESET: return "VK_EVENT_RESET";
case VK_INCOMPLETE: return "VK_INCOMPLETE";
case VK_ERROR_OUT_OF_HOST_MEMORY: return "VK_ERROR_OUT_OF_HOST_MEMORY";
case VK_ERROR_OUT_OF_DEVICE_MEMORY: return "VK_ERROR_OUT_OF_DEVICE_MEMORY";
case VK_ERROR_INITIALIZATION_FAILED: return "VK_ERROR_INITIALIZATION_FAILED";
case VK_ERROR_DEVICE_LOST: return "VK_ERROR_DEVICE_LOST";
case VK_ERROR_MEMORY_MAP_FAILED: return "VK_ERROR_MEMORY_MAP_FAILED";
case VK_ERROR_LAYER_NOT_PRESENT: return "VK_ERROR_LAYER_NOT_PRESENT";
case VK_ERROR_EXTENSION_NOT_PRESENT: return "VK_ERROR_EXTENSION_NOT_PRESENT";
case VK_ERROR_FEATURE_NOT_PRESENT: return "VK_ERROR_FEATURE_NOT_PRESENT";
case VK_ERROR_INCOMPATIBLE_DRIVER: return "VK_ERROR_INCOMPATIBLE_DRIVER";
case VK_ERROR_TOO_MANY_OBJECTS: return "VK_ERROR_TOO_MANY_OBJECTS";
case VK_ERROR_FORMAT_NOT_SUPPORTED: return "VK_ERROR_FORMAT_NOT_SUPPORTED";
case VK_ERROR_FRAGMENTED_POOL: return "VK_ERROR_FRAGMENTED_POOL";
default: return "Unknown Vulkan error code";
}
}
/*!
* \brief Protected Vulkan call
* \param func Expression to call.
*/
#define VULKAN_CHECK_ERROR(__e) \
{ \
CHECK(__e == VK_SUCCESS) \
<< "Vulan Error, code=" << __e << ": " << vulkan::VKGetErrorString(__e); \
}
#define VULKAN_CALL(func) \
{ \
VkResult __e = (func); \
VULKAN_CHECK_ERROR(__e); \
}
/*! \brief Auxiliary context structure for vulkan */
struct VulkanContext {
// phyiscal device
VkPhysicalDevice phy_device{nullptr};
// Phyiscal device property
VkPhysicalDeviceProperties phy_device_prop;
// Memory type index for staging.
uint32_t staging_mtype_index{0};
// whether staging is coherent
bool coherent_staging{false};
// Memory type index for compute
uint32_t compute_mtype_index{0};
// The logical device
VkDevice device{nullptr};
// command queue
VkQueue queue{nullptr};
// queue family_index;
uint32_t queue_family_index{0};
// Queue family index.
VkQueueFamilyProperties queue_prop;
};
/*! \brief The buffer object */
struct VulkanBuffer {
/*! \brief underlying buffer */
VkBuffer buffer{nullptr};
/*! \brief underlying buffer */
VkDeviceMemory memory{nullptr};
};
/*! \brief Buffer only used for stagging */
struct VulkanStagingBuffer {
/*! \brief the corresponding device */
VkDevice device{nullptr};
/*! \brief underlying buffer */
VkBuffer buffer{nullptr};
/*! \brief underlying buffer */
VkDeviceMemory memory{nullptr};
/*! \brief host address */
void* host_addr{nullptr};
/*! \brief size of the memory */
size_t size{0};
};
/*!
* \brief Process global Vulkan workspace.
*/
class VulkanWorkspace final : public DeviceAPI {
public:
// global mutex
std::mutex mu;
// whether the workspace it initialized.
bool initialized_{false};
// vulkan instance
VkInstance instance_{nullptr};
// The physical devices, have 1 to 1 mapping to devices
std::vector<VulkanContext> context_;
// Destructor
~VulkanWorkspace();
// Initialize workspace
// Return false if already initialized, otherwise return true.
void Init();
// override device API
void SetDevice(TVMContext ctx) final;
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
TVMType type_hint) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from,
size_t from_size,
void* to,
size_t to_size,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace
static const std::shared_ptr<VulkanWorkspace>& Global();
};
/*! \brief Helper command buffer resource */
struct VulkanCommandBuffer {
/*! \brief fence to signal the resource is ready to use */
VkFence fence{nullptr};
/*! \brief The internal command buffer */
VkCommandBuffer cmd_buffer{nullptr};
/*! \brief Descriptor set used to bind arguments */
VkDescriptorSet descriptor_set{nullptr};
/*! \brief Internal utilities for write command */
VkWriteDescriptorSet write_descriptor_set;
VulkanCommandBuffer() {
write_descriptor_set.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
write_descriptor_set.pNext = nullptr;
write_descriptor_set.dstSet = nullptr;
write_descriptor_set.dstBinding = 0;
write_descriptor_set.dstArrayElement = 0;
write_descriptor_set.descriptorCount = 1;
write_descriptor_set.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
write_descriptor_set.pImageInfo = nullptr;
write_descriptor_set.pBufferInfo = nullptr;
write_descriptor_set.pTexelBufferView = nullptr;
}
};
/*!
* \brief Command pool backed by a fixed size ring buffer.
*
* Vulkan requires us not to reuse command buffer until
* All its corresponding jobs have finished.
*
* This class to faciliate automatic management
* of the command buffers. A fence is created
* for each launch of command buffer jobs
* and when we try to reuse the same entry
* in the ring, we need to make sure that
* the previous pending job already finishes.
*
*/
class VulkanCommandPool {
public:
/*! \brief Maximum number of pending jobs in the pool */
static constexpr const int kMaxPending = 4;
/*! \brief Maximum number of pending jobs in the pool */
static constexpr const int kMaxNumArgs = 16;
/*!
* \brief constructor
* \param vctx The corresponding vulkan context.
*/
explicit VulkanCommandPool(const VulkanContext& vctx);
/*! \brief destructor */
~VulkanCommandPool();
/*!
* \brief Allocate a new command buffer entry
*
* The caller must only submit the entry once
* with the given fence in the entry,
* before calling next Alloc.
*
* This function may block to wait for a
* previously unfinished command when
* there is more than kMaxPending jobs.
*
* \returns The allocated entry.
*/
VulkanCommandBuffer* Alloc();
/*!
* \brief Allocate a new command buffer entry
* \param dlayout the descriptor layout.
*
* \returns The allocated entry.
*/
VulkanCommandBuffer* Alloc(const VkDescriptorSetLayout* dlayout);
private:
/*! \brief Local ring buffer */
std::vector<VulkanCommandBuffer> ring_;
/*! \brief clock pointer */
size_t clock_ptr_{0};
/*! \brief the corresponding device*/
VkDevice device_{nullptr};
/*! \brief internal command buffer pool */
VkCommandPool cmd_pool_{nullptr};
/*! \brief Descriptor pool */
VkDescriptorPool descriptor_pool_{nullptr};
};
/*! \brief Thread local workspace */
class VulkanThreadEntry {
public:
/*! \brief The current context */
TVMContext context;
/*! \brief workspace pool */
WorkspacePool pool;
/*! \brief The staging buffers */
std::vector<VulkanStagingBuffer> staging_buffer_;
/*!
* \brief Get the command pool of corresponding device;
* \param device_id The device id
* \return The corresponding command buffer.
*/
VulkanCommandPool* CommandPool(int device_id);
/*!
* \brief Get the stagging buffer.
* \param device_id The device id
* \return The corresponding stagging buffer.
*/
VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
// constructor
VulkanThreadEntry()
: pool(static_cast<DLDeviceType>(kDLVulkan), VulkanWorkspace::Global()) {
context.device_id = 0;
context.device_type = static_cast<DLDeviceType>(kDLVulkan);
}
~VulkanThreadEntry();
// get the global workspace
static VulkanThreadEntry* ThreadLocal();
private:
/*! \brief the command pools */
std::vector<std::unique_ptr<VulkanCommandPool> > pool_;
};
// inline implementation
} // namespace vulkan
} // namespace runtime
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
#endif // TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
/*!
* Copyright (c) 2017 by Contributors
* \file metal_module.h
* \brief Execution handling of Metal kernels
*/
#ifndef TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
#define TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
#include <tvm/runtime/config.h>
#include <tvm/runtime/packed_func.h>
#include <dmlc/type_traits.h>
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "../meta_data.h"
namespace tvm {
namespace runtime {
/*! \brief Maximum number of GPU supported in VulkanModule. */
static constexpr const int kVulkanMaxNumDevice = 8;
/*! \brief TVM Vulkan binary pack magic number */
static constexpr const int kVulkanModuleMagic = 0x02700027;
/*!
* \brief A single VK shader program
*
* Due to the global resource declaration.
* Current SPIRV only allows one entry program per shader,
* making it less useful for a Module like system.
*
* Instead we pass in map of str->VulkanShader until
* there is a native solution available.
*/
struct VulkanShader {
/*! \brief header flag */
uint32_t flag{0};
/*! \brief Data segment */
std::vector<uint32_t> data;
void Save(dmlc::Stream *writer) const;
bool Load(dmlc::Stream *reader);
};
/*!
* \brief create a metal module from data.
*
* \param pmap The program map.
* \param fmap The function information map.
* \param source Optional, source code.
*/
Module VulkanModuleCreate(
std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source);
} // namespace runtime
} // namespace tvm
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::VulkanShader, true);
} // namespace dmlc
#endif // TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
...@@ -18,7 +18,8 @@ def test_exp(): ...@@ -18,7 +18,8 @@ def test_exp():
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
if not tvm.module.enabled(host): if not tvm.module.enabled(host):
return return
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
return return
fexp = tvm.build(s, [A, B], fexp = tvm.build(s, [A, B],
device, host, device, host,
...@@ -33,6 +34,7 @@ def test_exp(): ...@@ -33,6 +34,7 @@ def test_exp():
b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5) b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
check_device("cuda", "llvm") check_device("cuda", "llvm")
check_device("vulkan")
check_device("opencl") check_device("opencl")
...@@ -75,11 +77,12 @@ def test_popcount(): ...@@ -75,11 +77,12 @@ def test_popcount():
bx, tx = s[B].split(B.op.axis[0], factor=num_thread) bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device) print("skip because %s is not enabled.." % device)
return return
ctx = tvm.context(device, 0) target = tvm.target.create(device)
if str(ctx).startswith('gpu'): if "cpu" not in target.keys:
s[B].bind(bx, tvm.thread_axis("blockIdx.x")) s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x")) s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
func = tvm.build(s, [A, B], device) func = tvm.build(s, [A, B], device)
...@@ -95,6 +98,8 @@ def test_popcount(): ...@@ -95,6 +98,8 @@ def test_popcount():
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
if dtype == "uint32":
check_device("vulkan")
run('uint32') run('uint32')
run('uint64') run('uint64')
...@@ -121,14 +126,14 @@ def test_add(): ...@@ -121,14 +126,14 @@ def test_add():
# one line to build the function. # one line to build the function.
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device) print("skip because %s is not enabled.." % device)
return return
fadd = tvm.build(s, [A, B, C], fadd = tvm.build(s, [A, B, C],
device, device,
name="myadd") name="myadd")
print(fadd.imported_modules[0].get_source())
ctx = tvm.context(device, 0)
# launch the kernel. # launch the kernel.
n = 1024 n = 1024
a = tvm.nd.array((np.random.uniform(size=n) * 256).astype(A.dtype), ctx) a = tvm.nd.array((np.random.uniform(size=n) * 256).astype(A.dtype), ctx)
...@@ -142,6 +147,8 @@ def test_add(): ...@@ -142,6 +147,8 @@ def test_add():
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
check_device("cuda") check_device("cuda")
check_device("vulkan")
run("float32") run("float32")
run("int32") run("int32")
run("int64") run("int64")
...@@ -149,7 +156,7 @@ def test_add(): ...@@ -149,7 +156,7 @@ def test_add():
if __name__ == "__main__": if __name__ == "__main__":
test_add()
test_log_pow_llvm() test_log_pow_llvm()
test_exp() test_exp()
test_add()
test_popcount() test_popcount()
...@@ -2,6 +2,7 @@ import tvm ...@@ -2,6 +2,7 @@ import tvm
import numpy as np import numpy as np
import time import time
def test_gemm(): def test_gemm():
# graph # graph
nn = 1024 nn = 1024
...@@ -64,13 +65,14 @@ def test_gemm(): ...@@ -64,13 +65,14 @@ def test_gemm():
# one line to build the function. # one line to build the function.
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device) print("skip because %s is not enabled.." % device)
return return
with tvm.target.create(device): with tvm.target.create(device):
f = tvm.build(s, [A, B, C]) f = tvm.build(s, [A, B, C])
ctx = tvm.context(device, 0)
# launch the kernel. # launch the kernel.
n = nn n = nn
m = n m = n
...@@ -86,12 +88,12 @@ def test_gemm(): ...@@ -86,12 +88,12 @@ def test_gemm():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("vulkan")
check_device("nvptx -mcpu=sm_20") check_device("nvptx -mcpu=sm_20")
check_device("rocm") check_device("rocm")
check_device("metal") check_device("metal")
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
#check_device("nvptx -mcpu=sm_20")
if __name__ == "__main__": if __name__ == "__main__":
test_gemm() test_gemm()
import tvm import tvm
import numpy as np import numpy as np
def test_reduce_prims(): def test_reduce_prims():
def test_prim(reducer, np_reducer): def test_prim(reducer, np_reducer):
# graph # graph
...@@ -21,12 +22,12 @@ def test_reduce_prims(): ...@@ -21,12 +22,12 @@ def test_reduce_prims():
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
ctx = tvm.context(device, 0)
if not tvm.module.enabled(host): if not tvm.module.enabled(host):
return return
if not tvm.module.enabled(device): if not ctx.exist:
print("skip because %s is not enabled.." % device) print("skip because %s is not enabled.." % device)
return return
ctx = tvm.context(device, 0)
freduce = tvm.build(s, freduce = tvm.build(s,
args=[A, B], args=[A, B],
target=device, target_host=host, target=device, target_host=host,
...@@ -44,6 +45,7 @@ def test_reduce_prims(): ...@@ -44,6 +45,7 @@ def test_reduce_prims():
np.testing.assert_allclose(npy, res, rtol=1e-4) np.testing.assert_allclose(npy, res, rtol=1e-4)
check_device("metal") check_device("metal")
check_device("vulkan")
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
test_prim(tvm.sum, np.sum) test_prim(tvm.sum, np.sum)
...@@ -106,10 +108,11 @@ def test_rfactor_threads(): ...@@ -106,10 +108,11 @@ def test_rfactor_threads():
# one line to build the function. # one line to build the function.
def check_target(device, host="stackvm"): def check_target(device, host="stackvm"):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device) print("skip because %s is not enabled.." % device)
return return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[A, B]) fapi = tvm.lower(s, args=[A, B])
fsum = tvm.build(fapi, fsum = tvm.build(fapi,
target=device, target=device,
...@@ -125,6 +128,7 @@ def test_rfactor_threads(): ...@@ -125,6 +128,7 @@ def test_rfactor_threads():
np.testing.assert_allclose( np.testing.assert_allclose(
b.asnumpy(), res, rtol=1e-4) b.asnumpy(), res, rtol=1e-4)
check_target("vulkan")
check_target("cuda") check_target("cuda")
check_target("metal") check_target("metal")
check_target("opencl") check_target("opencl")
...@@ -159,15 +163,14 @@ def test_rfactor_elemwise_threads(): ...@@ -159,15 +163,14 @@ def test_rfactor_elemwise_threads():
# one line to build the function. # one line to build the function.
def check_target(device, host="stackvm"): def check_target(device, host="stackvm"):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device) print("skip because %s is not enabled.." % device)
return return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[A, C]) fapi = tvm.lower(s, args=[A, C])
fsum = tvm.build(fapi, fsum = tvm.build(fapi,
target=device, target=device,
name="mysum") name="mysum")
print(fsum.imported_modules[0].get_source())
# launch the kernel. # launch the kernel.
a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
...@@ -176,6 +179,7 @@ def test_rfactor_elemwise_threads(): ...@@ -176,6 +179,7 @@ def test_rfactor_elemwise_threads():
np.testing.assert_allclose( np.testing.assert_allclose(
b.asnumpy(), res, rtol=1e-4) b.asnumpy(), res, rtol=1e-4)
check_target("vulkan")
check_target("cuda") check_target("cuda")
check_target("metal") check_target("metal")
check_target("opencl") check_target("opencl")
...@@ -264,10 +268,10 @@ def test_rfactor_argmax(): ...@@ -264,10 +268,10 @@ def test_rfactor_argmax():
s[B0].set_store_predicate(thread_x.var.equal(0)) s[B0].set_store_predicate(thread_x.var.equal(0))
def check_target(device): def check_target(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device) print("skip because %s is not enabled.." % device)
return return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[A0, A1, B0, B1]) fapi = tvm.lower(s, args=[A0, A1, B0, B1])
fargmax = tvm.build(fapi, fargmax = tvm.build(fapi,
target=device, target=device,
...@@ -285,6 +289,7 @@ def test_rfactor_argmax(): ...@@ -285,6 +289,7 @@ def test_rfactor_argmax():
np.testing.assert_allclose(np_res, nd_res0.asnumpy()) np.testing.assert_allclose(np_res, nd_res0.asnumpy())
check_target("cuda") check_target("cuda")
check_target("vulkan")
if __name__ == "__main__": if __name__ == "__main__":
test_rfactor_elemwise_threads() test_rfactor_elemwise_threads()
......
...@@ -24,13 +24,13 @@ def test_scan(): ...@@ -24,13 +24,13 @@ def test_scan():
# one line to build the function. # one line to build the function.
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device) print("skip because %s is not enabled.." % device)
return return
fscan = tvm.build(s, [X, res], fscan = tvm.build(s, [X, res],
device, device,
name="myscan") name="myscan")
ctx = tvm.context(device, 0)
# launch the kernel. # launch the kernel.
n = 1024 n = 1024
m = 10 m = 10
...@@ -41,6 +41,7 @@ def test_scan(): ...@@ -41,6 +41,7 @@ def test_scan():
np.testing.assert_allclose( np.testing.assert_allclose(
b.asnumpy(), np.cumsum(a_np, axis=0)) b.asnumpy(), np.cumsum(a_np, axis=0))
check_device("vulkan")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("opencl") check_device("opencl")
......
...@@ -13,12 +13,12 @@ def test_add_pipeline(): ...@@ -13,12 +13,12 @@ def test_add_pipeline():
# GPU schedule have to split by gridIdx and threadIdx # GPU schedule have to split by gridIdx and threadIdx
num_thread = 256 num_thread = 256
xo, xi = s[C].split(C.op.axis[0], factor=num_thread) xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
s[C].bind(xo, tvm.thread_axis("threadIdx.x")) s[C].bind(xi, tvm.thread_axis("threadIdx.x"))
s[C].bind(xi, tvm.thread_axis("blockIdx.x")) s[C].bind(xo, tvm.thread_axis("blockIdx.x"))
xo, xi = s[D].split(D.op.axis[0], factor=num_thread) xo, xi = s[D].split(D.op.axis[0], factor=num_thread)
s[D].bind(xo, tvm.thread_axis("threadIdx.x")) s[D].bind(xi, tvm.thread_axis("threadIdx.x"))
s[D].bind(xi, tvm.thread_axis("blockIdx.x")) s[D].bind(xo, tvm.thread_axis("blockIdx.x"))
# compile to IR # compile to IR
s = s.normalize() s = s.normalize()
...@@ -35,11 +35,11 @@ def test_add_pipeline(): ...@@ -35,11 +35,11 @@ def test_add_pipeline():
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0]) fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
def check_target(device, host="stackvm"): def check_target(device, host="stackvm"):
if not tvm.module.enabled(host): ctx = tvm.context(device, 0)
if not ctx.exist:
return return
if not tvm.module.enabled(device): if not tvm.module.enabled(host):
return return
ctx = tvm.context(device, 0)
mhost = tvm.codegen.build_module(fsplits[0], host) mhost = tvm.codegen.build_module(fsplits[0], host)
mdev = tvm.codegen.build_module(fsplits[1:], device) mdev = tvm.codegen.build_module(fsplits[1:], device)
mhost.import_module(mdev) mhost.import_module(mdev)
...@@ -55,12 +55,12 @@ def test_add_pipeline(): ...@@ -55,12 +55,12 @@ def test_add_pipeline():
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1) d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
def check_module_save(device, host="stackvm"): def check_module_save(device, host="stackvm"):
if not tvm.module.enabled(host): ctx = tvm.context(device, 0)
if not ctx.exist:
return return
if not tvm.module.enabled(device): if not tvm.module.enabled(host):
return return
ctx = tvm.context(device, 0) fmt = "ptx" if device == "cuda" else device
fmt = "ptx" if device == "cuda" else "cl"
mhost = tvm.codegen.build_module(fsplits[0], host) mhost = tvm.codegen.build_module(fsplits[0], host)
mdev = tvm.codegen.build_module(fsplits[1:], device) mdev = tvm.codegen.build_module(fsplits[1:], device)
temp = util.tempdir() temp = util.tempdir()
...@@ -82,7 +82,9 @@ def test_add_pipeline(): ...@@ -82,7 +82,9 @@ def test_add_pipeline():
check_target("cuda", host="llvm") check_target("cuda", host="llvm")
check_module_save("cuda", host="stackvm") check_module_save("cuda", host="stackvm")
check_target("nvptx", host="llvm") check_target("nvptx", host="llvm")
check_target("vulkan", host="llvm")
check_target("rocm", host="llvm") check_target("rocm", host="llvm")
check_module_save("vulkan", host="stackvm")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -110,6 +110,7 @@ def test_device_module_dump(): ...@@ -110,6 +110,7 @@ def test_device_module_dump():
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_device("cuda") check_device("cuda")
check_device("vulkan")
check_device("opencl") check_device("opencl")
check_device("metal") check_device("metal")
......
...@@ -7,6 +7,7 @@ def enabled_ctx_list(): ...@@ -7,6 +7,7 @@ def enabled_ctx_list():
('cl', tvm.opencl(0)), ('cl', tvm.opencl(0)),
('metal', tvm.metal(0)), ('metal', tvm.metal(0)),
('rocm', tvm.rocm(0)), ('rocm', tvm.rocm(0)),
('vulkan', tvm.vulkan(0)),
('vpi', tvm.vpi(0))] ('vpi', tvm.vpi(0))]
for k, v in ctx_list: for k, v in ctx_list:
assert tvm.context(k, 0) == v assert tvm.context(k, 0) == v
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import tvm import tvm
import os import os
from tvm.contrib import nvcc from tvm.contrib import nvcc
from tvm.contrib import spirv
import numpy as np import numpy as np
TASK="gemm" TASK="gemm"
...@@ -25,6 +26,7 @@ def tvm_callback_cuda_postproc(code): ...@@ -25,6 +26,7 @@ def tvm_callback_cuda_postproc(code):
code = open("perf/%s_manual.cu" % TASK).read() code = open("perf/%s_manual.cu" % TASK).read()
return code return code
def test_gemm(): def test_gemm():
# graph # graph
nn = 2048 nn = 2048
...@@ -101,12 +103,12 @@ def test_gemm(): ...@@ -101,12 +103,12 @@ def test_gemm():
s[BB].double_buffer() s[BB].double_buffer()
# correctness # correctness
def check_device(device): def check_device(device):
print("Device %s" % device) ctx = tvm.context(device, 0)
if not tvm.module.enabled(device): if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Device %s" % device)
f = tvm.build(s, [A, B, C], device) f = tvm.build(s, [A, B, C], device)
ctx = tvm.context(device, 0)
# launch the kernel. # launch the kernel.
n, m, l = nn, nn, nn n, m, l = nn, nn, nn
a_np = np.random.uniform(size=(n, l)).astype(A.dtype) a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
...@@ -126,7 +128,7 @@ def test_gemm(): ...@@ -126,7 +128,7 @@ def test_gemm():
GFLOPS = num_flops / (t * 1e3) / 1e6 GFLOPS = num_flops / (t * 1e3) / 1e6
print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS)) print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))
for device in ["cuda", "opencl", "rocm", "nvptx"]: for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]:
with tvm.build_config(auto_unroll_max_step=128, with tvm.build_config(auto_unroll_max_step=128,
unroll_explicit=(device != "cuda")): unroll_explicit=(device != "cuda")):
check_device(device) check_device(device)
......
...@@ -9,13 +9,13 @@ def verify_broadcast_to_ele(in_shape, out_shape): ...@@ -9,13 +9,13 @@ def verify_broadcast_to_ele(in_shape, out_shape):
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.broadcast_to(A, out_shape) B = topi.broadcast_to(A, out_shape)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B) s = topi.generic.schedule_broadcast(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="broadcast_to") foo = tvm.build(s, [A, B], device, name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype) data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape) out_npy = np.broadcast_to(data_npy, out_shape)
...@@ -25,6 +25,7 @@ def verify_broadcast_to_ele(in_shape, out_shape): ...@@ -25,6 +25,7 @@ def verify_broadcast_to_ele(in_shape, out_shape):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("vulkan")
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
...@@ -50,13 +51,13 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): ...@@ -50,13 +51,13 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
else: else:
raise NotImplementedError raise NotImplementedError
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C) s = topi.generic.schedule_broadcast(C)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ) foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype) lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype) rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype)
...@@ -82,6 +83,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): ...@@ -82,6 +83,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
foo(lhs_nd, rhs_nd, out_nd) foo(lhs_nd, rhs_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4) np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
check_device("vulkan")
check_device("opencl") check_device("opencl")
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
...@@ -105,5 +107,5 @@ def test_broadcast_binary(): ...@@ -105,5 +107,5 @@ def test_broadcast_binary():
if __name__ == "__main__": if __name__ == "__main__":
test_broadcast_to()
test_broadcast_binary() test_broadcast_binary()
test_broadcast_to()
...@@ -31,11 +31,11 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -31,11 +31,11 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
a_np, w_np, b_np, c_np = get_ref_data() a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
...@@ -49,7 +49,7 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -49,7 +49,7 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device) check_device(device)
......
...@@ -29,14 +29,14 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -29,14 +29,14 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
a_np, w_np, b_np, c_np = get_ref_data() a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s1 = topi.generic.schedule_conv2d_nchw([B]) s1 = topi.generic.schedule_conv2d_nchw([B])
s2 = topi.generic.schedule_conv2d_nchw([C]) s2 = topi.generic.schedule_conv2d_nchw([C])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
...@@ -50,7 +50,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -50,7 +50,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device) check_device(device)
......
...@@ -29,14 +29,14 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, ...@@ -29,14 +29,14 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
a_np, w_np, b_np, c_np = get_ref_data() a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s1 = topi.generic.schedule_conv2d_transpose_nchw([B]) s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
s2 = topi.generic.schedule_conv2d_transpose_nchw([C]) s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx) w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
...@@ -50,7 +50,7 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, ...@@ -50,7 +50,7 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device) check_device(device)
......
...@@ -29,13 +29,13 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): ...@@ -29,13 +29,13 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
a_np, b_np, c_np, d_np = get_ref_data() a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_dense(D) s = topi.generic.schedule_dense(D)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx) b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx) c = tvm.nd.array(c_np, ctx)
...@@ -44,7 +44,7 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): ...@@ -44,7 +44,7 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
f(a, b, c, d) f(a, b, c, d)
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device) check_device(device)
def test_dense(): def test_dense():
......
...@@ -23,7 +23,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -23,7 +23,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
...@@ -32,7 +33,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -32,7 +33,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d) s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift) s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu) s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
ctx = tvm.context(device, 0)
# build the kernels # build the kernels
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
...@@ -90,6 +90,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -90,6 +90,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
check_device("vulkan")
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
in_width = in_height in_width = in_height
...@@ -108,7 +110,8 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -108,7 +110,8 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
# schedule # schedule
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
...@@ -117,7 +120,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -117,7 +120,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
s1 = topi.generic.schedule_depthwise_conv2d_nhwc(DepthwiseConv2d) s1 = topi.generic.schedule_depthwise_conv2d_nhwc(DepthwiseConv2d)
s2 = topi.generic.schedule_depthwise_conv2d_nhwc(ScaleShift) s2 = topi.generic.schedule_depthwise_conv2d_nhwc(ScaleShift)
s3 = topi.generic.schedule_depthwise_conv2d_nhwc(Relu) s3 = topi.generic.schedule_depthwise_conv2d_nhwc(Relu)
ctx = tvm.context(device, 0)
# build the kernels # build the kernels
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
...@@ -177,6 +179,7 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -177,6 +179,7 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
check_device("vulkan")
def test_depthwise_conv2d(): def test_depthwise_conv2d():
print("testing nchw") print("testing nchw")
......
...@@ -32,11 +32,11 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli ...@@ -32,11 +32,11 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli
schedule = schedule_depthwise_conv2d_backward_input_nhwc(In_grad) schedule = schedule_depthwise_conv2d_backward_input_nhwc(In_grad)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
ctx = tvm.context(device, 0)
# build the kernel # build the kernel
f = tvm.build(schedule, [Filter, Out_grad, In_grad], device) f = tvm.build(schedule, [Filter, Out_grad, In_grad], device)
# prepare pod type for test data closure # prepare pod type for test data closure
...@@ -85,6 +85,7 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli ...@@ -85,6 +85,7 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
check_device("vulkan")
def test_topi_depthwise_conv2d_backward_input_nhwc(): def test_topi_depthwise_conv2d_backward_input_nhwc():
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1) verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1)
......
...@@ -32,11 +32,11 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl ...@@ -32,11 +32,11 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl
schedule = schedule_depthwise_conv2d_backward_weight_nhwc(Weight_grad) schedule = schedule_depthwise_conv2d_backward_weight_nhwc(Weight_grad)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
ctx = tvm.context(device, 0)
# build the kernel # build the kernel
f = tvm.build(schedule, [Input, Out_grad, Weight_grad], device) f = tvm.build(schedule, [Input, Out_grad, Weight_grad], device)
# prepare pod type for test data closure # prepare pod type for test data closure
...@@ -78,6 +78,7 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl ...@@ -78,6 +78,7 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl
check_device("cuda") check_device("cuda")
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
check_device("vulkan")
def test_topi_depthwise_conv2d_backward_weight_nhwc(): def test_topi_depthwise_conv2d_backward_weight_nhwc():
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 1) verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 1)
......
...@@ -44,20 +44,21 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): ...@@ -44,20 +44,21 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
b_np = np.maximum(b_np, 0.0) b_np = np.maximum(b_np, 0.0)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_pool(B) s = topi.generic.schedule_pool(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device) f = tvm.build(s, [A, B], device)
f(a, b) f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device) check_device(device)
def test_pool(): def test_pool():
...@@ -82,20 +83,20 @@ def verify_global_pool(n, c, h, w, pool_type): ...@@ -82,20 +83,20 @@ def verify_global_pool(n, c, h, w, pool_type):
b_np = np.maximum(b_np, 0.0) b_np = np.maximum(b_np, 0.0)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_global_pool(B) s = topi.generic.schedule_global_pool(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
f = tvm.build(s, [A, B], device) f = tvm.build(s, [A, B], device)
f(a, b) f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device) check_device(device)
def test_global_pool(): def test_global_pool():
......
...@@ -47,13 +47,14 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -47,13 +47,14 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
raise NotImplementedError raise NotImplementedError
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_reduce(B) s = topi.generic.schedule_reduce(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name=type) foo = tvm.build(s, [A, B], device, name=type)
# Test # Test
in_npy = np.random.uniform(size=in_shape).astype(np.float32) in_npy = np.random.uniform(size=in_shape).astype(np.float32)
...@@ -90,7 +91,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -90,7 +91,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3) np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3)
else: else:
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3) np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
for device in ["cuda", "opencl", "metal", "llvm", "rocm"]: for device in ["cuda", "opencl", "metal", "llvm", "rocm", "vulkan"]:
check_device(device) check_device(device)
......
...@@ -13,20 +13,21 @@ def verify_relu(m, n): ...@@ -13,20 +13,21 @@ def verify_relu(m, n):
b_np = a_np * (a_np > 0) b_np = a_np * (a_np > 0)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_elemwise(B) s = topi.generic.schedule_elemwise(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="relu") foo = tvm.build(s, [A, B], device, name="relu")
foo(a, b) foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device) check_device(device)
......
...@@ -17,20 +17,21 @@ def verify_softmax(m, n): ...@@ -17,20 +17,21 @@ def verify_softmax(m, n):
b_np = topi.testing.softmax_python(a_np) b_np = topi.testing.softmax_python(a_np)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_softmax(B) s = topi.generic.schedule_softmax(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="softmax") foo = tvm.build(s, [A, B], device, name="softmax")
foo(a, b) foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device) check_device(device)
def test_softmax(): def test_softmax():
...@@ -48,20 +49,20 @@ def verify_log_softmax(m, n): ...@@ -48,20 +49,20 @@ def verify_log_softmax(m, n):
b_np = topi.testing.log_softmax_python(a_np) b_np = topi.testing.log_softmax_python(a_np)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_softmax(B) s = topi.generic.schedule_softmax(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="log_softmax") foo = tvm.build(s, [A, B], device, name="log_softmax")
foo(a, b) foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ["cuda", "opencl", "metal", "rocm"]: for device in ["cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device) check_device(device)
......
...@@ -7,13 +7,13 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): ...@@ -7,13 +7,13 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.expand_dims(A, axis, num_newaxis) B = topi.expand_dims(A, axis, num_newaxis)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B) s = topi.generic.schedule_broadcast(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="expand_dims") foo = tvm.build(s, [A, B], device, name="expand_dims")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype) data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = data_npy.reshape(out_shape) out_npy = data_npy.reshape(out_shape)
...@@ -22,7 +22,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): ...@@ -22,7 +22,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device) check_device(device)
...@@ -30,13 +30,13 @@ def verify_tranpose(in_shape, axes): ...@@ -30,13 +30,13 @@ def verify_tranpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
B = topi.transpose(A, axes) B = topi.transpose(A, axes)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_injective(B) s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="tranpose") foo = tvm.build(s, [A, B], device, name="tranpose")
data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype) data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
out_npy = data_npy.transpose(axes) out_npy = data_npy.transpose(axes)
...@@ -45,7 +45,7 @@ def verify_tranpose(in_shape, axes): ...@@ -45,7 +45,7 @@ def verify_tranpose(in_shape, axes):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device) check_device(device)
...@@ -53,13 +53,13 @@ def verify_reshape(src_shape, dst_shape): ...@@ -53,13 +53,13 @@ def verify_reshape(src_shape, dst_shape):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
B = topi.reshape(A, dst_shape) B = topi.reshape(A, dst_shape)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_injective(B) s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="reshape") foo = tvm.build(s, [A, B], device, name="reshape")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.reshape(data_npy, newshape=dst_shape) out_npy = np.reshape(data_npy, newshape=dst_shape)
...@@ -68,7 +68,7 @@ def verify_reshape(src_shape, dst_shape): ...@@ -68,7 +68,7 @@ def verify_reshape(src_shape, dst_shape):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device) check_device(device)
...@@ -76,13 +76,14 @@ def verify_squeeze(src_shape, axis): ...@@ -76,13 +76,14 @@ def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
B = topi.squeeze(A, axis=axis) B = topi.squeeze(A, axis=axis)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_injective(B) s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="squeeze") foo = tvm.build(s, [A, B], device, name="squeeze")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis) out_npy = np.squeeze(data_npy, axis=axis)
...@@ -95,7 +96,7 @@ def verify_squeeze(src_shape, axis): ...@@ -95,7 +96,7 @@ def verify_squeeze(src_shape, axis):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device) check_device(device)
def verify_concatenate(shapes, axis): def verify_concatenate(shapes, axis):
...@@ -104,13 +105,14 @@ def verify_concatenate(shapes, axis): ...@@ -104,13 +105,14 @@ def verify_concatenate(shapes, axis):
tensor_l.append(tvm.placeholder(shape, name="A" + str(i))) tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis) out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor) s = topi.generic.schedule_injective(out_tensor)
ctx = tvm.context(device, 0)
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate") foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes] data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
out_npy = np.concatenate(data_npys, axis=axis) out_npy = np.concatenate(data_npys, axis=axis)
...@@ -119,7 +121,7 @@ def verify_concatenate(shapes, axis): ...@@ -119,7 +121,7 @@ def verify_concatenate(shapes, axis):
foo(*(data_nds + [out_nd])) foo(*(data_nds + [out_nd]))
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device) check_device(device)
...@@ -127,13 +129,14 @@ def verify_split(src_shape, indices_or_sections, axis): ...@@ -127,13 +129,14 @@ def verify_split(src_shape, indices_or_sections, axis):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
tensor_l = topi.split(A, indices_or_sections, axis=axis) tensor_l = topi.split(A, indices_or_sections, axis=axis)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_injective(tensor_l) s = topi.generic.schedule_injective(tensor_l)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A] + tensor_l, device, name="split") foo = tvm.build(s, [A] + tensor_l, device, name="split")
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npys = np.split(data_npy, indices_or_sections, axis=axis) out_npys = np.split(data_npy, indices_or_sections, axis=axis)
...@@ -143,7 +146,7 @@ def verify_split(src_shape, indices_or_sections, axis): ...@@ -143,7 +146,7 @@ def verify_split(src_shape, indices_or_sections, axis):
for out_nd, out_npy in zip(out_nds, out_npys): for out_nd, out_npy in zip(out_nds, out_npys):
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
check_device(device) check_device(device)
......
...@@ -14,13 +14,13 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale): ...@@ -14,13 +14,13 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale):
b_np = topi.testing.upsampling_python(a_np, scale) b_np = topi.testing.upsampling_python(a_np, scale)
def check_device(device): def check_device(device):
if not tvm.module.enabled(device): ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
s = topi.generic.schedule_injective(B) s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx) b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
f = tvm.build(s, [A, B], device) f = tvm.build(s, [A, B], device)
...@@ -28,7 +28,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale): ...@@ -28,7 +28,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale):
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'cuda']: for device in ['llvm', 'cuda', 'vulkan']:
check_device(device) check_device(device)
def test_upsampling(): def test_upsampling():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment